diff --git a/.bazelrc b/.bazelrc index a6d7a13f0..1246d336b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,11 +10,20 @@ build:linux --copt=-Wno-deprecated-declarations # you will typically need to spell out the compiler for local dev # BAZEL_VC= # BAZEL_VC_FULL_VERSION=14.44.3520 +# Some dependencies rely on bash so you will likely need msys2 +# BAZEL_SH=C:\msys64\usr\bin\bash.exe build:msvc --cxxopt="-std:c++20" --cxxopt="-utf-8" --host_cxxopt="-std:c++20" build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc build:msvc --build_tag_filters=-no_test_msvc +build:macos --cxxopt=-faligned-allocation +build:macos --cxxopt=-mmacosx-version-min=10.13 +build:macos --linkopt=-mmacosx-version-min=10.13 + +# ANTLR tool requires Java 17+. +build --java_runtime_version=remotejdk_17 + test --test_output=errors # Enable matchers in googletest diff --git a/.bazelversion b/.bazelversion index eab246c06..df5119ec6 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.3.2 +8.7.0 diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml new file mode 100644 index 000000000..6d12e6861 --- /dev/null +++ b/.github/workflows/windows_bazel_test.yml @@ -0,0 +1,28 @@ +name: Windows Bazel Test + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + name: Run Bazel Tests + runs-on: windows-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Setup Bazel and Bazelisk + uses: bazel-contrib/setup-bazel@0.19.0 + with: + bazelisk-cache: true + disk-cache: ${{ github.workflow }} + repository-cache: true + + - name: Run Tests + # msys2 'bash' on Windows will try to 'fix' the label prefix to + # work as a directory. + # //... won't work. + shell: bash + run: | + bazelisk test --config=msvc conformance:all conformance/policy:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml new file mode 100644 index 000000000..569177fcc --- /dev/null +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -0,0 +1,13 @@ +name: Windows Bazel Test (Post-Merge) + +on: + push: + branches: + - master + +jobs: + trigger-test: + # This prevents the workflow from running automatically when someone + # pushes to their fork. + if: github.repository == 'cel-expr/cel-cpp' + uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c2c2915be..97611fc75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,25 +12,25 @@ # # Run the following command from the root of the CEL repository: # -# gcloud builds submit --region=us -t gcr.io/cel-analysis/gcc9 . +# gcloud builds submit --region=us -t gcr.io/cel-analysis/cel-cpp/ubuntu_floor . # # Once complete get the sha256 digest from the output using the following # command: # -# gcloud artifacts versions list --package=gcc9 --repository=gcr.io \ +# gcloud artifacts versions list --package=cel-cpp/ubuntu_floor --repository=gcr.io \ # --location=us # # The cloudbuild.yaml file must be updated to use the new digest like so: # -# - name: 'gcr.io/cel-analysis/gcc9@' -FROM gcc:9 +# - name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@' +FROM gcr.io/cloud-marketplace/google/ubuntu2204:latest # Install Bazel prerequesites and required tools. # See https://docs.bazel.build/versions/master/install-ubuntu.html -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y --no-install-recommends \ - ca-certificates \ +RUN apt-get update && apt-get upgrade -y && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + bash \ + ca-certificates \ git \ libssl-dev \ make \ @@ -41,16 +41,29 @@ RUN apt-get update && \ zip \ zlib1g-dev \ default-jdk-headless \ - clang-11 && \ - apt-get clean + clang-11 \ + gcc-9 g++-9 \ + tzdata \ + && apt-get clean -# Install Bazel. -# https://github.com/bazelbuild/bazel/releases -ARG BAZEL_VERSION="7.3.2" -ADD https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh /tmp/install_bazel.sh -RUN /bin/bash /tmp/install_bazel.sh && rm /tmp/install_bazel.sh +# Install Bazelisk. +# https://github.com/bazelbuild/bazelisk/releases +ARG BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-amd64.deb" +ARG BAZELISK_CHKSUM="d8b00ea975c823e15263c80200ac42979e17368547fbff4ab177af035badfa83" +ADD ${BAZELISK_URL} /tmp/bazelisk.deb + +ENV BAZELISK_CHKSUM=${BAZELISK_CHKSUM} +RUN echo "${BAZELISK_CHKSUM} */tmp/bazelisk.deb" | sha256sum --check + +RUN apt-get install /tmp/bazelisk.deb RUN mkdir -p /workspace RUN mkdir -p /bazel -ENTRYPOINT ["/usr/local/bin/bazel"] +RUN USE_BAZEL_VERSION=8.7.0 bazelisk help +RUN USE_BAZEL_VERSION=7.3.2 bazelisk help + +ENV CC=gcc-9 +ENV CXX=g++-9 + +ENTRYPOINT ["/usr/bin/bazelisk"] diff --git a/MODULE.bazel b/MODULE.bazel index fbe9b41fc..187d68164 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -31,9 +31,10 @@ bazel_dep( name = "rules_python", version = "1.6.3", ) +bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep( name = "protobuf", - version = "33.4", + version = "34.1", repo_name = "com_google_protobuf", ) bazel_dep( @@ -41,20 +42,16 @@ bazel_dep( version = "20260107.0", repo_name = "com_google_absl", ) - bazel_dep( name = "googletest", version = "1.17.0.bcr.2", - dev_dependency = True, repo_name = "com_google_googletest", ) bazel_dep( name = "google_benchmark", version = "1.9.2", - dev_dependency = True, repo_name = "com_github_google_benchmark", ) - bazel_dep( name = "re2", version = "2025-11-05.bcr.1", @@ -74,16 +71,9 @@ bazel_dep( name = "platforms", version = "1.0.0", ) - -ANTLR4_VERSION = "4.13.2" - bazel_dep( name = "antlr4-cpp-runtime", - version = ANTLR4_VERSION, -) -single_version_override( - module_name = "antlr4-cpp-runtime", - patches = ["//bazel:antlr.patch"], + version = "4.13.2.bcr.2", ) python = use_extension("@rules_python//python/extensions:python.bzl", "python") @@ -95,8 +85,28 @@ python.toolchain( http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") +ANTLR4_VERSION = "4.13.2" + http_jar( name = "antlr4_jar", sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], ) + +bazel_dep( + name = "yaml-cpp", + version = "0.9.0", +) + +_CEL_POLICY_TAG = "ebfb2361f47080af643c14cf4da4c2b551a68740" + +_CEL_POLICY_SHA = "ea69e9c6b7bd5bc37d358148aebd2fcca38bc7c45a23feb635de72338e0327c1" + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "cel_policy", + sha256 = _CEL_POLICY_SHA, + strip_prefix = "cel-policy-%s" % _CEL_POLICY_TAG, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % _CEL_POLICY_TAG, +) diff --git a/README.md b/README.md index 23afe2b00..7c3c26be0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,12 @@ # C++ Implementations of the Common Expression Language +> [!WARNING] +> **On June 16, 2026, this repository will move to +> github.com/cel-expr/cel-cpp!** +> +> Please update your links and dependencies. See the [pinned +> issue](https://github.com/google/cel-cpp/issues/2029) for details. + For background on the Common Expression Language see the [cel-spec][1] repo. This is a C++ implementation of a [Common Expression Language][1] runtime, @@ -8,4 +15,4 @@ parser, and type checker. Released under the [Apache License](LICENSE). -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/base/operators.cc b/base/operators.cc index 805acc5a1..b7df40b27 100644 --- a/base/operators.cc +++ b/base/operators.cc @@ -179,13 +179,13 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) absl::optional Operator::FindByName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -193,13 +193,13 @@ absl::optional Operator::FindByName(absl::string_view input) { absl::optional Operator::FindByDisplayName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_display_name.cbegin(), operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == operators_by_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -208,13 +208,13 @@ absl::optional UnaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_name.cbegin(), unary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == unary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -223,14 +223,14 @@ absl::optional UnaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), unary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == unary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -239,13 +239,13 @@ absl::optional BinaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_name.cbegin(), binary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == binary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -254,14 +254,14 @@ absl::optional BinaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), binary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == binary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -270,13 +270,13 @@ absl::optional TernaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_name.cbegin(), ternary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == ternary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } @@ -285,14 +285,14 @@ absl::optional TernaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), ternary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == ternary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } diff --git a/base/operators_test.cc b/base/operators_test.cc index fdf95e7ae..6049f76c8 100644 --- a/base/operators_test.cc +++ b/base/operators_test.cc @@ -130,55 +130,55 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) TEST(Operator, FindByName) { EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); - EXPECT_THAT(Operator::FindByName("in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByName("in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByName(""), Eq(std::nullopt)); } TEST(Operator, FindByDisplayName) { EXPECT_THAT(Operator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByName) { EXPECT_THAT(UnaryOperator::FindByName("-_"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByDisplayName) { EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByName) { EXPECT_THAT(BinaryOperator::FindByName("_-_"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByDisplayName) { EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByName) { EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), Optional(Eq(TernaryOperator::Conditional()))); - EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByDisplayName) { - EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); } TEST(Operator, SupportsAbslHash) { diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index 2abbb6dbd..a4d28cdf8 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -55,6 +55,7 @@ def antlr_cc_library(name, src, package): generated, "@antlr4-cpp-runtime//:antlr4-cpp-runtime", ], + copts = ["-fexceptions"], linkstatic = 1, ) diff --git a/bazel/antlr.patch b/bazel/antlr.patch deleted file mode 100644 index c1aa9080c..000000000 --- a/bazel/antlr.patch +++ /dev/null @@ -1,30 +0,0 @@ ---- BUILD.bazel -+++ BUILD.bazel -@@ -17,21 +17,21 @@ - cc_library( - name = "antlr4-cpp-runtime", - srcs = glob(["runtime/src/**/*.cpp"]), - hdrs = ["runtime/src/antlr4-runtime.h"], - copts = ["-fexceptions"], -- defines = ["ANTLR4CPP_USING_ABSEIL"], -+ defines = ["ANTLR4CPP_USING_ABSEIL", "ANTLR4CPP_STATIC"], - features = ["-use_header_modules"], - includes = ["runtime/src"], - textual_hdrs = glob( - ["runtime/src/**/*.h"], - exclude = ["runtime/src/antlr4-runtime.h"], - ), - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], - ) - ---- VERSION -+++ /dev/null -@@ -1,1 +1,0 @@ --4.13.2 \ No newline at end of file diff --git a/checker/BUILD b/checker/BUILD index 42e37e81d..7f3ccfef7 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -49,8 +49,11 @@ cc_library( deps = [ ":type_check_issue", "//common:ast", + "//common:decl", "//common:source", + "//common:type", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -74,11 +77,14 @@ cc_test( cc_library( name = "type_checker", + srcs = ["type_checker.cc"], hdrs = ["type_checker.h"], deps = [ ":validation_result", "//common:ast", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -88,6 +94,7 @@ cc_library( deps = [ ":checker_options", ":type_checker", + "//common:container", "//common:decl", "//common:type", "@com_google_absl//absl/base:nullability", @@ -122,12 +129,14 @@ cc_test( srcs = ["type_checker_builder_factory_test.cc"], deps = [ ":checker_options", + ":optional", ":standard_library", ":type_checker", ":type_checker_builder", ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", + "//common:ast", "//common:decl", "//common:type", "//internal:status_macros", @@ -221,6 +230,8 @@ cc_library( hdrs = ["type_checker_subset_factory.h"], deps = [ ":type_checker_builder", + "//common:decl", + "//common:signature", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/checker_options.h b/checker/checker_options.h index 0b6d1af7f..cb85337fa 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -95,6 +95,14 @@ struct CheckerOptions { // Temporary flag to allow rolling out the change. No functional changes to // evaluation behavior in either mode. bool enable_function_name_in_reference = true; + + // If true, the checker will use the proto json field names for protobuf + // messages. Unlike protojson parsers, it will not accept the standard proto + // field names as valid json field names. + // + // Note: The checked AST will contain the json field names and an extension + // tag, but will require runtime support for resolving the json field names. + bool use_json_field_names = false; }; } // namespace cel diff --git a/checker/internal/BUILD b/checker/internal/BUILD index c539a2cc9..20c476db2 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -27,10 +27,11 @@ cc_library( hdrs = ["test_ast_helpers.h"], deps = [ "//common:ast", - "//extensions/protobuf:ast_converters", "//internal:status_macros", "//parser", "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -64,14 +65,20 @@ cc_library( srcs = ["type_check_env.cc"], hdrs = ["type_check_env.h"], deps = [ + ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -86,8 +93,12 @@ cc_library( srcs = ["namespace_generator.cc"], hdrs = ["namespace_generator.h"], deps = [ + "//common:container", "//internal:lexis", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -101,6 +112,7 @@ cc_test( srcs = ["namespace_generator_test.cc"], deps = [ ":namespace_generator", + "//common:container", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", @@ -118,8 +130,8 @@ cc_library( "type_checker_impl.h", ], deps = [ - ":format_type_name", ":namespace_generator", + ":proto_type_mask", ":type_check_env", ":type_inference_context", "//checker:checker_options", @@ -133,9 +145,11 @@ cc_library( "//common:ast_visitor", "//common:ast_visitor_base", "//common:constant", + "//common:container", "//common:decl", "//common:expr", - "//common:source", + "//common:format_type_name", + "//common:standard_definitions", "//common:type", "//common:type_kind", "//internal:lexis", @@ -144,8 +158,10 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -167,8 +183,11 @@ cc_test( ":type_checker_impl", "//checker:checker_options", "//checker:type_check_issue", + "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", + "//common:ast_proto", + "//common:container", "//common:decl", "//common:expr", "//common:source", @@ -176,13 +195,17 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", "//testutil:baseline_tests", + "//testutil:test_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", @@ -208,6 +231,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", @@ -219,8 +243,9 @@ cc_library( srcs = ["type_inference_context.cc"], hdrs = ["type_inference_context.h"], deps = [ - ":format_type_name", "//common:decl", + "//common:format_type_name", + "//common:standard_definitions", "//common:type", "//common:type_kind", "@com_google_absl//absl/container:flat_hash_map", @@ -251,24 +276,126 @@ cc_test( ) cc_library( - name = "format_type_name", - srcs = ["format_type_name.cc"], - hdrs = ["format_type_name.h"], + name = "descriptor_pool_type_introspector", + srcs = ["descriptor_pool_type_introspector.cc"], + hdrs = ["descriptor_pool_type_introspector.h"], deps = [ "//common:type", - "//common:type_kind", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) cc_test( - name = "format_type_name_test", - srcs = ["format_type_name_test.cc"], + name = "descriptor_pool_type_introspector_test", + srcs = ["descriptor_pool_type_introspector_test.cc"], deps = [ - ":format_type_name", + ":descriptor_pool_type_introspector", "//common:type", "//internal:testing", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "field_path", + srcs = ["field_path.cc"], + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask", + srcs = ["proto_type_mask.cc"], + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc new file mode 100644 index 000000000..733e4a3cb --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -0,0 +1,245 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +// Standard implementation for field lookups. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr> +FindStructTypeFieldByNameDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type, absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return std::nullopt; + } + const google::protobuf::FieldDescriptor* absl_nullable field = + descriptor->FindFieldByName(name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + + field = descriptor_pool->FindExtensionByPrintableName(descriptor, name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + return std::nullopt; +} + +// Standard implementation for listing fields. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr< + std::optional>> +ListStructTypeFieldsDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return std::nullopt; + } + + std::vector extensions; + descriptor_pool->FindAllExtensions(descriptor, &extensions); + + std::vector fields; + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back({field->name(), StructTypeField(MessageTypeField(field))}); + } + + return fields; +} + +} // namespace + +using Field = DescriptorPoolTypeIntrospector::Field; + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return Type::Message(descriptor); + } + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + return Type::Enum(enum_descriptor); + } + return std::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const { + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(type); + if (enum_descriptor != nullptr) { + const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = + enum_descriptor->FindValueByName(value); + if (enum_value_descriptor == nullptr) { + return std::nullopt; + } + return EnumConstant{ + .type = Type::Enum(enum_descriptor), + .type_full_name = enum_descriptor->full_name(), + .value_name = enum_value_descriptor->name(), + .number = enum_value_descriptor->number(), + }; + } + return std::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + if (!use_json_name_) { + return FindStructTypeFieldByNameDirectly(descriptor_pool_, type, name); + } + + const FieldTable* field_table = GetFieldTable(type); + + if (field_table == nullptr) { + return std::nullopt; + } + + if (auto it = field_table->json_name_map.find(name); + it != field_table->json_name_map.end()) { + return field_table->fields[it->second].field; + } + + if (auto it = field_table->extension_name_map.find(name); + it != field_table->extension_name_map.end()) { + return field_table->fields[it->second].field; + } + + return std::nullopt; +} + +absl::StatusOr< + std::optional>> +DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( + absl::string_view type) const { + if (!use_json_name_) { + return ListStructTypeFieldsDirectly(descriptor_pool_, type); + } + + const FieldTable* field_table = GetFieldTable(type); + if (field_table == nullptr) { + return std::nullopt; + } + std::vector fields; + fields.reserve(field_table->non_extensions.size()); + for (const auto& field : field_table->non_extensions) { + fields.push_back({field.json_name, field.field}); + } + return fields; +} + +const DescriptorPoolTypeIntrospector::FieldTable* +DescriptorPoolTypeIntrospector::GetFieldTable( + absl::string_view type_name) const { + absl::MutexLock lock(mu_); + if (auto it = field_tables_.find(type_name); it != field_tables_.end()) { + return it->second.get(); + } + if (cel::IsWellKnownMessageType(type_name)) { + return nullptr; + } + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return nullptr; + } + absl::string_view stable_type_name = descriptor->full_name(); + ABSL_DCHECK(stable_type_name == type_name); + std::unique_ptr field_table = CreateFieldTable(descriptor); + const FieldTable* field_table_ptr = field_table.get(); + field_tables_[stable_type_name] = std::move(field_table); + return field_table_ptr; +} + +std::unique_ptr +DescriptorPoolTypeIntrospector::CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const { + ABSL_DCHECK(!IsWellKnownMessageType(descriptor)); + std::vector fields; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + + std::vector extensions; + descriptor_pool_->FindAllExtensions(descriptor, &extensions); + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); i++) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(field)), + .json_name = field->json_name(), + .is_extension = false, + }); + field_name_map[field->name()] = fields.size() - 1; + if (use_json_name_ && !field->json_name().empty()) { + json_name_map[field->json_name()] = fields.size() - 1; + } + } + int non_extension_count = fields.size(); + + for (const google::protobuf::FieldDescriptor* extension : extensions) { + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(extension)), + .json_name = "", + .is_extension = true, + }); + extension_name_map[extension->full_name()] = fields.size() - 1; + } + int extension_count = fields.size() - non_extension_count; + auto result = std::make_unique(); + result->descriptor = descriptor; + result->fields = std::move(fields); + result->non_extensions = + absl::MakeConstSpan(result->fields).subspan(0, non_extension_count); + result->extensions = absl::MakeConstSpan(result->fields) + .subspan(non_extension_count, extension_count); + result->json_name_map = std::move(json_name_map); + result->field_name_map = std::move(field_name_map); + result->extension_name_map = std::move(extension_name_map); + return result; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/descriptor_pool_type_introspector.h b/checker/internal/descriptor_pool_type_introspector.h new file mode 100644 index 000000000..8a970ea00 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Implementation of `TypeIntrospector` that uses a `google::protobuf::DescriptorPool`. +// +// This is used by the type checker to resolve protobuf types and their fields +// and apply any options like using JSON names. +// +// Neither copyable nor movable. Should be managed by a TypeCheckEnv. +class DescriptorPoolTypeIntrospector : public TypeIntrospector { + public: + struct Field { + StructTypeField field; + absl::string_view json_name; + bool is_extension = false; + }; + + DescriptorPoolTypeIntrospector() = delete; + explicit DescriptorPoolTypeIntrospector( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} + + DescriptorPoolTypeIntrospector(const DescriptorPoolTypeIntrospector&) = + delete; + DescriptorPoolTypeIntrospector& operator=( + const DescriptorPoolTypeIntrospector&) = delete; + DescriptorPoolTypeIntrospector(DescriptorPoolTypeIntrospector&&) = delete; + DescriptorPoolTypeIntrospector& operator=(DescriptorPoolTypeIntrospector&&) = + delete; + + void set_use_json_name(bool use_json_name) { use_json_name_ = use_json_name; } + + bool use_json_name() const { return use_json_name_; } + + private: + struct FieldTable { + const google::protobuf::Descriptor* absl_nonnull descriptor; + std::vector fields; + absl::Span non_extensions; + absl::Span extensions; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + }; + + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final; + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final; + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final; + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final; + + std::unique_ptr CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const; + + const FieldTable* GetFieldTable(absl::string_view type_name) const; + + // Cached map of type to field table. + mutable absl::flat_hash_map> + field_tables_ ABSL_GUARDED_BY(mu_); + + mutable absl::Mutex mu_; + bool use_json_name_ = false; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc new file mode 100644 index 000000000..db766b347 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Optional; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::Truly; + +TEST(DescriptorPoolTypeIntrospectorTest, FindType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + EXPECT_THAT(introspector.FindType("cel.expr.conformance.proto3.TestAllTypes"), + IsOkAndHolds(Optional(Property(&Type::IsMessage, true)))); + EXPECT_THAT(introspector.FindType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), + IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); + EXPECT_THAT(introspector.FindType("non.existent.Type"), + IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto result = introspector.FindEnumConstant( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", "FOO"); + ASSERT_THAT(result, IsOkAndHolds(Optional(AllOf( + Truly([](const TypeIntrospector::EnumConstant& v) { + return v.value_name == "FOO" && v.number == 0; + }))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByName) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + introspector.set_use_json_name(false); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameJsonNameIgnored) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(false); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + EXPECT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto2.TestAllTypes", + "cel.expr.conformance.proto2.int32_ext"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + + ASSERT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + absl::StatusOr> field = + introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +MATCHER_P(FieldListingIs, field_name, "") { return arg.name == field_name; } + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + absl::StatusOr< + std::optional>> + fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(*fields, Optional(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeExtensions) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto2.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(259)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("single_int64"))); + EXPECT_THAT( + **fields, + Not(Contains(FieldListingIs("cel.expr.conformance.proto2.int32_ext")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + ListFieldsForStructTypeWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("singleInt64"))); + EXPECT_THAT(**fields, Not(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.SomeOtherType"); + EXPECT_THAT(fields, IsOkAndHolds(Eq(std::nullopt))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.cc b/checker/internal/field_path.cc new file mode 100644 index 000000000..5ecc4219b --- /dev/null +++ b/checker/internal/field_path.cc @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace cel::checker_internal { + +std::string FieldPath::DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..d67d9b935 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + // Returns the input path. + // For example: "f.b.d". + absl::string_view GetPath() const { return path_; } + + // Returns the list of nested field names in the path. + // For example: {"f", "b", "d"}. + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + // For example: "f". + std::string GetFieldName() const { return field_selection_.front(); } + + template + friend void AbslStringify(Sink& sink, const FieldPath& field_path) { + sink.Append(field_path.DebugString()); + } + + private: + static constexpr char kPathDelimiter = '.'; + + std::string DebugString() const; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..9a1434954 --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, AbslStringifyPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(absl::StrCat(field_path), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesIdenticalFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.type"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator.cc b/checker/internal/namespace_generator.cc index e5b2cfa51..7ab7628e4 100644 --- a/checker/internal/namespace_generator.cc +++ b/checker/internal/namespace_generator.cc @@ -20,7 +20,7 @@ #include #include "absl/functional/function_ref.h" -#include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -28,19 +28,20 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" #include "internal/lexis.h" namespace cel::checker_internal { namespace { -bool FieldSelectInterpretationCandidates( +bool FieldSelectInterpretationCandidatesImpl( absl::string_view prefix, - absl::Span partly_qualified_name, + absl::Span partly_qualified_name, bool prefix_is_alias, absl::FunctionRef callback) { for (int i = 0; i < partly_qualified_name.size(); ++i) { std::string buf; int count = partly_qualified_name.size() - i; - auto end_idx = count - 1; + auto end_idx = count - (prefix_is_alias ? 0 : 1); auto ident = absl::StrJoin(partly_qualified_name.subspan(0, count), "."); absl::string_view candidate = ident; if (absl::StartsWith(candidate, ".")) { @@ -54,28 +55,44 @@ bool FieldSelectInterpretationCandidates( return false; } } + if (prefix_is_alias) { + return callback(prefix, 0); + } return true; } +bool FieldSelectInterpretationCandidates( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/false, callback); +} + +bool FieldSelectInterpretationCandidatesWithAlias( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/true, callback); +} + } // namespace absl::StatusOr NamespaceGenerator::Create( - absl::string_view container) { + const ExpressionContainer& expression_container) { std::vector candidates; + absl::string_view container = expression_container.container(); if (container.empty()) { - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } - if (absl::StartsWith(container, ".")) { - return absl::InvalidArgumentError("container must not start with a '.'"); - } std::string prefix; for (auto segment : absl::StrSplit(container, '.')) { - if (!internal::LexisIsIdentifier(segment)) { - return absl::InvalidArgumentError( - "container must only contain valid identifier segments"); - } + // Assumes the the ExpressionContainer has already validated the container + // and aliases. + ABSL_DCHECK(internal::LexisIsIdentifier(segment)); if (prefix.empty()) { prefix = segment; } else { @@ -84,31 +101,75 @@ absl::StatusOr NamespaceGenerator::Create( candidates.push_back(prefix); } std::reverse(candidates.begin(), candidates.end()); - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } void NamespaceGenerator::GenerateCandidates( - absl::string_view unqualified_name, - absl::FunctionRef callback) { - if (absl::StartsWith(unqualified_name, ".")) { - callback(unqualified_name.substr(1)); + absl::string_view simple_name, + absl::FunctionRef callback) const { + // Special case for root-relative names. Aliases still apply first. + bool is_root_relative = absl::StartsWith(simple_name, "."); + if (is_root_relative) { + simple_name = simple_name.substr(1); + } + + // The name is unqualified, but may include a namespace (struct creation). + // This is just a quirk of the parser. + if (auto dot_pos = simple_name.find('.'); + dot_pos != absl::string_view::npos) { + absl::string_view first_segment = simple_name.substr(0, dot_pos); + absl::string_view rest = simple_name.substr(dot_pos + 1); + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + callback(absl::StrCat(resolved_alias, ".", rest)); + return; + } + } else { + if (auto resolved_alias = expression_container_->FindAlias(simple_name); + !resolved_alias.empty()) { + callback(resolved_alias); + return; + } + } + + if (is_root_relative) { + callback(simple_name); return; } + for (const auto& prefix : candidates_) { - std::string candidate = absl::StrCat(prefix, ".", unqualified_name); + std::string candidate = absl::StrCat(prefix, ".", simple_name); if (!callback(candidate)) { return; } } - callback(unqualified_name); + callback(simple_name); } void NamespaceGenerator::GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback) { - // Special case for explicit root relative name. e.g. '.com.example.Foo' - if (!partly_qualified_name.empty() && - absl::StartsWith(partly_qualified_name[0], ".")) { + absl::FunctionRef callback) const { + if (partly_qualified_name.empty()) { + return; + } + + // Special case for root-relative names. Aliases still apply first. + absl::string_view first_segment = partly_qualified_name[0]; + bool is_root_relative = absl::StartsWith(first_segment, "."); + if (is_root_relative) { + first_segment = first_segment.substr(1); + } + + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + FieldSelectInterpretationCandidatesWithAlias( + resolved_alias, partly_qualified_name.subspan(1), callback); + // If the alias matches, we don't check the container even if name + // resolution fails. + return; + } + + if (is_root_relative) { FieldSelectInterpretationCandidates("", partly_qualified_name, callback); return; } diff --git a/checker/internal/namespace_generator.h b/checker/internal/namespace_generator.h index 18c40dbda..61cb1956b 100644 --- a/checker/internal/namespace_generator.h +++ b/checker/internal/namespace_generator.h @@ -19,18 +19,26 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" namespace cel::checker_internal { // Utility class for generating namespace qualified candidates for reference // resolution. +// +// This class is expected to be scoped to a single type checking operation and +// borrows the ExpressionContainer from the TypeCheckEnv. class NamespaceGenerator { public: - static absl::StatusOr Create(absl::string_view container); + static absl::StatusOr Create( + const ExpressionContainer& expression_container + ABSL_ATTRIBUTE_LIFETIME_BOUND); // Copyable and movable. NamespaceGenerator(const NamespaceGenerator&) = default; @@ -51,8 +59,18 @@ class NamespaceGenerator { // and unqualified name foo // // com.google.foo, com.foo, foo - void GenerateCandidates(absl::string_view unqualified_name, - absl::FunctionRef callback); + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (foo = com.example) + // unqualified name foo + // + // com.example + void GenerateCandidates( + absl::string_view simple_name, + absl::FunctionRef callback) const; // For a partially qualified name, generate all the qualified candidates in // order of resolution precedence and pass them to the provided callback. The @@ -72,16 +90,30 @@ class NamespaceGenerator { // (com.Foo).bar, // (Foo.bar), // (Foo).bar, + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (Foo = com.example.Foo) + // partially qualified name Foo.bar + // + // (com.example.Foo.bar), + // (com.example.Foo).bar, void GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback); + absl::FunctionRef callback) const; private: - explicit NamespaceGenerator(std::vector candidates) - : candidates_(std::move(candidates)) {} + explicit NamespaceGenerator( + const ExpressionContainer* absl_nonnull expression_container, + std::vector candidates) + : candidates_(std::move(candidates)), + expression_container_(expression_container) {} // list of prefixes ordered from most qualified to least. std::vector candidates_; + const ExpressionContainer* absl_nonnull expression_container_; }; } // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator_test.cc b/checker/internal/namespace_generator_test.cc index da174748a..ba9bb88a4 100644 --- a/checker/internal/namespace_generator_test.cc +++ b/checker/internal/namespace_generator_test.cc @@ -18,19 +18,20 @@ #include #include -#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "common/container.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { -using ::absl_testing::StatusIs; +using ::absl_testing::IsOk; using ::testing::ElementsAre; using ::testing::Pair; TEST(NamespaceGeneratorTest, EmptyContainer) { - ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create("")); + ExpressionContainer container; + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -40,8 +41,9 @@ TEST(NamespaceGeneratorTest, EmptyContainer) { } TEST(NamespaceGeneratorTest, MultipleSegments) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -51,8 +53,9 @@ TEST(NamespaceGeneratorTest, MultipleSegments) { } TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates(".foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -61,18 +64,46 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { EXPECT_THAT(candidates, ElementsAre("foo")); } -TEST(NamespaceGeneratorTest, InvalidContainers) { - EXPECT_THAT(NamespaceGenerator::Create(".com.example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com..example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com.$example"), - StatusIs(absl::StatusCode::kInvalidArgument)); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT( + candidates, + ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), + Pair("com.foo.Bar", 1), Pair("com.foo", 0), + Pair("foo.Bar", 1), Pair("foo", 0))); } -TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT(candidates, + ElementsAre(Pair("bar.baz.Bar", 1), Pair("bar.baz", 0))); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasNoMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAbbreviation("foo.Bar"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + // No match on the alias (Bar) since it's not the first segment. std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( @@ -89,8 +120,9 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {".foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( diff --git a/checker/internal/proto_type_mask.cc b/checker/internal/proto_type_mask.cc new file mode 100644 index 000000000..85e39cb69 --- /dev/null +++ b/checker/internal/proto_type_mask.cc @@ -0,0 +1,87 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr> ProtoTypeMask::GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const { + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, this->GetTypeName())); + absl::btree_set field_names; + for (const FieldPath& field_path : this->GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +std::string ProtoTypeMask::DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", type_name_, + absl::StrJoin(paths, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..f7d522cba --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,111 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a descriptor for the input type name. +// Returns an error if the type name is not found. +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name); + +// Returns a field descriptor for the input field name. +// Returns an error if the field name is not found. +absl::StatusOr FindField( + const google::protobuf::Descriptor* descriptor, absl::string_view field_name); + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::vector& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + } + + // Returns a set of field names. The set includes the first field name from + // each field path. We are able to return a set of absl::string_view because + // the result is backed by the descriptor pool. + absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const; + + // Returns the type's full name. + // For example: "google.rpc.context.AttributeContext". + absl::string_view GetTypeName() const { return type_name_; } + + // Returns a representation of the FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + template + friend void AbslStringify(Sink& sink, const ProtoTypeMask& proto_type_mask) { + sink.Append(proto_type_mask.DebugString()); + } + + private: + std::string DebugString() const; + + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..9c50c9784 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,180 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +// Returns a message type descriptor for the input field descriptor. +// Returns an error if the field is not a message type. +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +// Inserts the type name with an empty set into types_and_visible_fields. +// Returns an error if the type name is already present with a non-empty set. +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({std::string(type_name), {}}); + return absl::OkStatus(); +} + +// Inserts the type name and field name into types_and_visible_fields. +// Returns an error if the type name is already present with an empty set. +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(std::string(field_name)); + return absl::OkStatus(); + } + types_and_visible_fields.insert( + {std::string(type_name), {std::string(field_name)}}); + return absl::OkStatus(); +} + +// Processes the input proto type masks to create and return the +// types_and_visible_fields map. +// Returns an error if one of the proto type masks is not valid. For example, +// if a type is not found in the descriptor pool, if a field name is not +// found, or if a field is not a message type when we are expecting it to be. +// Returns an error if there is a conflict in field visibility when +// updating the map. +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set& field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> +ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + std::shared_ptr proto_type_mask_registry = + absl::WrapUnique(new ProtoTypeMaskRegistry(types_and_visible_fields)); + return proto_type_mask_registry; +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +std::string ProtoTypeMaskRegistry::DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..338353e7d --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,83 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create a ProtoTypeMaskRegistry. + // Returns an error if one of the proto type masks is not valid. For example, + // if a type is not found in the descriptor pool, if a field name is not + // found, or if a field is not a message type when we are expecting it to be. + // Returns an error if there is a conflict in field visibility when + // updating the map. + static absl::StatusOr> Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const; + + template + friend void AbslStringify( + Sink& sink, + const std::shared_ptr& proto_type_mask_registry) { + sink.Append(proto_type_mask_registry->DebugString()); + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + std::string DebugString() const; + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..3a73c8823 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,402 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDelimiterFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {"."})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneRepeatedMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"repeated_nested_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("repeated_nested_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "repeated_nested_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoRepeatedMessageFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"repeated_nested_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("field 'repeated_nested_message' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, AbslStringifyPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + absl::StrCat(proto_type_mask_registry), + AllOf(HasSubstr("ProtoTypeMaskRegistry {"), + HasSubstr("{type: 'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'}"), + HasSubstr("{type: " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'" + ", visible_fields: 'bb'}"))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..0c534f8cf --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,143 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::vector field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithDelimiterFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "."}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_string"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, + UnorderedElementsAre("single_int32", "single_string")); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskTest, AbslStringifyPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(absl::StrCat(proto_type_mask), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/test_ast_helpers.cc b/checker/internal/test_ast_helpers.cc index 6ef7c2c05..543f70a89 100644 --- a/checker/internal/test_ast_helpers.cc +++ b/checker/internal/test_ast_helpers.cc @@ -14,29 +14,31 @@ #include "checker/internal/test_ast_helpers.h" #include -#include +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" -#include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "parser/options.h" #include "parser/parser.h" +#include "parser/parser_interface.h" namespace cel::checker_internal { -using ::cel::extensions::CreateAstFromParsedExpr; -using ::google::api::expr::parser::Parse; - absl::StatusOr> MakeTestParsedAst( absl::string_view expression) { - static ParserOptions options; - options.enable_optional_syntax = true; - CEL_ASSIGN_OR_RETURN(auto parsed, - Parse(expression, /*description=*/expression, options)); + static const cel::Parser* parser = []() { + cel::ParserOptions options = {.enable_optional_syntax = true}; + auto parser = NewParserBuilder(options)->Build(); + ABSL_CHECK_OK(parser); + return parser->release(); + }(); - return CreateAstFromParsedExpr(std::move(parsed)); + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + return parser->Parse(*source); } } // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index d856a7230..8dc83518d 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -28,110 +29,58 @@ #include "common/type_introspector.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" namespace cel::checker_internal { const VariableDecl* absl_nullable TypeCheckEnv::LookupVariable( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->variables_.find(name); it != scope->variables_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = variables_.find(name); it != variables_.end()) { + return &it->second; } return nullptr; } const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->functions_.find(name); it != scope->functions_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = functions_.find(name); it != functions_.end()) { + return &it->second; } + return nullptr; } -absl::StatusOr> TypeCheckEnv::LookupTypeName( +absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(name); - if (descriptor != nullptr) { - return Type::Message(descriptor); - } - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(name); - if (enum_descriptor != nullptr) { - return Type::Enum(enum_descriptor); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN(auto type, (*iter)->FindType(name)); + if (type.has_value()) { + return type; } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto type = (*iter)->FindType(name); - if (!type.ok() || type->has_value()) { - return type; - } - } - scope = scope->parent_; - } while ((scope != nullptr)); - return absl::nullopt; + return std::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupEnumConstant( +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(type); - if (enum_descriptor != nullptr) { - const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = - enum_descriptor->FindValueByName(value); - if (enum_value_descriptor == nullptr) { - return absl::nullopt; - } - auto decl = - MakeVariableDecl(absl::StrCat(enum_descriptor->full_name(), ".", - enum_value_descriptor->name()), - Type::Enum(enum_descriptor)); - decl.set_value( - Constant(static_cast(enum_value_descriptor->number()))); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN(auto enum_constant, + (*iter)->FindEnumConstant(type, value)); + if (enum_constant.has_value()) { + auto decl = MakeVariableDecl(absl::StrCat(enum_constant->type_full_name, + ".", enum_constant->value_name), + enum_constant->type); + decl.set_value(Constant(static_cast(enum_constant->number))); return decl; } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type, value); - if (!enum_constant.ok()) { - return enum_constant.status(); - } - if (enum_constant->has_value()) { - auto decl = - MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", - (**enum_constant).value_name), - (**enum_constant).type); - decl.set_value( - Constant(static_cast((**enum_constant).number))); - return decl; - } - } - scope = scope->parent_; - } while (scope != nullptr); - return absl::nullopt; + return std::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupTypeConstant( +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { - CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); + CEL_ASSIGN_OR_RETURN(std::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(type->name(), TypeType(arena, *type)); } @@ -143,45 +92,28 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return LookupEnumConstant(enum_name_candidate, value_name_candidate); } - return absl::nullopt; + return std::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupStructField( +absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(type_name); - if (descriptor != nullptr) { - const google::protobuf::FieldDescriptor* absl_nullable field_descriptor = - descriptor->FindFieldByName(field_name); - if (field_descriptor == nullptr) { - field_descriptor = descriptor_pool_->FindExtensionByPrintableName( - descriptor, field_name); - if (field_descriptor == nullptr) { - return absl::nullopt; - } - } - return cel::MessageTypeField(field_descriptor); - } + if (proto_type_mask_registry_ != nullptr && + !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return std::nullopt; } - const TypeCheckEnv* scope = this; - do { - // Check the type providers in reverse registration order. - // Note: this doesn't allow for shadowing a type with a subset type of the - // same name -- the parent type provider will still be considered when - // checking field accesses. - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto field_info = - (*iter)->FindStructTypeFieldByName(type_name, field_name); - if (!field_info.ok() || field_info->has_value()) { - return field_info; - } + // Check the type providers in registration order. + // Note: this doesn't allow for shadowing a type with a subset type of the + // same name -- the later type provider will still be considered when + // checking field accesses. + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); + ++iter) { + CEL_ASSIGN_OR_RETURN( + auto field, (*iter)->FindStructTypeFieldByName(type_name, field_name)); + if (field.has_value()) { + return field; } - scope = scope->parent_; - } while (scope != nullptr); - return absl::nullopt; + } + return std::nullopt; } const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index a4d242fdf..00fea0ba3 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -23,15 +23,22 @@ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -90,27 +97,32 @@ class TypeCheckEnv { absl_nonnull std::shared_ptr descriptor_pool) : descriptor_pool_(std::move(descriptor_pool)), - container_(""), - parent_(nullptr) {} - - TypeCheckEnv(absl_nonnull std::shared_ptr - descriptor_pool, - std::shared_ptr arena) - : descriptor_pool_(std::move(descriptor_pool)), - arena_(std::move(arena)), - container_(""), - parent_(nullptr) {} + proto_type_introspector_( + std::make_shared( + descriptor_pool_.get())) { + type_providers_.push_back( + std::make_shared()); + type_providers_.push_back(proto_type_introspector_); + } - // Move-only. + TypeCheckEnv(const TypeCheckEnv&) = default; + TypeCheckEnv& operator=(const TypeCheckEnv&) = default; TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; - const std::string& container() const { return container_; } + const ExpressionContainer& container() const { return container_; } - void set_container(std::string container) { + void set_container(ExpressionContainer container) { container_ = std::move(container); } + const DescriptorPoolTypeIntrospector& proto_type_introspector() const { + return *proto_type_introspector_; + } + DescriptorPoolTypeIntrospector& proto_type_introspector() { + return *proto_type_introspector_; + } + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } const absl::optional& expected_type() const { return expected_type_; } @@ -146,6 +158,14 @@ class TypeCheckEnv { variables_[decl.name()] = std::move(decl); } + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -163,9 +183,6 @@ class TypeCheckEnv { functions_[decl.name()] = std::move(decl); } - const TypeCheckEnv* absl_nullable parent() const { return parent_; } - void set_parent(TypeCheckEnv* parent) { parent_ = parent; } - // Returns the declaration for the given name if it is found in the current // or any parent scope. // Note: the returned declaration ptr is only valid as long as no changes are @@ -184,40 +201,43 @@ class TypeCheckEnv { absl::StatusOr> LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view type_name) const; - TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return TypeCheckEnv(this); - } - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_.get(); } // Used to keep an arena alive if one was needed to allocate types. // - // The TypeCheckEnv does not otherwise use it. - void set_arena(std::shared_ptr arena) { + // Expected to be called exactly once if at all. + void set_arena(std::shared_ptr arena) { + ABSL_DCHECK(arena_ == nullptr || arena == arena_); arena_ = std::move(arena); } - private: - explicit TypeCheckEnv(const TypeCheckEnv* absl_nonnull parent) - : descriptor_pool_(parent->descriptor_pool_), - container_(parent != nullptr ? parent->container() : ""), - parent_(parent) {} + // Returns the arena if one was set, nullptr otherwise. + std::shared_ptr arena() const { return arena_; } + private: absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; absl_nonnull std::shared_ptr descriptor_pool_; + // If set, an arena was needed to allocate types in the environment. - absl_nullable std::shared_ptr arena_; - std::string container_; - const TypeCheckEnv* absl_nullable parent_; + // + // The TypeCheckEnv does not otherwise use the arena, though it may be used by + // derived TypeCheckerBuilders. + absl_nullable std::shared_ptr arena_; + ExpressionContainer container_; + + // Used to resolve fields on message types. + std::shared_ptr proto_type_introspector_; // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; + std::shared_ptr proto_type_mask_registry_; + // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 8aa5177a5..4289fb528 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,21 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -84,19 +89,75 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { return absl::OkStatus(); } +absl::Status AddWellKnownContextDeclarationVariables( + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } + Type type = MessageTypeField(field).GetType(); + if (type.IsEnum()) { + type = IntType(); + } + absl::string_view name = field->name(); + if (use_json_name) { + name = field->json_name(); + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", name, + "' declared multiple times (from context declaration: '", + descriptor->full_name(), "')")); + } + } + return absl::OkStatus(); +} + absl::Status AddContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { - for (int i = 0; i < descriptor->field_count(); i++) { - const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i); - MessageTypeField cel_field(proto_field); - Type field_type = cel_field.GetType(); - if (field_type.IsEnum()) { - field_type = IntType(); + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { + const bool use_json_name = env.proto_type_introspector().use_json_name(); + if (IsWellKnownMessageType(descriptor)) { + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); + } + CEL_ASSIGN_OR_RETURN(auto fields, + env.proto_type_introspector().ListFieldsForStructType( + descriptor->full_name())); + if (!fields.has_value()) { + return absl::InternalError(absl::StrCat("context declaration '", + descriptor->full_name(), + "' not found, but was expected")); + } + for (const auto& field_entry : *fields) { + Type type = field_entry.field.GetType(); + if (type.IsEnum()) { + type = IntType(); } - if (!env.InsertVariableIfAbsent( - MakeVariableDecl(cel_field.name(), field_type))) { + + absl::string_view name = field_entry.name; + + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(name)) { + continue; + } + + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( - absl::StrCat("variable '", cel_field.name(), + absl::StrCat("variable '", name, "' declared multiple times (from context declaration: '", descriptor->full_name(), "')")); } @@ -121,13 +182,13 @@ absl::StatusOr MergeFunctionDecls( return merged_decl; } -absl::optional FilterDecl(FunctionDecl decl, - const TypeCheckerSubset& subset) { +std::optional FilterDecl(FunctionDecl decl, + const TypeCheckerSubset& subset) { FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); - for (const auto& ovl : overloads) { - if (subset.should_include_overload(name, ovl.id())) { + for (auto& ovl : overloads) { + if (subset.should_include_overload(name, ovl)) { absl::Status s = filtered.AddOverload(std::move(ovl)); if (!s.ok()) { // Should not be possible to construct the original decl in a way that @@ -137,7 +198,7 @@ absl::optional FilterDecl(FunctionDecl decl, } } if (filtered.overloads().empty()) { - return absl::nullopt; + return std::nullopt; } filtered.set_name(std::move(name)); return filtered; @@ -246,7 +307,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( for (FunctionDeclRecord& fn : config.functions) { FunctionDecl decl = std::move(fn.decl); if (subset != nullptr) { - absl::optional filtered = + std::optional filtered = FilterDecl(std::move(decl), *subset); if (!filtered.has_value()) { continue; @@ -280,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } for (const google::protobuf::Descriptor* context_type : config.context_types) { - CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); } for (VariableDeclRecord& var : config.variables) { @@ -302,12 +364,22 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } } + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + return absl::OkStatus(); } absl::StatusOr> TypeCheckerBuilderImpl::Build() { - TypeCheckEnv env(descriptor_pool_); - env.set_container(container_); + TypeCheckEnv env(template_env_); + CEL_RETURN_IF_ERROR(ConfigureTypeCheckEnv(env)); + return std::make_unique(std::move(env), + options_); +} + +absl::Status TypeCheckerBuilderImpl::ConfigureTypeCheckEnv(TypeCheckEnv& env) { + if (expression_container_.has_value()) { + env.set_container(*expression_container_); + } if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } @@ -324,6 +396,9 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config)); } + env.proto_type_introspector().set_use_json_name( + options_.use_json_field_names); + for (const ConfigRecord& config : configs) { TypeCheckerSubset* subset = nullptr; if (!config.id.empty()) { @@ -338,12 +413,10 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { /*subset=*/nullptr, env)); CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, /*subset=*/nullptr, env)); - // A library may have been the first to initialize the arena, so we need to - // set it as the last step. - env.set_arena(arena_); - auto checker = std::make_unique( - std::move(env), options_); - return checker; + if (type_arena_ != nullptr) { + env.set_arena(type_arena_); + } + return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { @@ -393,7 +466,7 @@ absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( absl::string_view type) { const google::protobuf::Descriptor* desc = - descriptor_pool_->FindMessageTypeByName(type); + template_env_.descriptor_pool()->FindMessageTypeByName(type); if (desc == nullptr) { return absl::NotFoundError( absl::StrCat("context declaration '", type, "' not found")); @@ -416,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(template_env_.descriptor_pool())); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, @@ -440,7 +530,15 @@ void TypeCheckerBuilderImpl::AddTypeProvider( } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { - container_ = container; + if (!expression_container_.has_value()) { + expression_container_.emplace(); + } + expression_container_->SetContainer(container).IgnoreError(); +} + +void TypeCheckerBuilderImpl::SetExpressionContainer( + ExpressionContainer container) { + expression_container_ = std::move(container); } void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 3b3472232..9895a8aee 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -21,6 +21,7 @@ #include #include "absl/base/nullability.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -28,9 +29,11 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -39,8 +42,6 @@ namespace cel::checker_internal { -class TypeCheckerBuilderImpl; - // Builder for TypeChecker instances. class TypeCheckerBuilderImpl : public TypeCheckerBuilder { public: @@ -50,7 +51,18 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options) : options_(options), target_config_(&default_config_), - descriptor_pool_(std::move(descriptor_pool)) {} + template_env_(std::move(descriptor_pool)) {} + + // Constructor for building an extended TypeChecker. + explicit TypeCheckerBuilderImpl(const CheckerOptions& options, + const TypeCheckEnv& template_env) + : options_(options), + target_config_(&default_config_), + template_env_(template_env) { + if (auto arena = template_env_.arena(); arena != nullptr) { + type_arena_ = std::move(arena); + } + } // Move only. TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; @@ -66,6 +78,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; @@ -76,17 +90,20 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { void set_container(absl::string_view container) override; + void SetExpressionContainer( + ExpressionContainer expression_container) override; + const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { - if (arena_ == nullptr) { - arena_ = std::make_shared(); + if (type_arena_ == nullptr) { + type_arena_ = std::make_shared(); } - return arena_.get(); + return type_arena_.get(); } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const override { - return descriptor_pool_.get(); + return template_env_.descriptor_pool(); } private: @@ -117,6 +134,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector functions; std::vector> type_providers; std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, @@ -125,6 +147,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status ApplyConfig(ConfigRecord config, const TypeCheckerSubset* subset, TypeCheckEnv& env); + absl::Status ConfigureTypeCheckEnv(TypeCheckEnv& env); + CheckerOptions options_; // Default target for configuration changes. Used for direct calls to // AddVariable, AddFunction, etc. @@ -132,12 +156,12 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { // Active target for configuration changes. // This is used to track which library the change is made on behalf of. ConfigRecord* absl_nonnull target_config_; - std::shared_ptr descriptor_pool_; - std::shared_ptr arena_; + TypeCheckEnv template_env_; + std::shared_ptr type_arena_; std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; - std::string container_; + absl::optional expression_container_; absl::optional expected_type_; }; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index e23c26165..fa7f80960 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -15,12 +15,15 @@ #include "checker/internal/type_checker_builder_impl.h" #include +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" @@ -42,7 +45,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; - struct ContextDeclsTestCase { std::string expr; TypeSpec expected_type; @@ -108,6 +110,168 @@ INSTANTIATE_TEST_SUITE_P( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))})); +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnEmptyFieldPaths) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {}), + StatusIs(absl::StatusCode::kInvalidArgument, + "field paths cannot be the empty set")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnUnknownFieldPath) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"}), + StatusIs(absl::StatusCode::kInvalidArgument, + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'")); +} + +class ContextDeclsWithProtoTypeMaskFieldsDefinedTest + : public testing::TestWithParam {}; + +std::string LogFieldName(absl::string_view field_name, absl::string_view expr) { + return absl::StrCat("field_name: ", field_name, ", expr: ", expr); +} + +TEST_P(ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + ContextDeclsWithProtoTypeMaskFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {GetParam().expr}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + std::vector field_names = { + "single_int64", "single_uint32", "single_double", + "single_string", "single_any", "single_duration", + "single_bool_wrapper", "list_value", "standalone_message", + "standalone_enum", "repeated_bytes", "repeated_nested_message", + "map_int32_timestamp", "single_struct"}; + for (auto& field_name : field_names) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(field_name)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + if (field_name == GetParam().expr) { + // The field name that is part of the proto type mask is visible. + ASSERT_TRUE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type) + << LogFieldName(field_name, GetParam().expr); + } else { + // The field names that are not part of the proto type mask are not + // visible. + EXPECT_FALSE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAccess) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("payload.standalone_message.bb")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int32")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAssignment) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes.NestedMessage{bb: 12345})")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int32: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int64: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -121,6 +285,20 @@ TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { "already exists")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"}), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -130,6 +308,16 @@ TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { "context declaration 'com.example.UnknownType' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask("com.example.UnknownType", + {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -140,17 +328,28 @@ TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { "context declaration 'google.protobuf.Timestamp' is not a struct")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Timestamp", {"any_field_name"}), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + TEST(ContextDeclsTest, CustomStructNotSupported) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); class MyTypeProvider : public cel::TypeIntrospector { public: - absl::StatusOr> FindTypeImpl( + absl::StatusOr> FindTypeImpl( absl::string_view name) const override { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } - return absl::nullopt; + return std::nullopt; } }; @@ -161,6 +360,28 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { "context declaration 'com.example.MyStruct' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return std::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "com.example.MyStruct", {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -180,6 +401,69 @@ TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + ErrorOnOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + NonOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.NestedTestAllTypes", + {"payload.single_int64"}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int32")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); +} + TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -194,6 +478,32 @@ TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { "variable 'single_int64' declared multiple times")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, NonOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), IsOk()); +} + TEST(TypeCheckerBuilderImplTest, InvalidTypeParamNameVariableValidationDisabled) { CheckerOptions options; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 14dce1647..bca187417 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -34,11 +36,12 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/checker_options.h" -#include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" +#include "checker/internal/type_checker_builder_impl.h" #include "checker/internal/type_inference_context.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_rewrite.h" @@ -48,7 +51,8 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" -#include "common/source.h" +#include "common/format_type_name.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -57,6 +61,15 @@ namespace cel::checker_internal { namespace { +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; @@ -66,58 +79,6 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } -SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { - const auto& source_info = ast.source_info(); - auto iter = source_info.positions().find(expr_id); - if (iter == source_info.positions().end()) { - return SourceLocation{}; - } - int32_t absolute_position = iter->second; - if (absolute_position < 0) { - return SourceLocation{}; - } - - // Find the first line offset that is greater than the absolute position. - int32_t line_idx = -1; - int32_t offset = 0; - for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { - int32_t next_offset = source_info.line_offsets()[i]; - if (next_offset <= offset) { - // Line offset is not monotonically increasing, so line information is - // invalid. - return SourceLocation{}; - } - if (absolute_position < next_offset) { - line_idx = i; - break; - } - offset = next_offset; - } - - if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { - return SourceLocation{}; - } - - int32_t rel_position = absolute_position - offset; - - return SourceLocation{line_idx + 1, rel_position}; -} - -// Special case for protobuf null fields. -bool IsPbNullFieldAssignable(const Type& value, const Type& field) { - if (field.IsNull()) { - return value.IsInt() || value.IsNull(); - } - - if (field.IsOptional() && value.IsOptional() && - field.AsOptional()->GetParameter().IsNull()) { - auto value_param = value.AsOptional()->GetParameter(); - return value_param.IsInt() || value_param.IsNull(); - } - - return false; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -238,16 +199,15 @@ class ResolveVisitor : public AstVisitorBase { struct AttributeResolution { const VariableDecl* decl; bool requires_disambiguation; + bool local; }; - ResolveVisitor(absl::string_view container, - NamespaceGenerator namespace_generator, + ResolveVisitor(NamespaceGenerator namespace_generator, const TypeCheckEnv& env, const Ast& ast, TypeInferenceContext& inference_context, std::vector& issues, google::protobuf::Arena* absl_nonnull arena) - : container_(container), - namespace_generator_(std::move(namespace_generator)), + : namespace_generator_(std::move(namespace_generator)), env_(&env), inference_context_(&inference_context), issues_(&issues), @@ -256,13 +216,23 @@ class ResolveVisitor : public AstVisitorBase { arena_(arena), current_scope_(&root_scope_) {} - void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -377,15 +347,15 @@ class ResolveVisitor : public AstVisitorBase { void ReportMissingReference(const Expr& expr, absl::string_view name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", - container_, "')"))); + env_->container().container(), "')"))); } void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); } @@ -393,7 +363,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("expected type '", FormatTypeName(inference_context_->FinalizeType(expected)), "' but found '", @@ -411,7 +381,7 @@ class ResolveVisitor : public AstVisitorBase { // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( - absl::optional field_info, + std::optional field_info, env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); @@ -421,10 +391,9 @@ class ResolveVisitor : public AstVisitorBase { if (field.optional()) { field_type = OptionalType(arena_, field_type); } - if (!inference_context_->IsAssignable(value_type, field_type) && - !IsPbNullFieldAssignable(value_type, field_type)) { + if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, field.id()), + ast_->ComputeSourceLocation(field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", FormatTypeName(inference_context_->FinalizeType(field_type)), @@ -438,10 +407,11 @@ class ResolveVisitor : public AstVisitorBase { return absl::OkStatus(); } - absl::optional CheckFieldType(int64_t expr_id, const Type& operand_type, - absl::string_view field_name); + std::optional CheckFieldType(int64_t expr_id, const Type& operand_type, + absl::string_view field_name); void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. @@ -461,7 +431,6 @@ class ResolveVisitor : public AstVisitorBase { return DynType(); } - absl::string_view container_; NamespaceGenerator namespace_generator_; const TypeCheckEnv* absl_nonnull env_; TypeInferenceContext* absl_nonnull inference_context_; @@ -475,6 +444,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses @@ -569,7 +539,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); types_[&expr] = ErrorType(); @@ -621,7 +591,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. ReportIssue(TypeCheckIssue( - Severity::kWarning, ComputeSourceLocation(*ast_, key->id()), + Severity::kWarning, ast_->ComputeSourceLocation(key->id()), absl::StrCat( "unsupported map key type: ", FormatTypeName(inference_context_->FinalizeType(key_type))))); @@ -663,8 +633,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { - // Follows list type inferencing behavior in Go (see map comments above). + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); @@ -727,7 +704,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); types_[&expr] = ErrorType(); @@ -878,7 +855,7 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, comprehension.iter_range().id()), + ast_->ComputeSourceLocation(comprehension.iter_range().id()), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(range_type)), @@ -919,8 +896,12 @@ const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( if (decl == nullptr) { return true; } + bool is_logical_op = (candidate == cel::StandardFunctions::kAnd || + candidate == cel::StandardFunctions::kOr) && + arg_count >= 2; for (const auto& ovl : decl->overloads()) { - if (ovl.member() == is_receiver && ovl.args().size() == arg_count) { + if (ovl.member() == is_receiver && + (ovl.args().size() == arg_count || is_logical_op)) { return false; } } @@ -944,12 +925,12 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } - absl::optional resolution = + std::optional resolution = inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", absl::StrJoin(arg_types, ", ", @@ -979,11 +960,10 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( absl::string_view name) { - // Note: if we see a leading dot, this shouldn't resolve to a local variable, - // but we need to check whether we need to disambiguate against a global in - // the reference map. if (absl::StartsWith(name, ".")) { - name = name.substr(1); + // Should not happen for normally parsed CEL, but prevent lookup in case + // of hand-crafted ASTs. + return nullptr; } return current_scope_->LookupLocalVariable(name); } @@ -993,7 +973,7 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } - absl::StatusOr> constant = + absl::StatusOr> constant = env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { @@ -1018,13 +998,15 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(name); - - if (local_decl != nullptr && !absl::StartsWith(name, ".")) { - attributes_[&expr] = {local_decl, false}; - types_[&expr] = - inference_context_->InstantiateTypeParams(local_decl->type()); - return; + if (!absl::StartsWith(name, ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(name); + if (local_decl != nullptr) { + attributes_[&expr] = {local_decl, /*requires_disambiguation=*/false, + /*local=*/true}; + types_[&expr] = + inference_context_->InstantiateTypeParams(local_decl->type()); + return; + } } const VariableDecl* decl = nullptr; @@ -1035,9 +1017,13 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, return decl == nullptr; }); + bool requires_disambiguation = false; + if (absl::StartsWith(name, ".")) { + requires_disambiguation = LookupLocalIdentifier(name.substr(1)) != nullptr; + } + if (decl != nullptr) { - attributes_[&expr] = {decl, - /* requires_disambiguation= */ local_decl != nullptr}; + attributes_[&expr] = {decl, requires_disambiguation, /*local=*/false}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } @@ -1053,35 +1039,49 @@ void ResolveVisitor::ResolveQualifiedIdentifier( return; } + int matched_segment_index = -1; + const VariableDecl* decl = nullptr; + bool requires_disambiguation = false; + bool is_local = false; // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); - const VariableDecl* decl = nullptr; - - int matched_segment_index = -1; - - if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) { - decl = local_decl; - matched_segment_index = 0; - } else { - namespace_generator_.GenerateCandidates( - qualifiers, [&decl, &matched_segment_index, this]( - absl::string_view candidate, int segment_index) { - decl = LookupGlobalIdentifier(candidate); - if (decl != nullptr) { - matched_segment_index = segment_index; - return false; - } - return true; - }); + if (!absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); + if (local_decl != nullptr) { + decl = local_decl; + matched_segment_index = 0; + is_local = true; + goto resolve_select_trail; + } } + namespace_generator_.GenerateCandidates( + qualifiers, [&decl, &matched_segment_index, this]( + absl::string_view candidate, int segment_index) { + decl = LookupGlobalIdentifier(candidate); + if (decl != nullptr) { + matched_segment_index = segment_index; + return false; + } + return true; + }); + if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); types_[&expr] = ErrorType(); return; } + if (absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = + LookupLocalIdentifier(qualifiers[0].substr(1)); + if (local_decl != nullptr) { + requires_disambiguation = true; + } + } + +resolve_select_trail: + const int num_select_opts = qualifiers.size() - matched_segment_index - 1; const Expr* root = &expr; @@ -1092,9 +1092,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = {decl, - /* requires_disambiguation= */ decl != local_decl && - local_decl != nullptr}; + attributes_[root] = {decl, requires_disambiguation, is_local}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. @@ -1104,9 +1102,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier( } } -absl::optional ResolveVisitor::CheckFieldType(int64_t id, - const Type& operand_type, - absl::string_view field) { +std::optional ResolveVisitor::CheckFieldType(int64_t id, + const Type& operand_type, + absl::string_view field) { if (operand_type.kind() == TypeKind::kDyn || operand_type.kind() == TypeKind::kAny) { return DynType(); @@ -1118,11 +1116,11 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, auto field_info = env_->LookupStructField(struct_type.name(), field); if (!field_info.ok()) { status_.Update(field_info.status()); - return absl::nullopt; + return std::nullopt; } if (!field_info->has_value()) { ReportUndefinedField(id, field, struct_type.name()); - return absl::nullopt; + return std::nullopt; } auto type = field_info->value().GetType(); if (type.kind() == TypeKind::kEnum) { @@ -1149,12 +1147,12 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, } ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, id), + ast_->ComputeSourceLocation(id), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), "' cannot be the operand of a select operation"))); - return absl::nullopt; + return std::nullopt; } void ResolveVisitor::ResolveSelectOperation(const Expr& expr, @@ -1162,7 +1160,7 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr, const Expr& operand) { const Type& operand_type = GetDeducedType(&operand); - absl::optional result_type; + std::optional result_type; int64_t id = expr.id(); // Support short-hand optional chaining. if (operand_type.IsOptional()) { @@ -1209,7 +1207,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { operand_type = operand_type.GetOptional().GetParameter(); } - absl::optional field_type = CheckFieldType( + std::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { types_[&expr] = ErrorType(); @@ -1226,16 +1224,56 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { } } +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // TODO(uncreated-issue/90): This is slightly inconsistent with the java + // runtime implementation which just requires the references to be acyclic. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, - Ast::ReferenceMap& references, Ast::TypeMap& types) + Ast::ReferenceMap& references, Ast::TypeMap& types, + ValidationResult::TypeMap& resolved_types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), type_map_(types), + resolved_types_(resolved_types), options_(options) {} bool PostVisitRewrite(Expr& expr) override { bool rewritten = false; @@ -1282,14 +1320,15 @@ class ResolveRewriter : public AstRewriterBase { if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { - auto flattened_type = - FlattenType(inference_context_.FinalizeType(iter->second)); + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); + resolved_types_[expr.id()] = finalized_type; rewritten = true; } @@ -1304,23 +1343,28 @@ class ResolveRewriter : public AstRewriterBase { const TypeInferenceContext& inference_context_; Ast::ReferenceMap& reference_map_; Ast::TypeMap& type_map_; + ValidationResult::TypeMap& resolved_types_; const CheckerOptions& options_; }; } // namespace -absl::StatusOr TypeCheckerImpl::Check( - std::unique_ptr ast) const { - google::protobuf::Arena type_arena; +absl::StatusOr TypeCheckerImpl::CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + std::optional type_arena; + if (arena == nullptr) { + type_arena.emplace(); + arena = &(*type_arena); + } std::vector issues; CEL_ASSIGN_OR_RETURN(auto generator, NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( - &type_arena, options_.enable_legacy_null_assignment); - ResolveVisitor visitor(env_.container(), std::move(generator), env_, *ast, - type_inference_context, issues, &type_arena); + arena, options_.enable_legacy_null_assignment); + ResolveVisitor visitor(std::move(generator), env_, *ast, + type_inference_context, issues, arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; @@ -1365,16 +1409,35 @@ absl::StatusOr TypeCheckerImpl::Check( // Apply updates as needed. // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. + ValidationResult::TypeMap resolved_types; ResolveRewriter rewriter(visitor, type_inference_context, options_, ast->mutable_reference_map(), - ast->mutable_type_map()); + ast->mutable_type_map(), resolved_types); AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); ast->set_is_checked(true); + if (options_.use_json_field_names) { + ast->mutable_source_info().mutable_extensions().push_back( + cel::ExtensionSpec("json_name", + std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime})); + } + + auto result = ValidationResult(std::move(ast), std::move(issues)); + if (!type_arena.has_value()) { + // cel::Type values will expire after this function returns when the local + // arena is destructed. Only set the resolved type map if we're using the + // caller's arena. + result.SetResolvedTypeMap(std::move(resolved_types)); + } + + return result; +} - return ValidationResult(std::move(ast), std::move(issues)); +std::unique_ptr TypeCheckerImpl::ToBuilder() const { + return std::make_unique(options_, env_); } } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 1b9062ec1..9ee9a50d0 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -22,6 +22,7 @@ #include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "google/protobuf/arena.h" @@ -41,8 +42,10 @@ class TypeCheckerImpl : public TypeChecker { TypeCheckerImpl(TypeCheckerImpl&&) = delete; TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; - absl::StatusOr Check( - std::unique_ptr ast) const override; + absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const override; + + std::unique_ptr ToBuilder() const override; private: TypeCheckEnv env_; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index c36051376..61ef7d55b 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -26,6 +26,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -33,8 +34,11 @@ #include "checker/internal/test_ast_helpers.h" #include "checker/internal/type_check_env.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/ast_proto.h" +#include "common/container.h" #include "common/decl.h" #include "common/expr.h" #include "common/source.h" @@ -43,11 +47,15 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" #include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" namespace cel { namespace checker_internal { @@ -74,6 +82,7 @@ using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; namespace testpb3 = ::cel::expr::conformance::proto3; +namespace testpb2 = ::cel::expr::conformance::proto2; std::string SevString(Severity severity) { switch (severity) { @@ -105,6 +114,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() { return &(*kArena); } +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", @@ -269,6 +289,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); @@ -286,6 +312,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); return absl::OkStatus(); } @@ -305,6 +332,78 @@ TEST(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -582,6 +681,34 @@ TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } +TEST(TypeCheckerImplTest, NamespacedFunctionWithAbbreviation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + env.set_container(*MakeExpressionContainer("", "x.y.foo")); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + TEST(TypeCheckerImplTest, MixedListTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -757,7 +884,7 @@ TEST(TypeCheckerImplTest, NestedComprehensions) { TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("com"); + env.set_container(*MakeExpressionContainer("com")); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1345,6 +1472,93 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(DynTypeSpec()))))))); } +struct VariadicLogicalCheckerTestCase { + std::string expr; +}; + +class VariadicLogicalCheckerTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalCheckerTest, Check) { + const auto& test_case = GetParam(); + + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + auto checker_builder = impl.ToBuilder(); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(parsed_ast))); + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveType::kBool))))); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalChecker, VariadicLogicalCheckerTest, + testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"}, + VariadicLogicalCheckerTestCase{"a && b && c && d"}, + VariadicLogicalCheckerTestCase{"a || b || c || d"}, + VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"}, + VariadicLogicalCheckerTestCase{"a && b && c"}, + VariadicLogicalCheckerTestCase{"a || b || c"}, + VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"}, + VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"})); + +TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) { + cel::expr::ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + } + } + )pb", + &parsed_expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + impl.Check(std::move(parsed_ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference"))); +} + TEST(TypeCheckerImplTest, ExpectedTypeMatches) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -1383,6 +1597,44 @@ TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { "expected type 'map(string, string)' but found 'map(string, int)'"))); } +TEST(TypeCheckerImplTest, ToBuilder) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + TypeCheckerImpl impl(std::move(env)); + auto builder = impl.ToBuilder(); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN(auto new_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + new_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, ToBuilderPropagatesArena) { + auto arena = std::make_shared(); + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_arena(arena); + + Type list_type = ListType(arena.get(), IntType()); + ASSERT_TRUE( + env.InsertVariableIfAbsent(MakeVariableDecl("my_list", list_type))); + + auto base_checker = std::make_unique(std::move(env)); + + std::unique_ptr builder = base_checker->ToBuilder(); + + base_checker.reset(); + arena.reset(); + + ASSERT_OK_AND_ASSIGN(auto derived_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("my_list")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + derived_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -1462,7 +1714,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); @@ -1483,7 +1735,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); CheckerOptions options; @@ -1508,7 +1760,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -1538,7 +1790,7 @@ TEST_P(WktCreationTest, MessageCreation) { const CheckedExprTestCase& test_case = GetParam(); TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1688,15 +1940,229 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(PrimitiveType::kBool), })); +TEST(AliasTest, ImportVariable) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr.conformance", + "com.example.TestVariable1", + "com.example.TestVariable2")); + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable1", + MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable2", + MessageType(testpb2::TestAllTypes::descriptor())))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + "TestVariable1.single_int64 == TestVariable2.single_int64")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + ASSERT_TRUE(checked_ast->root_expr().has_call_expr()); + ASSERT_EQ(checked_ast->root_expr().call_expr().function(), "_==_"); + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[0] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable1"); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[1] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable2"); +} + +TEST(AliasTest, AliasToContainerResolvesMessage) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))))); + + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, + "cel.expr.conformance.proto3.TestAllTypes")))); + + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(AliasTest, AliasSimpleName) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("foo", "bar"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertOrReplaceVariable(MakeVariableDecl("bar", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "bar"); +} + +TEST(AliasTest, AliasPreventsContainerResolution) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr")); + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("cel.expr.pb3.FooVariable", IntType()))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("expr.pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), + "cel.expr.pb3.FooVariable"); + } +} + +TEST(AliasTest, AliasPreventsDisambiguation) { + // Copying behavior from cel-go and cel-java. + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + env.InsertOrReplaceVariable(MakeVariableDecl("pb3.Foo", IntType())); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst(".pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.Foo'"))); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(".pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to '.pb3.Foo'"))); + } +} + class GenericMessagesTest : public testing::TestWithParam { }; -TEST_P(GenericMessagesTest, TypeChecksProto3) { +TEST_P(GenericMessagesTest, TypeChecksProto3Imports) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer( + "", "cel.expr.conformance.proto3.TestAllTypes", + "cel.expr.conformance.proto3.NestedTestAllTypes")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( @@ -1714,11 +2180,40 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { return; } - ASSERT_TRUE(result.IsValid()) - << absl::StrJoin(result.GetIssues(), "\n", - [](std::string* out, const TypeCheckIssue& issue) { - absl::StrAppend(out, issue.message()); - }); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); +} + +TEST_P(GenericMessagesTest, TypeChecksProto3Container) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); @@ -1839,6 +2334,12 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_nested_message: " + "[TestAllTypes.NestedMessage{bb: 42}]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: duration('1s')}", .expected_result_type = AstType( @@ -1966,11 +2467,6 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, - CheckedExprTestCase{ - .expr = "TestAllTypes{null_value: null}", - .expected_result_type = AstType( - MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), - }, // Legacy nullability behaviors. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", @@ -2252,7 +2748,7 @@ TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 96d985071..4f738b804 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -28,8 +28,9 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/format_type_name.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -133,7 +134,7 @@ FunctionOverloadInstance InstantiateFunctionOverload( // Converts a wrapper type to its corresponding primitive type. // Returns nullopt if the type is not a wrapper type. -absl::optional WrapperToPrimitive(const Type& t) { +std::optional WrapperToPrimitive(const Type& t) { switch (t.kind()) { case TypeKind::kBoolWrapper: return BoolType(); @@ -148,7 +149,7 @@ absl::optional WrapperToPrimitive(const Type& t) { case TypeKind::kUintWrapper: return UintType(); default: - return absl::nullopt; + return std::nullopt; } } @@ -286,7 +287,7 @@ bool TypeInferenceContext::IsAssignableInternal( } // Type is as concrete as it can be under current substitutions. - if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); + if (std::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, @@ -531,27 +532,34 @@ bool TypeInferenceContext::IsAssignableWithConstraints( return false; } -absl::optional +std::optional TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, absl::Span argument_types, bool is_receiver) { - absl::optional result_type; + std::optional result_type; + + bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd || + decl.name() == cel::StandardFunctions::kOr) && + argument_types.size() >= 2; std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { if (ovl.member() != is_receiver || - argument_types.size() != ovl.args().size()) { + (!is_logical_op && argument_types.size() != ovl.args().size())) { continue; } auto call_type_instance = InstantiateFunctionOverload(*this, ovl); - ABSL_DCHECK_EQ(argument_types.size(), - call_type_instance.param_types.size()); + if (!is_logical_op) { + ABSL_DCHECK_EQ(argument_types.size(), + call_type_instance.param_types.size()); + } bool is_match = true; AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { + int param_index = is_logical_op ? 0 : i; if (!assignability_context.IsAssignable( - argument_types[i], call_type_instance.param_types[i])) { + argument_types[i], call_type_instance.param_types[param_index])) { is_match = false; break; } @@ -571,7 +579,7 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, } if (!result_type.has_value() || matching_overloads.empty()) { - return absl::nullopt; + return std::nullopt; } return OverloadResolution{ .result_type = FullySubstitute(*result_type, /*free_to_dyn=*/false), @@ -649,14 +657,14 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, std::string TypeInferenceContext::DebugString() const { return absl::StrCat( "type_parameter_bindings: ", - absl::StrJoin( - type_parameter_bindings_, "\n ", - [](std::string* out, const auto& binding) { - absl::StrAppend( - out, binding.first, " (", binding.second.name, ") -> ", - checker_internal::FormatTypeName( - binding.second.type.value_or(Type(TypeParamType("none"))))); - })); + absl::StrJoin(type_parameter_bindings_, "\n ", + [](std::string* out, const auto& binding) { + absl::StrAppend( + out, binding.first, " (", binding.second.name, + ") -> ", + cel::FormatTypeName(binding.second.type.value_or( + Type(TypeParamType("none"))))); + })); } void TypeInferenceContext::AssignabilityContext:: diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index d1bf7fa6d..458d08ff1 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -291,7 +291,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadBasic) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); @@ -309,7 +309,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadFails) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -324,7 +324,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -341,7 +341,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_a}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); } @@ -359,7 +359,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_int}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); @@ -375,7 +375,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsBool()); @@ -394,7 +394,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload( decl, {list_of_a_instance, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution.has_value()); @@ -407,7 +407,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {ListType(&arena, IntType()), list_of_a_instance}, false); ASSERT_TRUE(resolution2.has_value()); @@ -433,7 +433,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); EXPECT_FALSE(resolution.has_value()); } @@ -450,13 +450,13 @@ TEST(TypeInferenceContextTest, InferencesAccumulate) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution1 = + std::optional resolution1 = context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, false); ASSERT_TRUE(resolution1.has_value()); EXPECT_TRUE(resolution1->result_type.IsList()); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {resolution1->result_type, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution2.has_value()); @@ -480,7 +480,7 @@ TEST(TypeInferenceContextTest, DebugString) { MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_int, list_of_int}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); @@ -517,7 +517,7 @@ class TypeInferenceContextWrapperTypesTest TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapped_primitive_type}, @@ -534,7 +534,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload( ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); @@ -550,7 +550,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, NullType()}, false); @@ -566,7 +566,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), NullType(), test_case.wrapper_type}, false); @@ -582,7 +582,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapped_primitive_type, test_case.wrapper_type}, @@ -622,7 +622,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -648,7 +648,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -680,7 +680,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(to_type_decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 28ae9a889..87c14f0cd 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -267,15 +267,16 @@ INSTANTIATE_TEST_SUITE_P( IsOptionalType(TypeSpec(PrimitiveType::kString))}, TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", IsOptionalType(TypeSpec(PrimitiveType::kString))}, - // Legacy nullability behaviors. TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}", - Eq(TypeSpec(MessageTypeSpec( - "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " + // Legacy nullability behaviors. + TestCase{ + "cel.expr.conformance.proto3.TestAllTypes{?single_value: null}", + Eq(TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_value: " "optional.of(null)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, diff --git a/checker/standard_library.cc b/checker/standard_library.cc index 4cd9e9831..744a171ef 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -14,6 +14,7 @@ #include "checker/standard_library.h" +#include #include #include "absl/base/no_destructor.h" @@ -833,11 +834,8 @@ absl::Status AddTypeConstantVariables(TypeCheckerBuilder& builder) { absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { VariableDecl pb_null; pb_null.set_name("google.protobuf.NullValue.NULL_VALUE"); - // TODO(uncreated-issue/74): This is interpreted as an enum (int) or null in - // different cases. We should add some additional spec tests to cover this and - // update the behavior to be consistent. pb_null.set_type(IntType()); - pb_null.set_value(Constant(nullptr)); + pb_null.set_value(Constant(int64_t{0})); CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(pb_null))); return absl::OkStatus(); } diff --git a/checker/type_checker.cc b/checker/type_checker.cc new file mode 100644 index 000000000..6d59e144d --- /dev/null +++ b/checker/type_checker.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker.h" + +namespace cel { +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast) const { + return CheckImpl(std::move(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::move(ast), arena); +} + +absl::StatusOr TypeChecker::Check(const Ast& ast) const { + return CheckImpl(std::make_unique(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + const Ast& ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::make_unique(ast), arena); +} +} // namespace cel diff --git a/checker/type_checker.h b/checker/type_checker.h index 993eafb71..edb6cc91f 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -16,13 +16,18 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #include +#include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "google/protobuf/arena.h" namespace cel { +class TypeCheckerBuilder; + // TypeChecker interface. // // Checks references and type agreement for a parsed CEL expression. @@ -40,10 +45,19 @@ class TypeChecker { // A non-ok status is returned if type checking can't reasonably complete // (e.g. if an internal precondition is violated or an extension returns an // error). - virtual absl::StatusOr Check( - std::unique_ptr ast) const = 0; + absl::StatusOr Check(std::unique_ptr ast) const; + absl::StatusOr Check(std::unique_ptr ast, + google::protobuf::Arena* arena) const; + absl::StatusOr Check(const Ast& ast) const; + absl::StatusOr Check(const Ast& ast, + google::protobuf::Arena* arena) const; + + // Returns a builder initialized with the configuration of this type checker. + virtual std::unique_ptr ToBuilder() const = 0; - // TODO(uncreated-issue/73): add overload for cref AST. + private: + virtual absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; }; } // namespace cel diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index e5942b157..c2d0cbf7b 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -17,6 +17,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" @@ -25,6 +26,7 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/type_checker.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -34,7 +36,6 @@ namespace cel { class TypeCheckerBuilder; -class TypeCheckerBuilderImpl; // Functional implementation to apply the library features to a // TypeCheckerBuilder. @@ -51,7 +52,7 @@ struct CheckerLibrary { // Represents a declaration to only use a subset of a library. struct TypeCheckerSubset { using FunctionPredicate = absl::AnyInvocable; + absl::string_view function, const OverloadDecl& overload) const>; // The id of the library to subset. Only one subset can be applied per // library id. @@ -102,6 +103,27 @@ class TypeCheckerBuilder { // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // This version accepts a mask in terms of field selections from the + // context type. The mask specifies which fields are visible on the + // struct and its members. The visible fields for a type accumulate + // across calls. This is a lightweight way to adjust the type checking + // behavior for a group of related types. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; @@ -132,10 +154,16 @@ class TypeCheckerBuilder { // // This is used for resolving references in the expressions being built. // + // Prefer setting the container via SetExpressionContainer(). + // // Note: if set multiple times, the last value is used. This can lead to - // surprising behavior if used in a custom library. + // surprising behavior if used in a custom library. If container is not a + // valid container name, the operation is ignored. virtual void set_container(absl::string_view container) = 0; + virtual void SetExpressionContainer( + ExpressionContainer expression_container) = 0; + // The current options for the TypeChecker being built. virtual const CheckerOptions& options() const = 0; diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index d5cf47fee..40406948d 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -23,10 +23,12 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" +#include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "internal/status_macros.h" @@ -233,8 +235,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id == "add_int" || overload_id == "sub_int"); + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() == "add_int" || overload.id() == "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -272,9 +274,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id != "add_int" && overload_id != "sub_int"); - ; + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() != "add_int" && overload.id() != "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -311,7 +312,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function != "add"; }}), IsOk()); @@ -350,12 +351,12 @@ TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), IsOk()); EXPECT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), StatusIs(absl::StatusCode::kAlreadyExists)); } @@ -367,7 +368,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); EXPECT_THAT(builder->AddLibrarySubset({"", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function == "add"; }}), StatusIs(absl::StatusCode::kInvalidArgument)); @@ -394,6 +395,27 @@ TEST(TypeCheckerBuilderTest, AddContextDeclaration) { EXPECT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, AddContextDeclarationWithProtoTypeMask) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, @@ -426,6 +448,32 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, + AllowWellKnownTypeContextDeclarationWithProtoTypeMask) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Any", {"value"}), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + // Visible field: value + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("value")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not visible field: type_url + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("type_url")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; @@ -464,7 +512,7 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationValue) { // Note: one of fields are all added with safe traversal, so // we lose the union discriminator information. R"cel( - null_value == null && + null_value == 0 && number_value == 0.0 && string_value == '' && list_value == [] && @@ -496,6 +544,113 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationInt64Value) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, ContextDeclarationWithJsonName) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("cel.cpp.testutil.TestJsonNames"), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(int32_snake_case_json_name == 1 && + int64CamelCaseJsonName == 2 && + uint32DefaultJsonName == 3u && + // `uint64-custom-json-name` == 4u && + single_string == 'shadows' && + singleString == 'shadowed')cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionStructCreation) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(cel.cpp.testutil.TestJsonNames{ + int32_snake_case_json_name: 1, + int64CamelCaseJsonName: 2, + uint32DefaultJsonName: 3u, + `uint64-custom-json-name`: 4u, + single_string: 'shadows', + singleString: 'shadowed' + })cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), + TypeSpec(MessageTypeSpec("cel.cpp.testutil.TestJsonNames"))); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionFieldAccess) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + builder->AddVariable(MakeVariableDecl( + "jsonObj", + cel::MessageType(builder->descriptor_pool()->FindMessageTypeByName( + "cel.cpp.testutil.TestJsonNames")))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel( + jsonObj.int32_snake_case_json_name == 1 && + jsonObj.int64CamelCaseJsonName == 2 && + jsonObj.uint32DefaultJsonName == 3u && + jsonObj.`uint64-custom-json-name` == 4u && + jsonObj.single_string == 'shadows' && + jsonObj.singleString == 'shadowed' && + jsonObj.`cel.cpp.testutil.int32_snake_case_ext` == 5 && + jsonObj.`cel.cpp.testutil.int64CamelCaseExt` == 6 + )cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, @@ -635,5 +790,63 @@ TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); } +TEST(TypeCheckerBuilderTest, ToBuilderIndependenceAndInheritance) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("addOne", + MakeOverloadDecl("addOne_int", IntType(), IntType()))); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); + + // Exercise checker1. + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("addOne(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result1, + checker1->Check(std::move(ast))); + EXPECT_TRUE(result1.IsValid()); + } + + // Start new builder via ToBuilder. + auto builder2 = checker1->ToBuilder(); + ASSERT_THAT(builder2->AddVariable(MakeVariableDecl("y", IntType())), IsOk()); + ASSERT_THAT(builder2->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder2->SetExpectedType(IntType()); + + ASSERT_OK_AND_ASSIGN(auto checker2, builder2->Build()); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("optional.of(addOne(x)).orValue(0) + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker2->Check(std::move(ast))); + EXPECT_TRUE(result2.IsValid()); + } + + // Demonstrate checker1 is unmodified and independent (still does not know + // about y). + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result_y_checker1_again, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result_y_checker1_again.IsValid()); + } + + // Same for optional library functions. + { + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("optional.none().orValue(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + } +} + } // namespace } // namespace cel diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index 6a05ce220..1b146c5a5 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -21,14 +21,21 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/signature.h" namespace cel { TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return true; + } + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + return signature.ok() && overload_ids.contains(*signature); }; } @@ -41,8 +48,13 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return !overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return false; + } + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + return !signature.ok() || !overload_ids.contains(*signature); }; } diff --git a/checker/type_checker_subset_factory_test.cc b/checker/type_checker_subset_factory_test.cc index fa38e1c0d..5b644ec7c 100644 --- a/checker/type_checker_subset_factory_test.cc +++ b/checker/type_checker_subset_factory_test.cc @@ -43,6 +43,8 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals, StandardOverloadIds::kNotStrictlyFalse, + "matches(string,string)", + "string.matches(string)", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -65,15 +67,19 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); + // Allowed by signature. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_TRUE(r.IsValid()); + // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_FALSE(r.IsValid()); - - ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); - EXPECT_FALSE(r.IsValid()); } TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { @@ -83,6 +89,8 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { absl::string_view exclude_list[] = { StandardOverloadIds::kMatches, StandardOverloadIds::kMatchesMember, + "size(string)", + "string.size()", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -105,18 +113,35 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); - // Not in allowlist. + // Allowed. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_TRUE(r.IsValid()); + // Excluded by ID. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); EXPECT_FALSE(r.IsValid()); + + // Excluded by signature (top-level function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by signature (member function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size member). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()")); + EXPECT_TRUE(r.IsValid()); } } // namespace diff --git a/checker/validation_result.h b/checker/validation_result.h index 45f949739..7417e9969 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -15,26 +15,32 @@ #ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ +#include #include #include #include #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" +#include "common/decl.h" #include "common/source.h" +#include "common/type.h" namespace cel { -// ValidationResult holds the result of TypeChecking. +// ValidationResult holds the result of type checking. // // Error states are captured as type check issues where possible. class ValidationResult { public: + using TypeMap = absl::flat_hash_map; + ValidationResult(std::unique_ptr ast, std::vector issues) : ast_(std::move(ast)), issues_(std::move(issues)) {} @@ -58,6 +64,8 @@ class ValidationResult { absl::Span GetIssues() const { return issues_; } + void AddIssue(TypeCheckIssue issue) { issues_.push_back(std::move(issue)); } + // The source expression may optionally be set if it is available. const cel::Source* absl_nullable GetSource() const { return source_.get(); } @@ -69,6 +77,18 @@ class ValidationResult { return std::move(source_); } + // Returns the resolved type map for the AST. + // + // Only populated if the AST was checked with an explicit arena. + // + // The type entries may have storage in the arena or reference type + // information from the type checker that produced the AST. This means the map + // is only valid as long as both the type checker and the arena are valid. + const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } + void SetResolvedTypeMap(TypeMap resolved_type_map) { + resolved_type_map_ = std::move(resolved_type_map); + } + // Returns a string representation of the issues in the result suitable for // display. // @@ -87,6 +107,7 @@ class ValidationResult { private: absl_nullable std::unique_ptr ast_; + TypeMap resolved_type_map_; std::vector issues_; absl_nullable std::unique_ptr source_; }; diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 8272378f6..dec359f25 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,5 +1,5 @@ steps: -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' args: - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' @@ -16,7 +16,7 @@ steps: - '--google_default_credentials' id: gcc-9 waitFor: ['-'] -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' env: - 'CC=clang-11' - 'CXX=clang++-11' diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index f4f729827..6cc1505a9 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -213,8 +213,7 @@ absl::Status NetworkAddressRepEqual( return absl::OkStatus(); } const NetworkAddressRep rep = content.To(); - absl::optional other_rep = - NetworkAddressRep::Unwrap(other); + std::optional other_rep = NetworkAddressRep::Unwrap(other); ABSL_DCHECK(other_rep.has_value()); *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); return absl::OkStatus(); @@ -311,7 +310,7 @@ cel::Value parseAddress( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); } @@ -321,7 +320,7 @@ cel::Value parseAddress( cel::Value parseAddressOrZero(const cel::StringValue& str) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); static const NetworkAddressRep kZero; if (!rep.has_value()) { return NetworkAddressRep::MakeValue(kZero); @@ -336,8 +335,7 @@ cel::Value parseAddressMatcher( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = - NetworkAddressMatcher::Parse(addr); + std::optional rep = NetworkAddressMatcher::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue( absl::InvalidArgumentError("invalid address matcher")); @@ -365,12 +363,12 @@ cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { cel::OpaqueValueContent::From(rep)); } -absl::optional NetworkAddressRep::Unwrap( +std::optional NetworkAddressRep::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || opaque->GetTypeId() != cel::TypeId()) { - return absl::nullopt; + return std::nullopt; } // Note: safety depends on: @@ -381,16 +379,16 @@ absl::optional NetworkAddressRep::Unwrap( return opaque->content().To(); } -absl::optional NetworkAddressRep::Parse( +std::optional NetworkAddressRep::Parse( absl::string_view str) { uint32_t ipv4 = 0; char ipv6[16]; auto version = ParseAddressImpl(str, &ipv4, ipv6); if (!version.ok()) { - return absl::nullopt; + return std::nullopt; } if (*version != IpVersion::kIPv4) { - return absl::nullopt; + return std::nullopt; } NetworkAddressRep rep; rep.version_ = *version; @@ -418,13 +416,13 @@ bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { return false; } -absl::optional NetworkAddressMatcher::Parse( +std::optional NetworkAddressMatcher::Parse( absl::string_view str) { // range style addr-addr int dash_pos = str.find('-'); if (dash_pos == absl::string_view::npos) { // TODO(uncreated-issue/86): CIDR style addr/prefix-length - return absl::nullopt; + return std::nullopt; } absl::string_view min_str = str.substr(0, dash_pos); absl::string_view max_str = str.substr(dash_pos + 1); @@ -433,23 +431,23 @@ absl::optional NetworkAddressMatcher::Parse( NetworkRangev6 v6; auto min_parse = ParseAddressImpl(min_str, &v4.min_incl, v6.min_incl); if (!min_parse.ok()) { - return absl::nullopt; + return std::nullopt; } auto max_parse = ParseAddressImpl(max_str, &v4.max_incl, v6.max_incl); if (!max_parse.ok()) { - return absl::nullopt; + return std::nullopt; } if (*min_parse != *max_parse) { - return absl::nullopt; + return std::nullopt; } NetworkAddressMatcher rep; if (*min_parse == IpVersion::kIPv4) { if (v4.min_incl > v4.max_incl) { - return absl::nullopt; + return std::nullopt; } rep.ranges_v4_.push_back(v4); } else if (*min_parse == IpVersion::kIPv6) { - return absl::nullopt; + return std::nullopt; } return rep; diff --git a/common/BUILD b/common/BUILD index da96b1c98..0426c0827 100644 --- a/common/BUILD +++ b/common/BUILD @@ -25,6 +25,7 @@ cc_library( hdrs = ["ast.h"], deps = [ ":expr", + ":source", "//common/ast:metadata", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", @@ -39,11 +40,81 @@ cc_test( deps = [ ":ast", ":expr", + ":source", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", ], ) +cc_library( + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + ":type_kind", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], + deps = [ + ":ast", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + ":ast", + ":type", + ":type_spec_resolver", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":ast", + ":signature", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], @@ -110,14 +181,14 @@ cc_library( hdrs = ["decl.h"], deps = [ ":constant", + ":signature", ":type", ":type_kind", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -367,7 +438,6 @@ cc_library( ":allocator", ":arena", ":data", - ":native_type", ":reference_count", "//common/internal:metadata", "//common/internal:reference_count", @@ -389,13 +459,9 @@ cc_test( ":allocator", ":data", ":memory", - ":native_type", "//common/internal:reference_count", "//internal:testing", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], @@ -541,12 +607,9 @@ cc_library( ], ) + [ "type.h", - "type_factory.h", "type_introspector.h", - "type_manager.h", ], deps = [ - ":memory", ":type_kind", "//internal:string_pool", "@com_google_absl//absl/algorithm:container", @@ -573,6 +636,17 @@ cc_library( ], ) +cc_library( + name = "format_type_name", + srcs = ["format_type_name.cc"], + hdrs = ["format_type_name.h"], + deps = [ + ":type", + ":type_kind", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "type_test", srcs = glob([ @@ -595,6 +669,18 @@ cc_test( ], ) +cc_test( + name = "format_type_name_test", + srcs = ["format_type_name_test.cc"], + deps = [ + ":format_type_name", + ":type", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "value", srcs = glob( @@ -991,9 +1077,6 @@ cc_library( deps = [ ":decl", ":decl_proto", - ":type", - ":type_proto", - "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1140,3 +1223,26 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "container", + srcs = ["container.cc"], + hdrs = ["container.h"], + deps = [ + "//internal:lexis", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "container_test", + srcs = ["container_test.cc"], + deps = [ + ":container", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) diff --git a/common/ast.cc b/common/ast.cc index aea153197..48b6f5e0b 100644 --- a/common/ast.cc +++ b/common/ast.cc @@ -19,6 +19,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "common/ast/metadata.h" +#include "common/source.h" namespace cel { namespace { @@ -57,4 +58,41 @@ const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { return &iter->second; } +SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { + const auto& source_info = this->source_info(); + auto iter = source_info.positions().find(expr_id); + if (iter == source_info.positions().end()) { + return SourceLocation{}; + } + int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. + int32_t line_idx = -1; + int32_t offset = 0; + for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { + line_idx = i; + break; + } + offset = next_offset; + } + + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; + } + + int32_t rel_position = absolute_position - offset; + + return SourceLocation{line_idx + 1, rel_position}; +} + } // namespace cel diff --git a/common/ast.h b/common/ast.h index 1b07b9878..afd0575ad 100644 --- a/common/ast.h +++ b/common/ast.h @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" +#include "common/source.h" namespace cel { @@ -135,6 +136,13 @@ class Ast final { expr_version_ = expr_version; } + // Computes the source location (line and column) for the given expression ID + // from the source info (which stores absolute positions). + // + // Returns a default (empty) source location if the expression ID is not found + // or the source info is not populated correctly. + SourceLocation ComputeSourceLocation(int64_t expr_id) const; + private: Expr root_expr_; SourceInfo source_info_; diff --git a/common/ast/BUILD b/common/ast/BUILD index 410d38c65..17456566b 100644 --- a/common/ast/BUILD +++ b/common/ast/BUILD @@ -98,10 +98,13 @@ cc_library( deps = [ "//common:constant", "//common:expr", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], diff --git a/common/ast/constant_proto.cc b/common/ast/constant_proto.cc index c0fe1c9f6..1982c05b4 100644 --- a/common/ast/constant_proto.cc +++ b/common/ast/constant_proto.cc @@ -35,7 +35,7 @@ using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, ConstantProto* absl_nonnull proto) { return absl::visit(absl::Overload( - [proto](absl::monostate) -> absl::Status { + [proto](std::monostate) -> absl::Status { proto->clear_constant_kind(); return absl::OkStatus(); }, diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc index f744deb00..38f7ef610 100644 --- a/common/ast/metadata.cc +++ b/common/ast/metadata.cc @@ -14,11 +14,18 @@ #include "common/ast/metadata.h" +#include #include +#include +#include +#include #include #include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" #include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" #include "absl/types/variant.h" namespace cel { @@ -30,6 +37,96 @@ const TypeSpec& DefaultTypeSpec() { return *type; } +std::string FormatPrimitive(PrimitiveType t) { + switch (t) { + case PrimitiveType::kBool: + return "bool"; + case PrimitiveType::kInt64: + return "int"; + case PrimitiveType::kUint64: + return "uint"; + case PrimitiveType::kDouble: + return "double"; + case PrimitiveType::kString: + return "string"; + case PrimitiveType::kBytes: + return "bytes"; + default: + return "*unspecified primitive*"; + } +} + +std::string FormatWellKnown(WellKnownTypeSpec t) { + switch (t) { + case WellKnownTypeSpec::kAny: + return "google.protobuf.Any"; + case WellKnownTypeSpec::kDuration: + return "google.protobuf.Duration"; + case WellKnownTypeSpec::kTimestamp: + return "google.protobuf.Timestamp"; + default: + return "*unspecified well known*"; + } +} + +using FormatIns = std::variant; +using FormatStack = std::vector; + +void HandleFormatTypeSpec(const TypeSpec& t, FormatStack& stack, + std::string* out) { + if (t.has_dyn()) { + absl::StrAppend(out, "dyn"); + } else if (t.has_null()) { + absl::StrAppend(out, "null"); + } else if (t.has_primitive()) { + absl::StrAppend(out, FormatPrimitive(t.primitive())); + } else if (t.has_wrapper()) { + absl::StrAppend(out, "wrapper(", FormatPrimitive(t.wrapper()), ")"); + } else if (t.has_well_known()) { + absl::StrAppend(out, FormatWellKnown(t.well_known())); + return; + } else if (t.has_abstract_type()) { + const auto& abs_type = t.abstract_type(); + if (abs_type.parameter_types().empty()) { + absl::StrAppend(out, abs_type.name()); + return; + } + absl::StrAppend(out, abs_type.name(), "("); + stack.push_back(")"); + for (size_t i = abs_type.parameter_types().size(); i > 0; --i) { + stack.push_back(&abs_type.parameter_types()[i - 1]); + if (i > 1) { + stack.push_back(", "); + } + } + + } else if (t.has_type()) { + if (t.type() == TypeSpec()) { + absl::StrAppend(out, "type"); + return; + } + absl::StrAppend(out, "type("); + stack.push_back(")"); + stack.push_back(&t.type()); + } else if (t.has_message_type()) { + absl::StrAppend(out, t.message_type().type()); + } else if (t.has_type_param()) { + absl::StrAppend(out, t.type_param().type()); + } else if (t.has_list_type()) { + absl::StrAppend(out, "list("); + stack.push_back(")"); + stack.push_back(&t.list_type().elem_type()); + } else if (t.has_map_type()) { + absl::StrAppend(out, "map("); + stack.push_back(")"); + stack.push_back(&t.map_type().value_type()); + stack.push_back(", "); + stack.push_back(&t.map_type().key_type()); + } else { + absl::StrAppend(out, "*error*"); + } +} + TypeSpecKind CopyImpl(const TypeSpecKind& other) { return absl::visit( absl::Overload( @@ -61,12 +158,18 @@ const ExtensionSpec& ExtensionSpec::DefaultInstance() { ExtensionSpec::ExtensionSpec(const ExtensionSpec& other) : id_(other.id_), affected_components_(other.affected_components_), - version_(std::make_unique(*other.version_)) {} + version_(other.version_ == nullptr + ? nullptr + : std::make_unique(*other.version_)) {} ExtensionSpec& ExtensionSpec::operator=(const ExtensionSpec& other) { id_ = other.id_; affected_components_ = other.affected_components_; - version_ = std::make_unique(*other.version_); + if (other.version_ != nullptr) { + version_ = std::make_unique(other.version()); + } else { + version_ = nullptr; + } return *this; } @@ -136,4 +239,24 @@ FunctionTypeSpec& FunctionTypeSpec::operator=(const FunctionTypeSpec& other) { return *this; } +std::string FormatTypeSpec(const TypeSpec& t) { + // Use a stack to avoid recursion. + // Probably overly defensive, but fuzzers will often notice the recursion + // and try to trigger it. + std::string out; + FormatStack seq; + seq.push_back(&t); + while (!seq.empty()) { + FormatIns ins = std::move(seq.back()); + seq.pop_back(); + if (std::holds_alternative(ins)) { + absl::StrAppend(&out, std::get(ins)); + continue; + } + ABSL_DCHECK(std::holds_alternative(ins)); + HandleFormatTypeSpec(*std::get(ins), seq, &out); + } + return out; +} + } // namespace cel diff --git a/common/ast/metadata.h b/common/ast/metadata.h index a82e999f8..1a69b5b50 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -573,6 +573,10 @@ class TypeSpec { TypeSpecKind& mutable_type_kind() { return type_kind_; } + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + bool has_dyn() const { return absl::holds_alternative(type_kind_); } @@ -740,6 +744,9 @@ class TypeSpec { TypeSpecKind type_kind_; }; +// Returns a string representation of the given TypeSpec. +std::string FormatTypeSpec(const TypeSpec& t); + // Describes a resolved reference to a declaration. class Reference { public: diff --git a/common/ast/metadata_test.cc b/common/ast/metadata_test.cc index 4afb0d07d..5553f4c8f 100644 --- a/common/ast/metadata_test.cc +++ b/common/ast/metadata_test.cc @@ -25,6 +25,8 @@ namespace cel { namespace { +using ::testing::ElementsAre; + TEST(AstTest, ListTypeSpecMutableConstruction) { ListTypeSpec type; type.mutable_elem_type() = TypeSpec(PrimitiveType::kBool); @@ -264,5 +266,34 @@ TEST(AstTest, ExtensionSpecEquality) { std::make_unique(0, 0), {})); } +TEST(AstTest, ExtensionCopyMove) { + ExtensionSpec a("constant_folding", nullptr, {}); + a.mutable_version().set_major(1); + a.mutable_version().set_minor(2); + a.mutable_affected_components().push_back(ExtensionSpec::Component::kRuntime); + + ExtensionSpec b(a); + + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 1); + EXPECT_EQ(b.version().minor(), 2); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + ExtensionSpec c(std::move(b)); + EXPECT_EQ(c, a); + + a.set_version(nullptr); + b = a; + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 0); + EXPECT_EQ(b.version().minor(), 0); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + c = std::move(b); + EXPECT_EQ(c, a); +} + } // namespace } // namespace cel diff --git a/common/ast_rewrite.cc b/common/ast_rewrite.cc index 14582f44f..b61e1fab6 100644 --- a/common/ast_rewrite.cc +++ b/common/ast_rewrite.cc @@ -54,7 +54,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/ast_test.cc b/common/ast_test.cc index 744b9e8d3..56e1bcd1e 100644 --- a/common/ast_test.cc +++ b/common/ast_test.cc @@ -18,6 +18,7 @@ #include "absl/container/flat_hash_map.h" #include "common/expr.h" +#include "common/source.h" #include "internal/testing.h" namespace cel { @@ -132,5 +133,56 @@ TEST(AstImpl, CheckedExprDeepCopy) { EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); } +TEST(AstImpl, ComputeSourceLocation) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20, 30}); + source_info.mutable_positions()[1] = 0; // Start of first line + source_info.mutable_positions()[2] = 5; // Middle of first line + source_info.mutable_positions()[3] = 10; // ... + source_info.mutable_positions()[4] = 15; + source_info.mutable_positions()[5] = 20; + source_info.mutable_positions()[6] = 25; + + Ast ast(Expr{}, std::move(source_info)); + + EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); +} + +TEST(AstImpl, ComputeSourceLocationFailures) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20}); + source_info.mutable_positions()[1] = -1; // Negative position + source_info.mutable_positions()[2] = 25; // Beyond last line offset + // ID 3 is missing + + Ast ast; + ast.mutable_source_info() = std::move(source_info); + + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); +} + +TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { + { + // Empty line offsets + Ast ast; + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } + { + // Non-monotonic + SourceInfo source_info; + source_info.set_line_offsets({10, 5}); + source_info.mutable_positions()[1] = 12; + Ast ast(Expr{}, std::move(source_info)); + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } +} + } // namespace } // namespace cel diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc index a6ba0d1ba..fb4f9731e 100644 --- a/common/ast_traverse.cc +++ b/common/ast_traverse.cc @@ -53,7 +53,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/container.cc b/common/container.cc new file mode 100644 index 000000000..e1db8f86c --- /dev/null +++ b/common/container.cc @@ -0,0 +1,171 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/lexis.h" + +namespace cel { +namespace { + +bool IsValidQualifiedName(absl::string_view name) { + auto dot_pos = name.find('.'); + while (dot_pos != absl::string_view::npos) { + if (!internal::LexisIsIdentifier(name.substr(0, dot_pos))) { + return false; + } + name = name.substr(dot_pos + 1); + dot_pos = name.find('.'); + } + return internal::LexisIsIdentifier(name); +} + +bool IsValidAlias(absl::string_view alias) { + return internal::LexisIsIdentifier(alias); +} + +bool IsAbbreviationImpl(absl::string_view alias, absl::string_view name) { + auto pos = name.rfind('.'); + return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && + alias == name.substr(pos + 1); +} + +} // namespace + +bool ExpressionContainer::AliasListing::IsAbbreviation() const { + return IsAbbreviationImpl(alias, name); +} + +absl::StatusOr MakeExpressionContainer( + absl::string_view name) { + ExpressionContainer container; + + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + return container; +} + +absl::Status ExpressionContainer::SetContainer(absl::string_view name) { + if (name.empty()) { + container_ = ""; + return absl::OkStatus(); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + for (const auto& entry : aliases_) { + const std::string& alias = entry.first; + if (name == alias || + (name.size() > alias.size() && + absl::string_view(name).substr(0, alias.size()) == alias && + name.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("container name collides with alias: ", alias)); + } + } + + container_ = std::string(name); + return absl::OkStatus(); +} + +absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + abrev = absl::StripAsciiWhitespace(abrev); + if (!IsValidQualifiedName(abrev)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + auto pos = abrev.rfind('.'); + if (pos == 0 || pos == absl::string_view::npos || pos == abrev.size() - 1) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + absl::string_view alias = abrev.substr(pos + 1); + return AddAlias(alias, abrev); +} + +absl::Status ExpressionContainer::AddAlias(absl::string_view alias, + absl::string_view name) { + if (!IsValidAlias(alias)) { + return absl::InvalidArgumentError(absl::StrCat( + "alias must be non-empty and simple (not qualified): ", alias)); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + if (auto it = aliases_.find(alias); it != aliases_.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "alias collides with existing reference: ", alias, " -> ", it->second)); + } + + if (container_ == alias || + (container_.size() > alias.size() && + absl::string_view(container_).substr(0, alias.size()) == alias && + container_.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("alias collides with container name: ", alias)); + } + + aliases_.insert({std::string(alias), std::string(name)}); + return absl::OkStatus(); +} + +absl::string_view ExpressionContainer::FindAlias( + absl::string_view alias) const { + auto it = aliases_.find(alias); + if (it != aliases_.end()) { + return it->second; + } + return ""; +} + +std::vector ExpressionContainer::ListAbbreviations() const { + std::vector res; + for (const auto& entry : aliases_) { + if (IsAbbreviationImpl(entry.first, entry.second)) { + res.push_back(entry.second); + } + } + return res; +} + +std::vector +ExpressionContainer::ListAliases() const { + std::vector res; + for (const auto& entry : aliases_) { + res.push_back({entry.first, entry.second}); + } + return res; +} + +} // namespace cel diff --git a/common/container.h b/common/container.h new file mode 100644 index 000000000..ad8d91c35 --- /dev/null +++ b/common/container.h @@ -0,0 +1,138 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel { + +// ExpressionContainer represents the namespace configuration for a CEL +// expression. +// +// The container defines the default resolution order for names referenced in +// the expression. It generally maps to a protobuf package and follows +// approximately the same resolution rules as protobuf or C++ namespaces. +// +// Aliases declare short names that can be referenced without resolving against +// the scopes defined by the container. An alias cannot be a prefix of the +// container name, (otherwise re-type-checking an expression could +// change the meaning). Aliases are always unqualified identifiers. +// +// An abbreviation is a special case of alias that behaves like an import or +// using declaration in other languages. (pkg.TypeName -> TypeName). +// +// For better traceability, prefer using abbreviations over aliases. +class ExpressionContainer { + public: + struct AliasListing { + std::string alias; + std::string name; + + bool IsAbbreviation() const; + }; + + ExpressionContainer() = default; + + ExpressionContainer(const ExpressionContainer&) = default; + ExpressionContainer(ExpressionContainer&&) = default; + ExpressionContainer& operator=(const ExpressionContainer&) = default; + ExpressionContainer& operator=(ExpressionContainer&&) = default; + + // Returns the full name of the container. + // + // The default value is an empty string meaning no container. + absl::string_view container() const { return container_; } + + // Sets the container name. + // + // Returns an error if the container name is malformed or conflicts with an + // existing alias. + absl::Status SetContainer(absl::string_view name); + + // Adds an abbreviation to the container. + // + // Returns an error if the abbreviation is malformed or conflicts with the + // container or an existing alias. + absl::Status AddAbbreviation(absl::string_view abrev); + + // Adds an alias to the container. + // + // Returns an error if the alias is malformed or conflicts with the container + // or an existing alias. + absl::Status AddAlias(absl::string_view alias, absl::string_view name); + + // Returns the full name of the alias or an empty string if not found. + // + // The returned string view may be invalidated by updates to the + // ExpressionContainer. + absl::string_view FindAlias(absl::string_view alias) const; + + // Utility method for listing the abbreviations in the container. + // Order is not guaranteed. + std::vector ListAbbreviations() const; + + // Utility method for listing the aliases in the container. + // Includes abbreviations. + // Order is not guaranteed. + std::vector ListAliases() const; + + // Removes all aliases and abbreviations from the container. + void clear() { + container_.clear(); + aliases_.clear(); + } + + private: + std::string container_; + + // alias -> full name. + absl::flat_hash_map aliases_; +}; + +// Factory function for creating an ExpressionContainer. +absl::StatusOr MakeExpressionContainer( + absl::string_view name); + +// Factory function for creating an ExpressionContainer with a list of +// abbreviations. +template +absl::StatusOr MakeExpressionContainer( + absl::string_view name, Args&&... abbrevs) { + ExpressionContainer container; + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + absl::string_view abbrevs_view[] = {std::forward(abbrevs)...}; + for (absl::string_view abrev : abbrevs_view) { + status.Update(container.AddAbbreviation(abrev)); + if (!status.ok()) { + return status; + } + } + + return container; +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc new file mode 100644 index 000000000..e40814f54 --- /dev/null +++ b/common/container_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(ExpressionContainerTest, DefaultConstructed) { + ExpressionContainer container; + EXPECT_THAT(container.container(), IsEmpty()); + EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); +} + +TEST(ExpressionContainerTest, MakeExpressionContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.container(), Eq("my.container")); + + EXPECT_THAT(MakeExpressionContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, MakeExpressionContainerWithAbbrevs) { + ASSERT_OK_AND_ASSIGN( + ExpressionContainer container, + MakeExpressionContainer("my.container", "pkg.Abbr", "qual.pkg.Abbr2")); + EXPECT_THAT(container.container(), Eq("my.container")); + EXPECT_THAT(container.FindAlias("Abbr"), Eq("pkg.Abbr")); + EXPECT_THAT(container.FindAlias("Abbr2"), Eq("qual.pkg.Abbr2")); + + EXPECT_THAT(MakeExpressionContainer("my.container", "invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, SetContainer) { + ExpressionContainer container; + EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); + EXPECT_THAT(container.container(), Eq("my.container.name")); + EXPECT_THAT(container.SetContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.SetContainer("foo.1invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, AddAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); +} + +TEST(ExpressionContainerTest, AddAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); + EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); +} + +TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); + EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); + + EXPECT_THAT(container.ListAbbreviations(), + UnorderedElementsAre("qual.pkg.Abbr")); + + auto aliases = container.ListAliases(); + EXPECT_THAT(aliases, SizeIs(2)); +} + +TEST(ExpressionContainerTest, InvalidAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAbbreviation(""), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation(".pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg."), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, InvalidAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo.bar", "baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo", ".baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, CollidesWithContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.AddAlias("my", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel diff --git a/common/decl.cc b/common/decl.cc index 1e06cb703..858e6fb49 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -20,11 +20,13 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/signature.h" #include "common/type.h" #include "common/type_kind.h" @@ -104,215 +106,48 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { return args_overlap; } -void AppendEscaped(std::string* result, absl::string_view str, - bool escape_dot) { - for (char c : str) { - switch (c) { - case '\\': - case '(': - case ')': - case '<': - case '>': - case '"': - case ',': - result->push_back('\\'); - result->push_back(c); - break; - case '.': - if (escape_dot) { - result->push_back('\\'); - } - result->push_back(c); - break; - default: - result->push_back(c); - break; - } - } -} - -void AppendTypeParameters(std::string* result, const Type& type); - -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -void AppendTypeToOverloadId(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - return; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - return; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - return; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - return; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - return; - case TypeKind::kString: - absl::StrAppend(result, "string"); - return; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - return; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - return; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - return; - case TypeKind::kUnknown: - absl::StrAppend(result, "unknown"); - return; - case TypeKind::kError: - absl::StrAppend(result, "error"); - return; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - return; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - return; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - return; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - return; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - return; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - return; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - return; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - return; - case TypeKind::kList: - absl::StrAppend(result, "list"); - AppendTypeParameters(result, type); - return; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - AppendTypeParameters(result, type); - return; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - AppendTypeParameters(result, type); - return; - case TypeKind::kEnum: - absl::StrAppend(result, "enum"); - AppendTypeParameters(result, type); - return; - case TypeKind::kType: - absl::StrAppend(result, "type"); - AppendTypeParameters(result, type); - return; - case TypeKind::kOpaque: - result->push_back('"'); - AppendEscaped(result, type.name(), /*escape_dot=*/false); - result->push_back('"'); - AppendTypeParameters(result, type); - return; - default: // This includes TypeKind::kStruct aka TypeKind::kTypeMessage - AppendEscaped(result, type.name(), /*escape_dot=*/false); - return; - } -} - -void AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - AppendTypeToOverloadId(result, parameters[i]); - if (i < parameters.size() - 1) { - result->push_back(','); - } - } - result->push_back('>'); - } -} - -// Generates an identifier for the overload based on the function name and -// the types of the arguments. If `member` is true, the first argument type -// is used as the receiver and is prepended to the function name, followed by -// a dot. -// -// Examples: -// -// - `foo()` -// - `foo(int)` -// - `bar.foo(int)` -// - `foo(int,string)` -// - `foo(list,list)` -// - `bar.foo(list,list<"my_type">)` -// -std::string GenerateOverloadId(std::string_view function_name, - const std::vector& args, bool member) { - std::string result; - if (member) { - if (!args.empty()) { - AppendTypeToOverloadId(&result, args[0]); - } else { - // This should never happen: a member function with no receiver. - absl::StrAppend(&result, "error"); - } - result.push_back('.'); - } - AppendEscaped(&result, function_name, /*escape_dot=*/true); - result.push_back('('); - for (size_t i = member ? 1 : 0; i < args.size(); ++i) { - AppendTypeToOverloadId(&result, args[i]); - if (i < args.size() - 1) { - result.push_back(','); - } - } - result.push_back(')'); - - return result; -} - template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, - OverloadDeclHashSet& overloads, Overload&& overload, - absl::Status& status) { + absl::flat_hash_map& by_id, + absl::flat_hash_map& by_signature, + Overload&& overload, absl::Status& status) { if (!status.ok()) { return; } - if (overload.id().empty()) { - OverloadDecl overload_decl = overload; - overload_decl.set_id(GenerateOverloadId(function_name, overload_decl.args(), - overload_decl.member())); - AddOverloadInternal(function_name, insertion_order, overloads, - std::move(overload_decl), status); + absl::StatusOr signature = + MakeOverloadSignature(function_name, overload.args(), overload.member()); + if (!signature.ok()) { + status = signature.status(); return; } - if (auto it = overloads.find(overload.id()); it != overloads.end()) { + OverloadDecl mutable_overload = std::forward(overload); + + if (mutable_overload.id().empty()) { + mutable_overload.set_id(*signature); + } + + if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) { status = absl::AlreadyExistsError( - absl::StrCat("overload already exists: ", overload.id())); + absl::StrCat("overload exists: ", mutable_overload.id())); return; } - for (const auto& existing : overloads) { - if (SignaturesOverlap(overload, existing)) { + + for (const auto& existing : insertion_order) { + if (SignaturesOverlap(mutable_overload, existing)) { status = absl::InvalidArgumentError( absl::StrCat("overload signature collision: ", existing.id(), - " collides with ", overload.id())); + " collides with ", mutable_overload.id())); return; } } - const auto inserted = overloads.insert(std::forward(overload)); - ABSL_DCHECK(inserted.second); - insertion_order.push_back(*inserted.first); + + size_t index = insertion_order.size(); + by_id[mutable_overload.id()] = index; + by_signature[*signature] = index; + insertion_order.push_back(std::move(mutable_overload)); } void CollectTypeParams(absl::flat_hash_set& type_params, @@ -362,14 +197,25 @@ absl::flat_hash_set OverloadDecl::GetTypeParams() const { void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - overload, status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, overload, status); } void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - std::move(overload), status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, std::move(overload), status); +} + +const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const { + if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) { + return &overloads_.insertion_order[it->second]; + } + if (auto it = overloads_.by_signature.find(id); + it != overloads_.by_signature.end()) { + return &overloads_.insertion_order[it->second]; + } + return nullptr; } } // namespace cel diff --git a/common/decl.h b/common/decl.h index 22ee8cbf0..b15645236 100644 --- a/common/decl.h +++ b/common/decl.h @@ -22,11 +22,10 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -264,39 +263,6 @@ OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, return overload_decl; } -struct OverloadDeclHash { - using is_transparent = void; - - size_t operator()(const OverloadDecl& overload_decl) const { - return (*this)(overload_decl.id()); - } - - size_t operator()(absl::string_view id) const { return absl::HashOf(id); } -}; - -struct OverloadDeclEqualTo { - using is_transparent = void; - - bool operator()(const OverloadDecl& lhs, const OverloadDecl& rhs) const { - return (*this)(lhs.id(), rhs.id()); - } - - bool operator()(const OverloadDecl& lhs, absl::string_view rhs) const { - return (*this)(lhs.id(), rhs); - } - - bool operator()(absl::string_view lhs, const OverloadDecl& rhs) const { - return (*this)(lhs, rhs.id()); - } - - bool operator()(absl::string_view lhs, absl::string_view rhs) const { - return lhs == rhs; - } -}; - -using OverloadDeclHashSet = - absl::flat_hash_set; - template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads); @@ -346,21 +312,27 @@ class FunctionDecl final { return overloads_.insertion_order; } + ABSL_MUST_USE_RESULT const OverloadDecl* FindOverloadById( + absl::string_view id) const; + std::vector release_overloads() { std::vector released = std::move(overloads_.insertion_order); overloads_.insertion_order.clear(); - overloads_.set.clear(); + overloads_.by_id.clear(); + overloads_.by_signature.clear(); return released; } private: struct Overloads { std::vector insertion_order; - OverloadDeclHashSet set; + absl::flat_hash_map by_id; + absl::flat_hash_map by_signature; void Reserve(size_t size) { insertion_order.reserve(size); - set.reserve(size); + by_id.reserve(size); + by_signature.reserve(size); } }; @@ -405,6 +377,70 @@ bool TypeIsAssignable(const Type& to, const Type& from); } // namespace common_internal +struct VariableDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::VariableDecl& lhs, + const cel::VariableDecl& rhs) const { + return lhs.name() == rhs.name(); + } + + bool operator()(const cel::VariableDecl& lhs, std::string_view rhs) const { + return lhs.name() == rhs; + } + + bool operator()(std::string_view lhs, const cel::VariableDecl& rhs) const { + return lhs == rhs.name(); + } +}; + +struct VariableDeclHash { + using is_transparent = void; + + size_t operator()(const cel::VariableDecl& decl) const { + return (*this)(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using VariableDeclSet = absl::flat_hash_set; + +struct FunctionDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::FunctionDecl& lhs, + const cel::FunctionDecl& rhs) const { + return (*this)(lhs.name(), rhs.name()); + } + + bool operator()(const cel::FunctionDecl& lhs, std::string_view rhs) const { + return (*this)(lhs.name(), rhs); + } + + bool operator()(std::string_view lhs, const cel::FunctionDecl& rhs) const { + return (*this)(lhs, rhs.name()); + } + + bool operator()(std::string_view lhs, std::string_view rhs) const { + return lhs == rhs; + } +}; + +struct FunctionDeclHash { + using is_transparent = void; + + size_t operator()(const cel::FunctionDecl& decl) const { + return absl::HashOf(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using FunctionDeclSet = absl::flat_hash_set; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ diff --git a/common/decl_proto.cc b/common/decl_proto.cc index 89f7f4453..098c5068c 100644 --- a/common/decl_proto.cc +++ b/common/decl_proto.cc @@ -69,7 +69,7 @@ absl::StatusOr FunctionDeclFromProto( return decl; } -absl::StatusOr> DeclFromProto( +absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/decl_proto_test.cc b/common/decl_proto_test.cc index 62215f07f..d72d97e09 100644 --- a/common/decl_proto_test.cc +++ b/common/decl_proto_test.cc @@ -49,7 +49,7 @@ TEST_P(DeclFromProtoTest, FromProtoWorks) { cel::expr::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromProto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { @@ -79,7 +79,7 @@ TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { google::api::expr::v1alpha1::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { diff --git a/common/decl_proto_v1alpha1.cc b/common/decl_proto_v1alpha1.cc index 2c6cfb6e4..a8d73e5c2 100644 --- a/common/decl_proto_v1alpha1.cc +++ b/common/decl_proto_v1alpha1.cc @@ -52,7 +52,7 @@ absl::StatusOr FunctionDeclFromV1Alpha1Proto( return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); } -absl::StatusOr> DeclFromV1Alpha1Proto( +absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/decl_test.cc b/common/decl_test.cc index 6e5710049..72e7f1b93 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -14,7 +14,10 @@ #include "common/decl.h" -#include "absl/log/die_if_null.h" +#include +#include + +#include "absl/log/die_if_null.h" // IWYU pragma: keep #include "absl/status/status.h" #include "common/constant.h" #include "common/type.h" @@ -162,6 +165,53 @@ TEST(FunctionDecl, Overloads) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(FunctionDecl, AddOverloadInvalidSignature) { + FunctionDecl function_decl; + function_decl.set_name("foo"); + // Member overload must have at least one argument (the receiver). + // This should fail to add because signature generation fails. + EXPECT_THAT(function_decl.AddOverload(MakeMemberOverloadDecl(StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadDuplicateId) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl("hello", + MakeOverloadDecl("foo", StringType{}, StringType{}))); + // Adding another overload with the same ID "foo" should fail. + EXPECT_THAT( + function_decl.AddOverload(MakeOverloadDecl("foo", IntType{}, IntType{})), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(FunctionDecl, FindOverload) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}))); + + // Find by explicit ID + const OverloadDecl* overload = function_decl.FindOverloadById("foo"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find by ID fallback to signature + overload = function_decl.FindOverloadById("hello(string)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find implicit overload (where ID == signature) + overload = function_decl.FindOverloadById("hello(int)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "hello(int)"); + + // Non-existent + EXPECT_EQ(function_decl.FindOverloadById("non_existent"), nullptr); +} + TEST(FunctionDecl, OverloadId) { google::protobuf::Arena arena; const auto* descriptor = @@ -186,7 +236,6 @@ TEST(FunctionDecl, OverloadId) { MakeOverloadDecl(IntType{}, TimestampType{}), MakeOverloadDecl(IntType{}, IntWrapperType{}), MakeOverloadDecl(IntType{}, MessageType(descriptor)), - MakeMemberOverloadDecl(IntType{}), MakeMemberOverloadDecl(StringType{}, StringType{}), MakeMemberOverloadDecl(StringType{}, StringType{}, ListType(&arena, BoolType{})), @@ -198,36 +247,20 @@ TEST(FunctionDecl, OverloadId) { ElementsAre(Property(&OverloadDecl::id, "hello()"), Property(&OverloadDecl::id, "hello(string)"), Property(&OverloadDecl::id, "hello(int,uint)"), - Property(&OverloadDecl::id, "hello(list)"), - Property(&OverloadDecl::id, "hello(map)"), - Property(&OverloadDecl::id, "hello(\"bar\">)"), + Property(&OverloadDecl::id, "hello(list<~A>)"), + Property(&OverloadDecl::id, "hello(map<~B,~C>)"), + Property(&OverloadDecl::id, "hello(bar>)"), Property(&OverloadDecl::id, "hello(any)"), Property(&OverloadDecl::id, "hello(duration)"), Property(&OverloadDecl::id, "hello(timestamp)"), Property(&OverloadDecl::id, "hello(int_wrapper)"), Property(&OverloadDecl::id, "hello(cel.expr.conformance.proto3.TestAllTypes)"), - Property(&OverloadDecl::id, "error.hello()"), Property(&OverloadDecl::id, "string.hello()"), Property(&OverloadDecl::id, "string.hello(list)"), Property(&OverloadDecl::id, "string.hello(bool,dyn)"))); } -TEST(FunctionDecl, OverloadIdEscaping) { - google::protobuf::Arena arena; - ASSERT_OK_AND_ASSIGN( - auto function_decl, - MakeFunctionDecl("h.(e),l\\o", - MakeMemberOverloadDecl( - StringType{}, StringType{}, - ListType(&arena, TypeParamType("a,b..(d)\\e"))))); - - EXPECT_THAT(function_decl.overloads(), - ElementsAre(Property(&OverloadDecl::id, - "string.h\\.\\(e\\)\\,l\\\\\\o(list<" - "a\\,b.\\.\\(d\\)\\\\e>)"))); -} - using common_internal::TypeIsAssignable; TEST(TypeIsAssignable, BoolWrapper) { diff --git a/common/expr.h b/common/expr.h index 9c6f508c6..7305c2c9f 100644 --- a/common/expr.h +++ b/common/expr.h @@ -45,7 +45,9 @@ class MapExprEntry; class MapExpr; class ComprehensionExpr; -inline constexpr absl::string_view kAccumulatorVariableName = "__result__"; +inline constexpr absl::string_view kAccumulatorVariableName = "@result"; +inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = + "__result__"; bool operator==(const Expr& lhs, const Expr& rhs); diff --git a/common/expr_factory.h b/common/expr_factory.h index c8a9b831f..757318545 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -32,6 +32,11 @@ namespace cel { class MacroExprFactory; class ParserMacroExprFactory; +class OptimizerExprFactory; + +namespace tools { +class ProtoToPredicateBuilder; +} class ExprFactory { protected: @@ -352,12 +357,36 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; + friend class OptimizerExprFactory; + friend class tools::ProtoToPredicateBuilder; ExprFactory() : accu_var_(kAccumulatorVariableName) {} - explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {} std::string accu_var_; }; diff --git a/checker/internal/format_type_name.cc b/common/format_type_name.cc similarity index 97% rename from checker/internal/format_type_name.cc rename to common/format_type_name.cc index 7cd17251f..4bd6c2e61 100644 --- a/checker/internal/format_type_name.cc +++ b/common/format_type_name.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include #include @@ -20,7 +20,7 @@ #include "common/type.h" #include "common/type_kind.h" -namespace cel::checker_internal { +namespace cel { namespace { struct FormatImplRecord { @@ -177,4 +177,4 @@ std::string FormatTypeName(const Type& type) { return out; } -} // namespace cel::checker_internal +} // namespace cel diff --git a/checker/internal/format_type_name.h b/common/format_type_name.h similarity index 74% rename from checker/internal/format_type_name.h rename to common/format_type_name.h index c31e1c4d0..723ac20fd 100644 --- a/checker/internal/format_type_name.h +++ b/common/format_type_name.h @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ -#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ #include #include "common/type.h" -namespace cel::checker_internal { +namespace cel { // Format the type name for presentation in error messages. Matches the // formatting used in github.com/cel-spec. std::string FormatTypeName(const Type& type); -} // namespace cel::checker_internal +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ diff --git a/checker/internal/format_type_name_test.cc b/common/format_type_name_test.cc similarity index 97% rename from checker/internal/format_type_name_test.cc rename to common/format_type_name_test.cc index 23bc2bda9..ca63f60b0 100644 --- a/checker/internal/format_type_name_test.cc +++ b/common/format_type_name_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include "common/type.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" -namespace cel::checker_internal { +namespace cel { namespace { using ::cel::expr::conformance::proto2::GlobalEnum_descriptor; @@ -101,6 +101,7 @@ TEST(FormatTypeNameTest, Opaque) { "tuple(tuple(int, int), tuple(int, int), tuple(int, int))"); } +#ifndef __APPLE__ TEST(FormatTypeNameTest, ArbitraryNesting) { google::protobuf::Arena arena; Type type = IntType(); @@ -111,6 +112,7 @@ TEST(FormatTypeNameTest, ArbitraryNesting) { EXPECT_THAT(FormatTypeName(type), MatchesRegex(R"(^(ptype\(){1000}int(\)){1000})")); } +#endif } // namespace -} // namespace cel::checker_internal +} // namespace cel diff --git a/common/internal/BUILD b/common/internal/BUILD index c5ca63564..3be350754 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -21,10 +21,8 @@ cc_library( name = "casting", hdrs = ["casting.h"], deps = [ - "//common:native_type", "//internal:casts", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], diff --git a/common/legacy_value.cc b/common/legacy_value.cc index 5c81fdacb..7fbf16732 100644 --- a/common/legacy_value.cc +++ b/common/legacy_value.cc @@ -700,7 +700,8 @@ absl::Status LegacyMapValue::Get( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -732,7 +733,7 @@ absl::StatusOr LegacyMapValue::Find( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -764,11 +765,17 @@ absl::Status LegacyMapValue::Has( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - CEL_ASSIGN_OR_RETURN(auto has, impl_->Has(cel_key)); - *result = BoolValue{has}; + absl::StatusOr has = impl_->Has(cel_key); + if (!has.ok()) { + *result = ErrorValue(std::move(has).status()); + return absl::OkStatus(); + } + + *result = BoolValue(*has); return absl::OkStatus(); } diff --git a/common/optional_ref.h b/common/optional_ref.h index 454926c80..c7ba580fc 100644 --- a/common/optional_ref.h +++ b/common/optional_ref.h @@ -84,7 +84,12 @@ class optional_ref final { constexpr T& value() const { return ABSL_PREDICT_TRUE(has_value()) ? *value_ - : (absl::optional().value(), *value_); + // Replicate the same error logic as in `absl::optional`'s + // `value()`. It either throws an exception or aborts the + // program. We intentionally ignore the return value of + // the constructed optional's value as we only need to run + // the code for error checking. + : ((void)absl::optional().value(), *value_); } constexpr T& operator*() const { diff --git a/common/signature.cc b/common/signature.cc new file mode 100644 index 000000000..e497e780d --- /dev/null +++ b/common/signature.cc @@ -0,0 +1,640 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/signature.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_spec_resolver.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Signature generator helper functions. +namespace { + +void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { + for (char c : str) { + switch (c) { + case '\\': + case '(': + case ')': + case '<': + case '>': + case '"': + case ',': + case '~': + result->push_back('\\'); + break; + case '.': + if (escape_dot) { + result->push_back('\\'); + } + break; + } + result->push_back(c); + } +} + +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); + +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); + } + return absl::OkStatus(); +} + +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list<"); + if (type_spec.list_type().elem_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map<"); + if (type_spec.map_type().key_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back(','); + if (type_spec.map_type().value_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function<"); + if (type_spec.function().result_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { + result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); + } + result->push_back('>'); + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); + result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); + } + return absl::OkStatus(); +} +} // namespace + +absl::StatusOr MakeTypeSignature(const Type& type) { + std::string result; + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::string result; + if (is_member) { + if (!args.empty()) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[0])); + } else { + return absl::InvalidArgumentError("Member function with no receiver"); + } + result.push_back('.'); + } + AppendEscaped(&result, function_name, /*escape_dot=*/true); + result.push_back('('); + for (size_t i = is_member ? 1 : 0; i < args.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[i])); + if (i < args.size() - 1) { + result.push_back(','); + } + } + result.push_back(')'); + + return result; +} + +// Signature parser helper functions. +namespace { + +std::string StripUnescapedWhitespace(std::string_view str) { + std::string result; + result.reserve(str.size()); + bool escaped = false; + for (char c : str) { + if (escaped) { + result.push_back(c); + escaped = false; + continue; + } + if (c == '\\') { + result.push_back(c); + escaped = true; + continue; + } + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + continue; + } + result.push_back(c); + } + return result; +} + +absl::optional ParseBuiltinOrWrapper(std::string_view name_str) { + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any" || name_str == "google.protobuf.Any") + return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp" || name_str == "google.protobuf.Timestamp") + return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration" || name_str == "google.protobuf.Duration") + return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn" || name_str == "google.protobuf.Value") + return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "google.protobuf.ListValue") { + return TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + if (name_str == "google.protobuf.Struct") { + return TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))); + } + + return absl::nullopt; +} + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [¶ms](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (!params.empty()) { + if (ParseBuiltinOrWrapper(name_str).has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid type signature: ", name_str, + " cannot have type parameters")); + } + } else { + if (auto builtin = ParseBuiltinOrWrapper(name_str); builtin.has_value()) { + return *builtin; + } + } + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + if (stripped_sig.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN( + size_t paren_idx, + SignatureScanner(stripped_sig, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || stripped_sig.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = std::string_view(stripped_sig).substr(0, paren_idx); + std::string_view args_str = + std::string_view(stripped_sig) + .substr(paren_idx + 1, stripped_sig.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + +} // namespace cel diff --git a/common/signature.h b/common/signature.h new file mode 100644 index 000000000..777f03439 --- /dev/null +++ b/common/signature.h @@ -0,0 +1,101 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Generates a signature for a `cel::Type`, which is a string representation of +// the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSignature(const Type& type); + +// Generates a signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + +// Generates a signature for a function overload based on the function name +// and the types of the arguments. If `is_member` is true, the first argument +// type is used as the receiver and is prepended to the function name, followed +// by a dollar sign. +// +// Examples: +// +// - `foo()` +// - `foo(int)` +// - `bar.foo(int)` +// - `foo(int,string)` +// - `foo(list,list)` +// - `bar.foo(list,list>)` +// +// If the function name contains a period, it is escaped with a backslash, e.g. +// `foo.bar` becomes `foo\.bar`. This allows to disambiguate between a member +// function and qualified target type name. +// +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Generates a signature for a function overload based on the function name +// and the type specs of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + // The function signature type, configured as a `FunctionTypeSpec`. + TypeSpec signature_type; +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ diff --git a/common/signature_test.cc b/common/signature_test.cc new file mode 100644 index 000000000..ea51eb566 --- /dev/null +++ b/common/signature_test.cc @@ -0,0 +1,784 @@ +#include "common/signature.h" +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "common/type_spec_resolver.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); +} +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kStruct || + lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + +struct TypeSignatureTestCase { + TypeSpec type; + std::string expected_signature; + std::string expected_error; +}; + +using TypeSignatureTest = testing::TestWithParam; + +TEST_P(TypeSignatureTest, TypeSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = MakeTypeSpecSignature(param.type); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + + absl::StatusOr type = ConvertTypeSpecToType( + param.type, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(type, ::absl_testing::IsOk()); + EXPECT_THAT(MakeTypeSignature(*type), + IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetTypeSignatureTestCases() { + return { + { + .type = TypeSpec(NullTypeSpec{}), + .expected_signature = "null", + }, + { + .type = TypeSpec(PrimitiveType::kBool), + .expected_signature = "bool", + }, + { + .type = TypeSpec(PrimitiveType::kInt64), + .expected_signature = "int", + }, + { + .type = TypeSpec(PrimitiveType::kUint64), + .expected_signature = "uint", + }, + { + .type = TypeSpec(PrimitiveType::kDouble), + .expected_signature = "double", + }, + { + .type = TypeSpec(PrimitiveType::kString), + .expected_signature = "string", + }, + { + .type = TypeSpec(PrimitiveType::kBytes), + .expected_signature = "bytes", + }, + { + .type = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {})), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kDuration), + .expected_signature = "duration", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_signature = "timestamp", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kString))), + .expected_signature = "list", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A")))), + .expected_signature = "list<~A>", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A(ParamTypeSpec(R"(a,b..(d)\e)")))), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec()))), + .expected_signature = "map", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C")))), + .expected_signature = "map<~B,~C>", + }, + { + .type = TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(MapTypeSpec(nullptr, nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(std::make_unique(PrimitiveType::kInt64)), + .expected_signature = "type", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kAny), + .expected_signature = "any", + }, + { + .type = TypeSpec(DynTypeSpec{}), + .expected_signature = "dyn", + }, + { + .type = TypeSpec(AbstractType( + "bar", {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), + {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool)}))})), + .expected_signature = "bar>", + }, + { + .type = + TypeSpec(AbstractType("bar", {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kString)})), + .expected_signature = "bar", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + .expected_signature = "bool_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + .expected_signature = "int_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + .expected_signature = "uint_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_signature = "double_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + .expected_signature = "string_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + .expected_signature = "bytes_wrapper", + }, + { + .type = TypeSpec( + FunctionTypeSpec(nullptr, {TypeSpec(PrimitiveType::kInt64)})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec( + std::make_unique(PrimitiveType::kInt64), {})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec(nullptr, {})), + .expected_signature = "function", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported Type kind: *unknown*"))); + + EXPECT_THAT(MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported type in signature: *error*"))); + + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported primitive type"))); + + EXPECT_THAT( + MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); + + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec( + PrimitiveTypeWrapper(static_cast(999)))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported wrapper type"))); +} + +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(auto expected_type, + ConvertTypeSpecToType(param.type, GetTestArena(), + *GetTestingDescriptorPool())); + VerifyTypesEqual(*parsed, expected_type); + } +} + +struct OverloadSignatureTestCase { + std::string function_name = "hello"; + std::vector args; + bool is_member = false; + std::string expected_signature; + std::string expected_error; +}; + +using OverloadSignatureTest = testing::TestWithParam; + +TEST_P(OverloadSignatureTest, OverloadSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + MakeOverloadSignature(param.function_name, param.args, param.is_member); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetOverloadSignatureTestCases() { + return { + { + .args = {TypeSpec(PrimitiveType::kString)}, + .expected_signature = "hello(string)", + }, + { + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, + .expected_signature = "hello(int,uint)", + }, + { + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, + .expected_signature = "hello(list)", + }, + { + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, + .expected_signature = "hello(list<~A>)", + }, + { + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, + .expected_signature = "hello(map)", + }, + { + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, + .expected_signature = "hello(map<~B,~C>)", + }, + + { + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, + .expected_signature = "hello(bar>)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, + .expected_signature = "hello(any)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, + .expected_signature = "hello(duration)", + }, + { + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, + .expected_signature = "hello(timestamp)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + .expected_signature = "hello(bool_wrapper)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, + .expected_signature = "hello(int_wrapper)", + }, + { + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, + .expected_signature = "hello(uint_wrapper)", + }, + { + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, + .expected_signature = + "hello(cel.expr.conformance.proto3.TestAllTypes)", + }, + { + .args = {TypeSpec(PrimitiveType::kString)}, + .is_member = true, + .expected_signature = "string.hello()", + }, + { + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, + .is_member = true, + .expected_signature = "string.hello(list)", + }, + { + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, + .is_member = true, + .expected_signature = "string.hello(bool,dyn)", + }, + { + .function_name = "hello", + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, + .expected_signature = "inspect(type)", + }, + { + .function_name = R"(h.(e),l\o)", + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, + .is_member = true, + .expected_signature = + R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", + }, + }; +} + +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = + MakeOverloadSignature("hello", std::vector{}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + +INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, + ValuesIn(GetOverloadSignatureTestCases())); + +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto gp_any, + ParseType("google.protobuf.Any", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_any.IsAny()); + + ASSERT_OK_AND_ASSIGN(auto gp_timestamp, + ParseType("google.protobuf.Timestamp", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_timestamp.IsTimestamp()); + + ASSERT_OK_AND_ASSIGN(auto gp_duration, + ParseType("google.protobuf.Duration", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_duration.IsDuration()); + + ASSERT_OK_AND_ASSIGN(auto gp_value, + ParseType("google.protobuf.Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_value.IsDyn()); + + ASSERT_OK_AND_ASSIGN(auto gp_list_value, + ParseType("google.protobuf.ListValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_list_value.IsList()); + + ASSERT_OK_AND_ASSIGN(auto gp_struct, + ParseType("google.protobuf.Struct", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_struct.IsMap()); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_type1, + ParseType("map < int , string > ", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type1.IsMap()); + + ASSERT_OK_AND_ASSIGN(auto ws_type2, + ParseType("map\t<\nint\r,\tstring\n>\r", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type2.IsMap()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_func1, + ParseFunctionSignature(" hello ( string ) ")); + EXPECT_TRUE(ws_func1.signature_type.has_function()); + EXPECT_EQ(ws_func1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto ws_func2, + ParseFunctionSignature("\thello\n(\rstring\t)\n\r")); + EXPECT_TRUE(ws_func2.signature_type.has_function()); + EXPECT_EQ(ws_func2.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT( + ParseType("list b < c>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + // Parameter count validations for list, map and type types. + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("type", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type expects at most 1 parameter"))); + + // Invalid parameter name validations. + EXPECT_THAT(ParseType("~", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + EXPECT_THAT(ParseType("~A", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + // Checks that builtin types cannot have type parameters. + EXPECT_THAT( + ParseType("int", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + +TEST(OverloadSignatureTest, ArgumentTypeVector) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} + +} // namespace +} // namespace cel diff --git a/common/source.cc b/common/source.cc index 8c32ad6ba..5fa4cca0e 100644 --- a/common/source.cc +++ b/common/source.cc @@ -483,26 +483,26 @@ absl::optional Source::GetLocation( return SourceLocation{line_and_offset->first, position - line_and_offset->second}; } - return absl::nullopt; + return std::nullopt; } absl::optional Source::GetPosition( const SourceLocation& location) const { if (ABSL_PREDICT_FALSE(location.line < 1 || location.column < 0)) { - return absl::nullopt; + return std::nullopt; } if (auto position = FindLinePosition(location.line); ABSL_PREDICT_TRUE(position.has_value())) { return *position + location.column; } - return absl::nullopt; + return std::nullopt; } absl::optional Source::Snippet(int32_t line) const { auto content = this->content(); auto start = FindLinePosition(line); if (ABSL_PREDICT_FALSE(!start.has_value() || content.empty())) { - return absl::nullopt; + return std::nullopt; } auto end = FindLinePosition(line + 1); if (end.has_value()) { @@ -554,7 +554,7 @@ std::string Source::DisplayErrorLocation(SourceLocation location) const { absl::optional Source::FindLinePosition(int32_t line) const { if (ABSL_PREDICT_FALSE(line < 1)) { - return absl::nullopt; + return std::nullopt; } if (line == 1) { return SourcePosition{0}; @@ -563,13 +563,13 @@ absl::optional Source::FindLinePosition(int32_t line) const { if (ABSL_PREDICT_TRUE(line <= static_cast(line_offsets.size()))) { return line_offsets[static_cast(line - 2)]; } - return absl::nullopt; + return std::nullopt; } absl::optional> Source::FindLine( SourcePosition position) const { if (ABSL_PREDICT_FALSE(position < 0)) { - return absl::nullopt; + return std::nullopt; } int32_t line = 1; const auto line_offsets = this->line_offsets(); diff --git a/common/source_test.cc b/common/source_test.cc index 2a3b78893..30a2ce9b0 100644 --- a/common/source_test.cc +++ b/common/source_test.cc @@ -81,37 +81,37 @@ TEST(StringSource, PositionAndLocation) { Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); - EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); } TEST(StringSource, SnippetSingle) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello, world", "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); - EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); } TEST(StringSource, SnippetMulti) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello\nworld\nmy\nbub\n", "four-line-test")); - EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); - EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); } TEST(CordSource, Description) { @@ -150,17 +150,17 @@ TEST(CordSource, PositionAndLocation) { Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); - EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); } TEST(CordSource, SnippetSingle) { @@ -168,7 +168,7 @@ TEST(CordSource, SnippetSingle) { NewSource(absl::Cord("hello, world"), "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); - EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); } TEST(CordSource, SnippetMulti) { @@ -176,13 +176,13 @@ TEST(CordSource, SnippetMulti) { auto source, NewSource(absl::Cord("hello\nworld\nmy\nbub\n"), "four-line-test")); - EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); - EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); } TEST(Source, DisplayErrorLocationBasic) { diff --git a/common/type.cc b/common/type.cc index 2b81e39f8..9ea85954c 100644 --- a/common/type.cc +++ b/common/type.cc @@ -75,7 +75,9 @@ Type Type::Message(const Descriptor* absl_nonnull descriptor) { Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) { if (descriptor->full_name() == "google.protobuf.NullValue") { - return NullType(); + // Special case NullValue to prevent the emebedder providing a different + // descriptor for it and it leaking. + return IntType(); } return EnumType(descriptor); } @@ -95,7 +97,7 @@ static constexpr std::array kTypeToKindArray = { TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == - absl::variant_size(), + std::variant_size(), "Kind indexer must match variant declaration for cel::Type."); } // namespace @@ -156,7 +158,7 @@ absl::optional GetOrNullopt(const common_internal::TypeVariant& variant) { if (const auto* alt = absl::get_if(&variant); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } } // namespace @@ -241,7 +243,7 @@ absl::optional Type::AsOptional() const { if (auto maybe_opaque = AsOpaque(); maybe_opaque.has_value()) { return maybe_opaque->AsOptional(); } - return absl::nullopt; + return std::nullopt; } absl::optional Type::AsString() const { @@ -261,7 +263,7 @@ absl::optional Type::AsStruct() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } absl::optional Type::AsTimestamp() const { @@ -601,7 +603,7 @@ absl::optional StructTypeField::AsMessage() const { alternative != nullptr) { return *alternative; } - return absl::nullopt; + return std::nullopt; } StructTypeField::operator MessageTypeField() const { @@ -638,8 +640,6 @@ constexpr absl::string_view kUInt64TypeName = "uint"; constexpr absl::string_view kDoubleTypeName = "double"; constexpr absl::string_view kStringTypeName = "string"; constexpr absl::string_view kBytesTypeName = "bytes"; -constexpr absl::string_view kDurationTypeName = "google.protobuf.Duration"; -constexpr absl::string_view kTimestampTypeName = "google.protobuf.Timestamp"; constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; @@ -668,12 +668,6 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kBytesTypeName) { return BytesType{}; } - if (name == kDurationTypeName) { - return DurationType{}; - } - if (name == kTimestampTypeName) { - return TimestampType{}; - } if (name == kListTypeName) { return ListType{}; } @@ -683,6 +677,53 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kCelTypeTypeName) { return TypeType{}; } + if (cel::IsWellKnownMessageType(name)) { + if (name == "google.protobuf.Any") { + return AnyType(); + } + if (name == "google.protobuf.BoolValue") { + return BoolWrapperType(); + } + if (name == "google.protobuf.BytesValue") { + return BytesWrapperType(); + } + if (name == "google.protobuf.DoubleValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Duration") { + return DurationType(); + } + if (name == "google.protobuf.FloatValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Int32Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.Int64Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.ListValue") { + return ListType(); + } + if (name == "google.protobuf.StringValue") { + return StringWrapperType(); + } + if (name == "google.protobuf.Struct") { + return JsonMapType(); + } + if (name == "google.protobuf.Timestamp") { + return TimestampType(); + } + if (name == "google.protobuf.UInt32Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.UInt64Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.Value") { + return DynType(); + } + } return common_internal::MakeBasicStructType(name); } diff --git a/common/type_introspector.cc b/common/type_introspector.cc index c69235b3b..3846ab58b 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -17,7 +17,9 @@ #include #include #include +#include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" @@ -102,7 +104,7 @@ struct WellKnownType { auto it = std::lower_bound(fields_by_name.begin(), fields_by_name.end(), name, FieldNameComparer{}); if (it == fields_by_name.end() || it->name() != name) { - return absl::nullopt; + return std::nullopt; } return *it; } @@ -112,7 +114,7 @@ struct WellKnownType { auto it = std::lower_bound(fields_by_number.begin(), fields_by_number.end(), number, FieldNumberComparer{}); if (it == fields_by_number.end() || it->number() != number) { - return absl::nullopt; + return std::nullopt; } return *it; } @@ -173,7 +175,8 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { "google.protobuf.Value", WellKnownType{ DynType{}, - {MakeBasicStructTypeField("null_value", NullType{}, 1), + {// NullValue enum is an int. Not normally referenced directly. + MakeBasicStructTypeField("null_value", IntType{}, 1), MakeBasicStructTypeField("number_value", DoubleType{}, 2), MakeBasicStructTypeField("string_value", StringType{}, 3), MakeBasicStructTypeField("bool_value", BoolType{}, 4), @@ -211,50 +214,64 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { } // namespace -absl::StatusOr> TypeIntrospector::FindType( - absl::string_view name) const { +absl::StatusOr> TypeIntrospector::FindTypeImpl( + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr> +TypeIntrospector::FindEnumConstantImpl(absl::string_view, + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr> +TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, + absl::string_view) const { + return std::nullopt; +} + +absl::StatusOr< + absl::optional>> +TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { + return std::nullopt; +} + +absl::optional FindWellKnownType(absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(name); it != well_known_types.end()) { return it->second.type; } - return FindTypeImpl(name); + return std::nullopt; } -absl::StatusOr> -TypeIntrospector::FindEnumConstant(absl::string_view type, - absl::string_view value) const { +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value) { if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { - return EnumConstant{NullType{}, "google.protobuf.NullValue", "NULL_VALUE", - 0}; + return TypeIntrospector::EnumConstant{ + IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } - return FindEnumConstantImpl(type, value); + return std::nullopt; } -absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByName(absl::string_view type, - absl::string_view name) const { +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(type); it != well_known_types.end()) { return it->second.FieldByName(name); } - return FindStructTypeFieldByNameImpl(type, name); -} - -absl::StatusOr> TypeIntrospector::FindTypeImpl( - absl::string_view) const { - return absl::nullopt; -} - -absl::StatusOr> -TypeIntrospector::FindEnumConstantImpl(absl::string_view, - absl::string_view) const { - return absl::nullopt; + return std::nullopt; } -absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, - absl::string_view) const { - return absl::nullopt; +absl::optional> +ListFieldsForWellKnownType(absl::string_view type) { + const auto& well_known_types = GetWellKnownTypesMap(); + auto it = well_known_types.find(type); + if (it == well_known_types.end()) { + return std::nullopt; + } + // The fields are not normally gettable. + return {}; } } // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h index 7f4a19a31..932fb108e 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #include +#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -24,8 +25,6 @@ namespace cel { -class TypeFactory; - // `TypeIntrospector` is an interface which allows querying type-related // information. It handles type introspection, but not type reflection. That is, // it is not capable of instantiating new values or understanding values. Its @@ -42,20 +41,47 @@ class TypeIntrospector { int32_t number; }; + struct StructTypeFieldListing { + // The name used to access the field in source CEL. + // This is assumed owned by the TypeIntrospector or a dependency that + // outlives it. + absl::string_view name; + // The field description. + StructTypeField field; + }; + virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. - absl::StatusOr> FindType(absl::string_view name) const; + absl::StatusOr> FindType(absl::string_view name) const { + return FindTypeImpl(name); + } // `FindEnumConstant` find a fully qualified enumerator name `name` in enum // type `type`. absl::StatusOr> FindEnumConstant( - absl::string_view type, absl::string_view value) const; + absl::string_view type, absl::string_view value) const { + return FindEnumConstantImpl(type, value); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) const; + absl::string_view type, absl::string_view name) const { + return FindStructTypeFieldByNameImpl(type, name); + } + + // `ListFieldsForStructType` returns the fields of struct type `type`. + // + // This is used when the struct is declared as a context type. + // + // If the type is not found, returns `absl::nullopt`. + // If the type exists but is not a struct or has no fields, returns an empty + // vector. + absl::StatusOr>> + ListFieldsForStructType(absl::string_view type) const { + return ListFieldsForStructTypeImpl(type); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. @@ -74,6 +100,56 @@ class TypeIntrospector { virtual absl::StatusOr> FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const; + + virtual absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const; +}; + +// Looks up a well-known type by name. +absl::optional FindWellKnownType(absl::string_view name); + +// Looks up a well-known enum constant by type and value. +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value); + +// Looks up a well-known struct type field by type and field name. +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name); + +absl::optional> +ListFieldsForWellKnownType(absl::string_view type); + +// `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which +// handles well known types that are treated specially by CEL. +// +// This also serves as a minimal implementation of a TypeInstrospector when no +// custom types are present. +// +// This class has no mutable state, so trivially thread-safe. +class WellKnownTypeIntrospector : public virtual TypeIntrospector { + public: + WellKnownTypeIntrospector() = default; + + private: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final { + return FindWellKnownType(name); + } + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final { + return FindWellKnownTypeEnumConstant(type, value); + } + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final { + return FindWellKnownTypeFieldByName(type, name); + } + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final { + return ListFieldsForWellKnownType(type); + } }; } // namespace cel diff --git a/common/type_manager.h b/common/type_manager.h deleted file mode 100644 index 354f4c9b8..000000000 --- a/common/type_manager.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/memory.h" -#include "common/type.h" -#include "common/type_factory.h" -#include "common/type_introspector.h" - -namespace cel { - -// `TypeManager` is an additional layer on top of `TypeFactory` and -// `TypeIntrospector` which combines the two and adds additional functionality. -class TypeManager : public virtual TypeFactory { - public: - virtual ~TypeManager() = default; - - // See `TypeIntrospector::FindType`. - absl::StatusOr> FindType(absl::string_view name) { - return GetTypeIntrospector().FindType(name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - const StructType& type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - protected: - virtual const TypeIntrospector& GetTypeIntrospector() const = 0; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ diff --git a/common/type_proto.cc b/common/type_proto.cc index 66c16689d..b6b66f73a 100644 --- a/common/type_proto.cc +++ b/common/type_proto.cc @@ -71,7 +71,7 @@ absl::optional MaybeWellKnownType(absl::string_view type_name) { return it->second; } - return absl::nullopt; + return std::nullopt; } absl::Status TypeToProtoInternal(const cel::Type& type, diff --git a/common/type_reflector_test.cc b/common/type_reflector_test.cc index f2ff2c322..d9c855e4b 100644 --- a/common/type_reflector_test.cc +++ b/common/type_reflector_test.cc @@ -210,7 +210,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { internal::GetTestingMessageFactory(), "google.protobuf.BoolValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -218,7 +218,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -236,7 +236,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { internal::GetTestingMessageFactory(), "google.protobuf.Int32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -248,7 +248,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -270,7 +270,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { internal::GetTestingMessageFactory(), "google.protobuf.Int64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -278,7 +278,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -296,7 +296,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -308,7 +308,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -330,7 +330,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -338,7 +338,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -356,7 +356,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { internal::GetTestingMessageFactory(), "google.protobuf.FloatValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -364,7 +364,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -382,7 +382,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -390,7 +390,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -408,7 +408,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { internal::GetTestingMessageFactory(), "google.protobuf.StringValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -416,7 +416,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, StringValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -434,7 +434,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { internal::GetTestingMessageFactory(), "google.protobuf.BytesValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -442,7 +442,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BytesValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -460,7 +460,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { internal::GetTestingMessageFactory(), "google.protobuf.Duration"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -468,7 +468,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -477,7 +477,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -485,7 +485,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -505,7 +505,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { internal::GetTestingMessageFactory(), "google.protobuf.Timestamp"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -513,7 +513,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -522,7 +522,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -530,7 +530,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -552,7 +552,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { EXPECT_THAT(builder->SetFieldByName( "type_url", StringValue("type.googleapis.com/google.protobuf.BoolValue")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -560,14 +560,14 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("value", BytesValue()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT( builder->SetFieldByNumber( 1, StringValue("type.googleapis.com/google.protobuf.BoolValue")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -575,7 +575,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..90c9930a8 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,301 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + Type elem_type; + if (type_spec.list_type().elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + elem_type, ConvertTypeSpecToType(type_spec.list_type().elem_type(), + arena, pool)); + } + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + Type key_type; + if (type_spec.map_type().key_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + } + + Type value_type; + if (type_spec.map_type().value_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + value_type, ConvertTypeSpecToType(type_spec.map_type().value_type(), + arena, pool)); + } + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + Type result_type; + if (func_spec.result_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + } + std::vector arg_types; + arg_types.reserve(func_spec.arg_types().size()); + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..edbfa3bde --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Resolves a `cel::TypeSpec` to a `cel::Type`. +// +// TypeSpec only specifies a type while Type provides support for inspecting +// properties of the type when used in CEL. Returns a status with code +// `InvalidArgument` if the input cannot be resolved to a type. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..1cda7280f --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,284 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Values; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +using ConversionTest = testing::TestWithParam>; + +TEST_P(ConversionTest, TestTypeSpecConversion) { + ASSERT_OK_AND_ASSIGN( + auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P( + TypeSpecResolverTest, ConversionTest, + testing::Values( + std::make_tuple(TypeSpec(PrimitiveType::kBool), TypeKind::kBool), + std::make_tuple(TypeSpec(PrimitiveType::kInt64), TypeKind::kInt), + std::make_tuple(TypeSpec(PrimitiveType::kUint64), TypeKind::kUint), + std::make_tuple(TypeSpec(PrimitiveType::kDouble), TypeKind::kDouble), + std::make_tuple(TypeSpec(PrimitiveType::kString), TypeKind::kString), + std::make_tuple(TypeSpec(PrimitiveType::kBytes), TypeKind::kBytes), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kAny), TypeKind::kAny), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kTimestamp), + TypeKind::kTimestamp), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kDuration), + TypeKind::kDuration), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + TypeKind::kBoolWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + TypeKind::kIntWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + TypeKind::kUintWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + TypeKind::kDoubleWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + TypeKind::kStringWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + TypeKind::kBytesWrapper))); + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, OptionalType) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("optional_type", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "optional_type"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnknownTypeSpecKindError) { + TypeSpec spec; + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown TypeSpec kind"))); +} + +} // namespace +} // namespace cel diff --git a/common/type_test.cc b/common/type_test.cc index 119234fdc..d6a613c3c 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -45,7 +45,7 @@ TEST(Type, Enum) { EXPECT_EQ(Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"))), - NullType()); + IntType()); } TEST(Type, Field) { @@ -58,7 +58,7 @@ TEST(Type, Field) { BoolType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("null_value"))), - NullType()); + IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int32"))), IntType()); @@ -638,5 +638,39 @@ TEST(Type, Wrap) { EXPECT_EQ(Type(AnyType()).Wrap(), AnyType()); } +TEST(Type, LegacyRuntimeType) { + EXPECT_EQ(common_internal::LegacyRuntimeType("bool"), BoolType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Any"), + AnyType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BoolValue"), + BoolWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BytesValue"), + BytesWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.DoubleValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Duration"), + DurationType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.FloatValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int32Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int64Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.ListValue"), + ListType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.StringValue"), + StringWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Struct"), + JsonMapType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Timestamp"), + TimestampType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt32Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt64Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Value"), + DynType()); +} + } // namespace } // namespace cel diff --git a/common/typeinfo.cc b/common/typeinfo.cc index 86bae1934..b07275712 100644 --- a/common/typeinfo.cc +++ b/common/typeinfo.cc @@ -57,18 +57,13 @@ std::string TypeInfo::DebugString() const { } return std::string(demangled.get()); #else - size_t length = 0; int status = 0; std::unique_ptr demangled( - abi::__cxa_demangle(rep_->name(), nullptr, &length, &status)); + abi::__cxa_demangle(rep_->name(), nullptr, nullptr, &status)); if (status != 0 || demangled == nullptr) { return std::string(rep_->name()); } - while (length != 0 && demangled.get()[length - 1] == '\0') { - // length includes the null terminator, remove it. - --length; - } - return std::string(demangled.get(), length); + return std::string(demangled.get()); #endif #else return absl::StrCat("0x", absl::Hex(absl::bit_cast(rep_))); diff --git a/common/types/legacy_type_manager.h b/common/types/legacy_type_manager.h deleted file mode 100644 index 238335b52..000000000 --- a/common/types/legacy_type_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ - -#include "common/memory.h" -#include "common/type_introspector.h" -#include "common/type_manager.h" - -namespace cel::common_internal { - -// `LegacyTypeManager` is an implementation which should be used when -// converting between `cel::Value` and `google::api::expr::runtime::CelValue` -// and only then. -class LegacyTypeManager : public virtual TypeManager { - public: - explicit LegacyTypeManager(const TypeIntrospector& type_introspector) - : type_introspector_(type_introspector) {} - - protected: - const TypeIntrospector& GetTypeIntrospector() const final { - return type_introspector_; - } - - private: - const TypeIntrospector& type_introspector_; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ diff --git a/common/types/opaque_type.cc b/common/types/opaque_type.cc index 002319d1d..9c58e8289 100644 --- a/common/types/opaque_type.cc +++ b/common/types/opaque_type.cc @@ -98,7 +98,7 @@ absl::optional OpaqueType::AsOptional() const { if (IsOptional()) { return OptionalType(absl::in_place, *this); } - return absl::nullopt; + return std::nullopt; } OptionalType OpaqueType::GetOptional() const { diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc index 4540cec9c..69f531a2f 100644 --- a/common/types/struct_type.cc +++ b/common/types/struct_type.cc @@ -27,7 +27,7 @@ namespace cel { absl::string_view StructType::name() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) { return absl::string_view(); }, + absl::Overload([](std::monostate) { return absl::string_view(); }, [](const common_internal::BasicStructType& alt) { return alt.name(); }, @@ -39,7 +39,7 @@ TypeParameters StructType::GetParameters() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload( - [](absl::monostate) { return TypeParameters(); }, + [](std::monostate) { return TypeParameters(); }, [](const common_internal::BasicStructType& alt) { return alt.GetParameters(); }, @@ -49,7 +49,7 @@ TypeParameters StructType::GetParameters() const { std::string StructType::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) { return std::string(); }, + absl::Overload([](std::monostate) { return std::string(); }, [](common_internal::BasicStructType alt) { return alt.DebugString(); }, @@ -61,7 +61,7 @@ absl::optional StructType::AsMessage() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } MessageType StructType::GetMessage() const { @@ -72,7 +72,7 @@ MessageType StructType::GetMessage() const { common_internal::TypeVariant StructType::ToTypeVariant() const { return absl::visit( absl::Overload( - [](absl::monostate) { return common_internal::TypeVariant(); }, + [](std::monostate) { return common_internal::TypeVariant(); }, [](common_internal::BasicStructType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); diff --git a/common/types/thread_compatible_type_introspector.h b/common/types/thread_compatible_type_introspector.h deleted file mode 100644 index 870ea9054..000000000 --- a/common/types/thread_compatible_type_introspector.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ - -#include "common/type_introspector.h" - -namespace cel::common_internal { - -// `ThreadCompatibleTypeIntrospector` is a basic implementation of -// `TypeIntrospector` which is thread compatible. By default this implementation -// just returns `NOT_FOUND` for most methods. -class ThreadCompatibleTypeIntrospector : public virtual TypeIntrospector { - public: - ThreadCompatibleTypeIntrospector() = default; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ diff --git a/common/value.cc b/common/value.cc index 535ddead8..1cd3f54e1 100644 --- a/common/value.cc +++ b/common/value.cc @@ -115,7 +115,7 @@ Type Value::GetRuntimeType() const { namespace { template -struct IsMonostate : std::is_same, absl::monostate> {}; +struct IsMonostate : std::is_same, std::monostate> {}; } // namespace @@ -171,7 +171,7 @@ absl::Status Value::ConvertToJsonArray( google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -212,7 +212,7 @@ absl::Status Value::ConvertToJsonObject( google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -1363,7 +1363,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->CopyFrom(message); return ParsedMessageValue(cloned, arena); @@ -1391,7 +1391,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->GetReflection()->Swap(cloned, &message); return ParsedMessageValue(cloned, arena); @@ -1422,7 +1422,7 @@ Value Value::WrapMessage( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { auto* cloned = message->New(arena); cloned->CopyFrom(*message); @@ -1456,7 +1456,7 @@ Value Value::WrapMessageUnsafe( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { return UnsafeParsedMessageValue(message); } diff --git a/common/values/custom_map_value.h b/common/values/custom_map_value.h index 9e840e07f..ca6e1e025 100644 --- a/common/values/custom_map_value.h +++ b/common/values/custom_map_value.h @@ -225,7 +225,7 @@ class CustomMapValueInterface { // Returns the number of entries in this map. virtual size_t Size() const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -233,7 +233,7 @@ class CustomMapValueInterface { google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ForEach( ForEachCallback callback, @@ -347,7 +347,7 @@ class CustomMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -356,7 +356,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -365,7 +365,7 @@ class CustomMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -374,7 +374,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -386,7 +386,7 @@ class CustomMapValue final // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, @@ -394,7 +394,7 @@ class CustomMapValue final google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr NewIterator() const; diff --git a/common/values/legacy_map_value.h b/common/values/legacy_map_value.h index 31865a873..c83b7fc2f 100644 --- a/common/values/legacy_map_value.h +++ b/common/values/legacy_map_value.h @@ -102,7 +102,7 @@ class LegacyMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -111,7 +111,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -120,7 +120,7 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -129,7 +129,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,11 +137,11 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/map_value.h b/common/values/map_value.h index ffbdea6c9..b6e69ea57 100644 --- a/common/values/map_value.h +++ b/common/values/map_value.h @@ -15,10 +15,16 @@ // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" -// `MapValue` represents values of the primitive `map` type. `MapValueView` -// is a non-owning view of `MapValue`. `MapValueInterface` is the abstract -// base class of implementations. `MapValue` and `MapValueView` act as smart -// pointers to `MapValueInterface`. +// `MapValue` represents values of the primitive `map` type. It provides a +// unified interface for accessing map contents, regardless of the underlying +// implementation (e.g., JSON, protobuf map field, or custom implementation). +// +// Public member functions: +// - `IsEmpty()` / `Size()`: Query map size. +// - `Get()` / `Find()` / `Has()`: Access entries by key. +// - `ListKeys()` / `NewIterator()` / `ForEach()`: Iterate over entries. +// - `ConvertToJson()` / `ConvertToJsonObject()`: JSON conversion. +// - `IsCustom()` / `AsCustom()` / `GetCustom()`: Access custom implementation. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ @@ -54,7 +60,6 @@ namespace cel { -class MapValueInterface; class MapValue; class Value; @@ -119,8 +124,13 @@ class MapValue final : private common_internal::MapValueMixin { absl::StatusOr Size() const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Get` sets the value `result` to (via `result`) the value associated with + // `key`. If `key` is not found, `no such key` is set to `result`. If an error + // occurs (e.g., invalid key type), an `no such key` is returned. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -128,8 +138,13 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Find` returns `true` if `key` is found in the map, and stores the + // associated value in `result`. If `key` is not found, `false` is returned + // and `result` is unchanged. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,8 +152,13 @@ class MapValue final : private common_internal::MapValueMixin { google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Has` returns `true` if `key` is found in the map, and stores the BoolValue + // result in `result`. In case of an error, the result is set to an + // ErrorValue. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -146,28 +166,25 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ListKeys` returns a `ListValue` containing all keys in the map. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for - // documentation. + // `ForEachCallback` is the callback type for `ForEach`. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ForEach` calls `callback` for each entry in the map. Iteration continues + // until all entries are visited or `callback` returns an error or `false`. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `NewIterator` returns a new iterator for the map. absl::StatusOr NewIterator() const; // Returns `true` if this value is an instance of a custom map value. diff --git a/common/values/message_value.cc b/common/values/message_value.cc index e06206407..66dfd9511 100644 --- a/common/values/message_value.cc +++ b/common/values/message_value.cc @@ -46,7 +46,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c ABSL_CHECK(*this); // Crash OK return absl::visit( absl::Overload( - [](absl::monostate) -> const google::protobuf::Descriptor* absl_nonnull { + [](std::monostate) -> const google::protobuf::Descriptor* absl_nonnull { ABSL_UNREACHABLE(); }, [](const ParsedMessageValue& alternative) @@ -58,7 +58,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c std::string MessageValue::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) -> std::string { return "INVALID"; }, + absl::Overload([](std::monostate) -> std::string { return "INVALID"; }, [](const ParsedMessageValue& alternative) -> std::string { return alternative.DebugString(); }), @@ -68,7 +68,7 @@ std::string MessageValue::DebugString() const { bool MessageValue::IsZeroValue() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) -> bool { return true; }, + absl::Overload([](std::monostate) -> bool { return true; }, [](const ParsedMessageValue& alternative) -> bool { return alternative.IsZeroValue(); }), @@ -81,7 +81,7 @@ absl::Status MessageValue::SerializeTo( google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -99,7 +99,7 @@ absl::Status MessageValue::ConvertToJson( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -117,7 +117,7 @@ absl::Status MessageValue::ConvertToJsonObject( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJsonObject` on " "an invalid `MessageValue`"); @@ -136,7 +136,7 @@ absl::Status MessageValue::Equal( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Equal` on " "an invalid `MessageValue`"); @@ -155,7 +155,7 @@ absl::Status MessageValue::GetFieldByName( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByName` on " "an invalid `MessageValue`"); @@ -175,7 +175,7 @@ absl::Status MessageValue::GetFieldByNumber( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByNumber` on " "an invalid `MessageValue`"); @@ -192,7 +192,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::string_view name) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByName` on " "an invalid `MessageValue`"); @@ -206,7 +206,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByNumber` on " "an invalid `MessageValue`"); @@ -224,7 +224,7 @@ absl::Status MessageValue::ForEachField( google::protobuf::Arena* absl_nonnull arena) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ForEachField` on " "an invalid `MessageValue`"); @@ -244,7 +244,7 @@ absl::Status MessageValue::Qualify( int* absl_nonnull count) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Qualify` on " "an invalid `MessageValue`"); diff --git a/common/values/null_value.h b/common/values/null_value.h index 53c3161a1..d4d05dba3 100644 --- a/common/values/null_value.h +++ b/common/values/null_value.h @@ -37,8 +37,7 @@ namespace cel { class Value; class NullValue; -// `NullValue` represents values of the primitive `duration` type. - +// `NullValue` represents the CEL `null` value. class NullValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kNull; diff --git a/common/values/opaque_value.h b/common/values/opaque_value.h index 273b7889a..57af78ae0 100644 --- a/common/values/opaque_value.h +++ b/common/values/opaque_value.h @@ -52,7 +52,7 @@ class Value; class OpaqueValueInterface; class OpaqueValueInterfaceIterator; class OpaqueValue; -class TypeFactory; + using OpaqueValueContent = CustomValueContent; struct OpaqueValueDispatcher { diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index ad0a65efb..7c214b9cb 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -122,200 +122,185 @@ absl::Status OptionalValueEqual( return absl::OkStatus(); } +google::protobuf::Arena* absl_nullable OptionalValueGetArenaNull( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return nullptr; +} + +OpaqueValue OptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + return common_internal::MakeOptionalValue(dispatcher, content); +} + +bool OptionalValueHasNoValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content) { + return false; +} + +void EmptyOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = + ErrorValue(absl::FailedPreconditionError("optional.none() dereference")); +} + +void NullOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = NullValue(); +} + +void BoolOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = BoolValue(content.To()); +} + +void IntOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = IntValue(content.To()); +} + +void UintOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UintValue(content.To()); +} + +void DoubleOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = DoubleValue(content.To()); +} + +void DurationOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeDurationValue(content.To()); +} + +void TimestampOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeTimestampValue(content.To()); +} + ABSL_CONST_INIT const OptionalValueDispatcher empty_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, - }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content) -> bool { return false; }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = ErrorValue( - absl::FailedPreconditionError("optional.none() dereference")); + .clone = &OptionalValueClone, }, + &OptionalValueHasNoValue, + &EmptyOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent, - cel::Value* absl_nonnull result) -> void { *result = NullValue(); }, + &NullOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = BoolValue(content.To()); - }, + &BoolOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = IntValue(content.To()); - }, + &IntOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UintValue(content.To()); - }, + &UintOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher double_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = DoubleValue(content.To()); - }, + &DoubleOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher duration_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeDurationValue(content.To()); - }, + &DurationOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher timestamp_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeTimestampValue(content.To()); - }, + &TimestampOptionalValueValue, }; struct OptionalValueContent { @@ -323,43 +308,51 @@ struct OptionalValueContent { google::protobuf::Arena* absl_nonnull arena; }; +google::protobuf::Arena* absl_nullable GenericOptionalValueGetArena( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) { + return content.To().arena; +} + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + +void GenericOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = *content.To().value; +} + ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = - [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent content) -> google::protobuf::Arena* absl_nullable { - return content.To().arena; - }, + .get_arena = &GenericOptionalValueGetArena, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - ABSL_DCHECK(arena != nullptr); - - cel::Value* absl_nonnull result = ::new ( - arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) - cel::Value( - content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { - arena->OwnDestructor(result); - } - return common_internal::MakeOptionalValue( - &optional_value_dispatcher, - OpaqueValueContent::From( - OptionalValueContent{.value = result, .arena = arena})); - }, + .clone = &GenericOptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = *content.To().value; - }, + &GenericOptionalValueValue, }; +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + cel::Value* absl_nonnull result = + ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(content.To().value->Clone(arena)); + if (!ArenaTraits<>::trivially_destructible(*result)) { + arena->OwnDestructor(result); + } + return common_internal::MakeOptionalValue( + &optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); +} + } // namespace OptionalValue OptionalValue::Of(cel::Value value, @@ -402,7 +395,7 @@ OptionalValue OptionalValue::Of(cel::Value value, cel::Value* absl_nonnull result = ::new ( arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(std::move(value)); - if (!ArenaTraits<>::trivially_destructible(result)) { + if (!ArenaTraits<>::trivially_destructible(*result)) { arena->OwnDestructor(result); } return OptionalValue(&optional_value_dispatcher, diff --git a/common/values/parsed_json_map_value.cc b/common/values/parsed_json_map_value.cc index 6072a0b21..ec8c91a4f 100644 --- a/common/values/parsed_json_map_value.cc +++ b/common/values/parsed_json_map_value.cc @@ -408,8 +408,8 @@ class ParsedJsonMapValueIterator final : public ValueIterator { private: const google::protobuf::Message* absl_nonnull const message_; const well_known_types::StructReflection reflection_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; std::string scratch_; }; diff --git a/common/values/parsed_json_map_value.h b/common/values/parsed_json_map_value.h index b20fe032b..ba8d3490d 100644 --- a/common/values/parsed_json_map_value.h +++ b/common/values/parsed_json_map_value.h @@ -132,7 +132,7 @@ class ParsedJsonMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -141,7 +141,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -150,7 +150,7 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -159,7 +159,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -167,11 +167,11 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc index 737593cca..47b737f82 100644 --- a/common/values/parsed_map_field_value.cc +++ b/common/values/parsed_map_field_value.cc @@ -415,10 +415,10 @@ absl::Status ParsedMapFieldValue::ListKeys( field_->message_type()->map_key())); auto builder = NewListValueBuilder(arena); builder->Reserve(Size()); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, *message_, *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); for (; begin != end; ++begin) { Value scratch; (*key_accessor)(begin.GetKey(), message_, arena, &scratch); @@ -446,10 +446,10 @@ absl::Status ParsedMapFieldValue::ForEach( CEL_ASSIGN_OR_RETURN( auto value_accessor, common_internal::MapFieldValueAccessorFor(value_field)); - auto begin = extensions::protobuf_internal::MapBegin(*reflection, *message_, - *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin( + *reflection, *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); Value key_scratch; Value value_scratch; for (; begin != end; ++begin) { @@ -479,10 +479,10 @@ class ParsedMapFieldValueIterator final : public ValueIterator { value_field_(field->message_type()->map_value()), key_accessor_(key_accessor), value_accessor_(value_accessor), - begin_(extensions::protobuf_internal::MapBegin( + begin_(extensions::protobuf_internal::ConstMapBegin( *message_->GetReflection(), *message_, *field)), - end_(extensions::protobuf_internal::MapEnd(*message_->GetReflection(), - *message_, *field)) {} + end_(extensions::protobuf_internal::ConstMapEnd( + *message_->GetReflection(), *message_, *field)) {} bool HasNext() override { return begin_ != end_; } @@ -545,8 +545,8 @@ class ParsedMapFieldValueIterator final : public ValueIterator { const google::protobuf::FieldDescriptor* absl_nonnull const value_field_; const absl_nonnull common_internal::MapFieldKeyAccessor key_accessor_; const absl_nonnull common_internal::MapFieldValueAccessor value_accessor_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; }; } // namespace diff --git a/common/values/parsed_map_field_value.h b/common/values/parsed_map_field_value.h index 3478f75bc..21d686bfd 100644 --- a/common/values/parsed_map_field_value.h +++ b/common/values/parsed_map_field_value.h @@ -117,7 +117,7 @@ class ParsedMapFieldValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -126,7 +126,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -135,7 +135,7 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -144,7 +144,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -152,11 +152,11 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 359596267..446b18421 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -812,6 +812,17 @@ ProtoMessageRepeatedFieldFromValueMutator( const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + // If the value is null and the target repeated field is anything except + // google.protobuf.{Any,ListValue,Struct,Value}, it should be pruned. + if (value.IsNull()) { + const auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY) { + return absl::nullopt; + } + } auto* element = reflection->AddMessage(message, field, factory); auto result = ProtoMessageFromValueImpl(value, pool, factory, well_known_types, element); @@ -945,6 +956,19 @@ class MessageValueBuilderImpl { if (error_value) { return false; } + if (map_value_field->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + entry_value.IsNull()) { + auto well_known_type = + map_value_field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } google::protobuf::MapValueRef proto_value; extensions::protobuf_internal::InsertOrLookupMapValue( *reflection_, message_, *field, proto_key, &proto_value); @@ -978,6 +1002,16 @@ class MessageValueBuilderImpl { CEL_RETURN_IF_ERROR(list_value->ForEach( [this, field, accessor, &error_value](const Value& element) -> absl::StatusOr { + if (field->message_type() != nullptr && element.IsNull()) { + auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } CEL_ASSIGN_OR_RETURN(error_value, (*accessor)(descriptor_pool_, message_factory_, &well_known_types_, reflection_, diff --git a/common/values/values.h b/common/values/values.h index c9703dcbb..aaa6f8659 100644 --- a/common/values/values.h +++ b/common/values/values.h @@ -48,7 +48,6 @@ namespace cel { class ValueInterface; class ListValueInterface; -class MapValueInterface; class StructValueInterface; class Value; diff --git a/compiler/BUILD b/compiler/BUILD index 02bbb37dd..d4a0ab4ac 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -27,9 +27,12 @@ cc_library( "//checker:validation_result", "//parser:options", "//parser:parser_interface", + "//validator", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) @@ -39,15 +42,18 @@ cc_library( hdrs = ["compiler_factory.h"], deps = [ ":compiler", + "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", + "//common:ast", "//common:source", "//internal:noop_delete", "//internal:status_macros", "//parser", "//parser:parser_interface", + "//validator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -64,6 +70,7 @@ cc_test( deps = [ ":compiler", ":compiler_factory", + ":optional", ":standard_library", "//checker:optional", "//checker:standard_library", @@ -78,6 +85,7 @@ cc_test( "//parser:macro", "//parser:parser_interface", "//testutil:baseline_tests", + "//validator:timestamp_literal_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/compiler/compiler.h b/compiler/compiler.h index 8b867cd60..27237df60 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -19,6 +19,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -28,6 +29,8 @@ #include "checker/validation_result.h" #include "parser/options.h" #include "parser/parser_interface.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" namespace cel { @@ -94,12 +97,14 @@ struct CompilerLibrarySubset { struct CompilerOptions { ParserOptions parser_options; CheckerOptions checker_options; + // If true, parse errors will be adapted to issues where possible. + bool adapt_parser_errors = false; }; // Interface for CEL CompilerBuilder objects. // -// Builder implementations are thread hostile, but should create -// thread-compatible Compiler instances. +// Builder implementations do not provide any synchronization themselves, +// but create thread-compatible Compiler instances. class CompilerBuilder { public: virtual ~CompilerBuilder() = default; @@ -109,6 +114,7 @@ class CompilerBuilder { virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; virtual ParserBuilder& GetParserBuilder() = 0; + virtual Validator& GetValidator() = 0; virtual absl::StatusOr> Build() = 0; }; @@ -124,10 +130,16 @@ class Compiler { virtual ~Compiler() = default; virtual absl::StatusOr Compile( - absl::string_view source, absl::string_view description) const = 0; + absl::string_view source, absl::string_view description, + google::protobuf::Arena* absl_nullable arena) const = 0; absl::StatusOr Compile(absl::string_view source) const { - return Compile(source, ""); + return Compile(source, "", nullptr); + } + + absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const { + return Compile(source, description, nullptr); } // Accessor for the underlying type checker. @@ -135,6 +147,18 @@ class Compiler { // Accessor for the underlying parser. virtual const Parser& GetParser() const = 0; + + // Accessor for the underlying validator. + virtual const Validator& GetValidator() const = 0; + + // Returns a builder initialized with the configuration of this compiler. + // + // The returned builder is a copy of the validated environment and may + // behave differently than the builder that created this compiler. + // + // The returned builder does not share state with the compiler and may be + // modified independently. + virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 6530dd816..ed22c5630 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -17,21 +17,26 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/parser.h" #include "parser/parser_interface.h" +#include "validator/validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -41,36 +46,70 @@ namespace { class CompilerImpl : public Compiler { public: CompilerImpl(std::unique_ptr type_checker, - std::unique_ptr parser) - : type_checker_(std::move(type_checker)), parser_(std::move(parser)) {} + std::unique_ptr parser, + // Copy the validator in case builder is reused. + Validator validator, CompilerOptions options) + : type_checker_(std::move(type_checker)), + parser_(std::move(parser)), + validator_(std::move(validator)), + options_(options) {} absl::StatusOr Compile( - absl::string_view expression, - absl::string_view description) const override { + absl::string_view expression, absl::string_view description, + google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); - CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (!options_.adapt_parser_errors || + ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast))); + type_checker_->Check(*std::move(ast), arena)); result.SetSource(std::move(source)); + if (!validator_.validations().empty()) { + validator_.UpdateValidationResult(result); + } return result; } + std::unique_ptr ToBuilder() const override; + const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } + const Validator& GetValidator() const override { return validator_; } private: std::unique_ptr type_checker_; std::unique_ptr parser_; + Validator validator_; + CompilerOptions options_; }; class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, - std::unique_ptr parser_builder) + std::unique_ptr parser_builder, + Validator validator, CompilerOptions options) : type_checker_builder_(std::move(type_checker_builder)), - parser_builder_(std::move(parser_builder)) {} + parser_builder_(std::move(parser_builder)), + validator_(std::move(validator)), + options_(options) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -126,22 +165,30 @@ class CompilerBuilderImpl : public CompilerBuilder { TypeCheckerBuilder& GetCheckerBuilder() override { return *type_checker_builder_; } + Validator& GetValidator() override { return validator_; } absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); - return std::make_unique(std::move(type_checker), - std::move(parser)); + return std::make_unique( + std::move(type_checker), std::move(parser), validator_, options_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; + Validator validator_; + CompilerOptions options_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; }; +std::unique_ptr CompilerImpl::ToBuilder() const { + return std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_, options_); +} + } // namespace absl::StatusOr> NewCompilerBuilder( @@ -156,7 +203,8 @@ absl::StatusOr> NewCompilerBuilder( auto parser_builder = NewParserBuilder(options.parser_options); return std::make_unique(std::move(type_checker_builder), - std::move(parser_builder)); + std::move(parser_builder), + Validator(), options); } } // namespace cel diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 5df0f4794..035fd8aa6 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -29,12 +29,15 @@ #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" +#include "compiler/optional.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/macro.h" #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" +#include "validator/timestamp_literal_validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -287,6 +290,23 @@ TEST(CompilerFactoryTest, DisableStandardMacrosWithStdlib) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, AddValidator) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + builder->GetValidator().AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("timestamp('invalid')")); + EXPECT_FALSE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(result, + compiler->Compile("timestamp('2024-01-01T00:00:00Z')")); + EXPECT_TRUE(result.IsValid()); +} + TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { ASSERT_OK_AND_ASSIGN( auto builder, @@ -346,5 +366,66 @@ TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { HasSubstr("descriptor_pool must not be null"))); } +TEST(CompilerFactoryTest, ToBuilderWorks) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + auto derived_builder = compiler->ToBuilder(); + + ASSERT_THAT(derived_builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto derived_compiler, derived_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + derived_compiler->Compile("has(a.b) && a.?b.orValue('foo') == 'foo'")); + EXPECT_TRUE(result.IsValid()); +} + +TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); + ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); + EXPECT_TRUE( + it->second.IsOptional() && + it->second.GetOptional().GetParameter().IsList() && + it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); +} + +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + CompilerOptions opts; + opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + } // namespace } // namespace cel diff --git a/conformance/BUILD b/conformance/BUILD index ba485f36d..35d554c7b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -32,7 +32,6 @@ cc_library( "//common:ast", "//common:ast_proto", "//common:decl_proto_v1alpha1", - "//common:expr", "//common:source", "//common:value", "//common/internal:value_conversion", @@ -57,8 +56,6 @@ cc_library( "//extensions/protobuf:enum_adapter", "//internal:status_macros", "//parser", - "//parser:macro", - "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", @@ -69,19 +66,19 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", @@ -97,6 +94,7 @@ cc_library( deps = [ ":service", ":utils", + "//internal:runfiles", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", @@ -164,7 +162,7 @@ _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] -_TESTS_TO_SKIP_MODERN = [ +_TESTS_TO_SKIP = [ # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "timestamps/duration_converters/get_milliseconds", @@ -197,45 +195,25 @@ _TESTS_TO_SKIP_MODERN = [ "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", "timestamps/timestamp_selectors_tz/getDayOfYear", # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. + # precision to preserve value. Not available on older compilers where we just use absl::Format. + # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", -] -_TESTS_TO_SKIP_MODERN_DASHBOARD = [ - # Future features for CEL 1.0 - # TODO(issues/119): Strong typing support for enums, specified but not implemented. - "enums/strong_proto2", - "enums/strong_proto3", + # Recent changes + "namespace/namespace_shadowing/basic", + "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] -_TESTS_TO_SKIP_LEGACY = [ - # Tests which require spec changes. - # TODO(issues/93): Deprecate Duration.getMilliseconds. - "timestamps/duration_converters/get_milliseconds", - - # Broken test cases which should be supported. - # TODO(issues/112): Unbound functions result in empty eval response. - "basic/functions/unbound", - "basic/functions/unbound_is_runtime_error", - - # TODO(issues/97): Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails - "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", - "namespace/qualified/self_eval_qualified_lookup", - "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - # TODO(issues/117): Integer overflow on enum assignments should error. - "enums/legacy_proto2/select_big,select_neg", - - # Skip until fixed. - "wrappers/field_mask/to_json", - "wrappers/empty/to_json", - "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", - "parse/receiver_function_names", +_TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP +_TESTS_TO_SKIP_MODERN_DASHBOARD = [ # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. "enums/strong_proto2", "enums/strong_proto3", +] +_TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ # Legacy value does not support optional_type. "optionals/optionals", @@ -245,15 +223,7 @@ _TESTS_TO_SKIP_LEGACY = [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # These depend on legacy US/ timezones. It's spotty if these are included with a normally - # configured timezone database. - "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", - "timestamps/timestamp_selectors_tz/getDayOfYear", - # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. - "conversions/string/double_hard", - - # cel.@block + # no optional support for legacy types "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", @@ -263,7 +233,7 @@ _TESTS_TO_SKIP_LEGACY = [ _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. - "block_ext", + # "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ @@ -327,6 +297,24 @@ gen_conformance_tests( skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) +gen_conformance_tests( + name = "conformance_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( diff --git a/conformance/policy/BUILD b/conformance/policy/BUILD new file mode 100644 index 000000000..29210e02d --- /dev/null +++ b/conformance/policy/BUILD @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//conformance/policy:policy_conformance_test.bzl", + "cel_policy_conformance_test", +) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "policy_conformance_test_lib", + testonly = True, + srcs = ["policy_conformance_test.cc"], + deps = [ + "//common:ast", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//extensions/protobuf:bind_proto_to_activation", + "//extensions/protobuf:enum_adapter", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//internal:testing_no_main", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "//policy:cel_policy_validation_result", + "//policy:compiler", + "//policy:test_util", + "//policy:yaml_policy_parser", + "//runtime", + "//runtime:activation", + "//runtime:function_adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_policy_conformance_test( + name = "policy_conformance_test", + example = "@cel_policy//conformance:testdata/nested_rule/policy.yaml", + skip_tests = [ + # TODO(b/506179116): Fix these. + # Need to add k8s custom yaml parser and mock runtime. + "k8s", + ], + test_files = [ + "@cel_policy//conformance:testdata", + ], +) diff --git a/conformance/policy/policy_conformance_test.bzl b/conformance/policy/policy_conformance_test.bzl new file mode 100644 index 000000000..0b4d1a4c6 --- /dev/null +++ b/conformance/policy/policy_conformance_test.bzl @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating policy conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +def cel_policy_conformance_test(name, test_files, example, skip_tests = [], **kwargs): + """Generates a policy conformance test target. + + Args: + name: Name of the test target. + test_files: List of targets or files representing the test data. + example: A specific example file from test_files used for runfiles resolution. + skip_tests: List of test cases to skip. + testdata_dir: Path to testdata directory under runfiles. + **kwargs: Additional arguments passed to the underlying cc_test. + """ + args = ["--gunit_fail_if_no_test_linked"] + args.append("--testdata_example='$(rlocationpath {})'".format(example)) + + if skip_tests: + args.append("--skip_tests=" + ",".join(skip_tests)) + + cc_test( + name = name, + data = test_files + [example], + deps = [ + "//conformance/policy:policy_conformance_test_lib", + ], + args = args, + **kwargs + ) diff --git a/conformance/policy/policy_conformance_test.cc b/conformance/policy/policy_conformance_test.cc new file mode 100644 index 000000000..0d68f8abf --- /dev/null +++ b/conformance/policy/policy_conformance_test.cc @@ -0,0 +1,659 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +// NOLINTNEXTLINE(build/c++17) for OSS compatibility +#include + +#include "cel/expr/eval.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/protobuf/bind_proto_to_activation.h" +#include "extensions/protobuf/enum_adapter.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/compiler.h" +#include "policy/test_util.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +// Use a specific file to handle bazel runfiles resolution correctly. We find +// parent directory named 'testdata' to use as the root of the test cases. +ABSL_FLAG(std::string, testdata_example, "", + "Path to a specific example file."); +ABSL_FLAG(std::vector, skip_tests, {}, + "Comma-separated list of tests to skip."); + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; + +// Implementations for extension functions referenced in conformance tests. +cel::Value LocationCode(const cel::StringValue& ip, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory, google::protobuf::Arena* arena) { + std::string ip_str = ip.ToString(); + if (ip_str == "10.0.0.1") return cel::StringValue(arena, "us"); + if (ip_str == "10.0.0.2") return cel::StringValue(arena, "de"); + return cel::StringValue(arena, "ir"); +} + +// TODO(uncreated-issue/92): This should be migrated to use the testrunner utility +// after adding support for reading the yaml specification for envs/tests. +class InputEvaluator { + public: + static absl::StatusOr> Create( + const std::shared_ptr& pool) { + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + // Enable default extensions (optional, bindings) + cel::Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + env.SetConfig(config); + env_runtime.SetConfig(config); + + auto compiler_builder_or = env.NewCompilerBuilder(); + CEL_ASSIGN_OR_RETURN(auto compiler_builder, std::move(compiler_builder_or)); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + auto runtime_builder_or = env_runtime.CreateRuntimeBuilder(); + CEL_ASSIGN_OR_RETURN(auto runtime_builder, std::move(runtime_builder_or)); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + return absl::WrapUnique( + new InputEvaluator(std::move(compiler), std::move(runtime))); + } + + absl::StatusOr Evaluate( + absl::string_view expr_str, google::protobuf::Arena* arena, + google::protobuf::MessageFactory* message_factory) const { + CEL_ASSIGN_OR_RETURN(auto validation_result, compiler_->Compile(expr_str)); + if (!validation_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to compile input expr: ", expr_str)); + } + CEL_ASSIGN_OR_RETURN(auto ast, validation_result.ReleaseAst()); + CEL_ASSIGN_OR_RETURN( + auto program, + runtime_->CreateProgram(std::make_unique(std::move(*ast)))); + cel::Activation activation; + EvaluateOptions options; + options.message_factory = message_factory; + return program->Evaluate(arena, activation, options); + } + + private: + InputEvaluator(std::unique_ptr compiler, + std::unique_ptr runtime) + : compiler_(std::move(compiler)), runtime_(std::move(runtime)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; +}; + +absl::StatusOr EvaluateInputValue( + const cel::expr::conformance::test::InputValue& input_val, + const InputEvaluator& evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + if (input_val.has_expr()) { + return evaluator.Evaluate(input_val.expr(), arena, message_factory); + } + if (input_val.has_value()) { + return cel::test::FromExprValue(input_val.value(), descriptor_pool, + message_factory, arena); + } + return absl::InvalidArgumentError("Empty InputValue"); +} + +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + CelValueMatcherImpl(cel::Value expected_val, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) + : expected_val_(std::move(expected_val)), + pool_(pool), + message_factory_(message_factory), + arena_(arena) {} + + bool MatchAndExplain(const cel::Value& actual_val, + testing::MatchResultListener* listener) const override { + cel::Value actual = actual_val; + if (actual.IsOptional() && !expected_val_.IsOptional()) { + auto opt_val = actual.AsOptional(); + if (opt_val->HasValue()) { + actual = opt_val->Value(); + } + } + cel::Value eq_result; + auto eq_status = actual.Equal(expected_val_, pool_, message_factory_, + arena_, &eq_result); + if (!eq_status.ok()) { + *listener << "equality check failed with status: " << eq_status; + return false; + } + if (!eq_result.IsTrue()) { + *listener << "expected: " << expected_val_.DebugString() + << "\nactual: " << actual.DebugString(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "is equal to " << expected_val_.DebugString(); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not equal to " << expected_val_.DebugString(); + } + + private: + cel::Value expected_val_; + const google::protobuf::DescriptorPool* pool_; + google::protobuf::MessageFactory* message_factory_; + google::protobuf::Arena* arena_; +}; + +absl::StatusOr> MakeExpectedValueMatcher( + const cel::expr::conformance::test::TestOutput& output, + const InputEvaluator& input_evaluator, const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + cel::Value expected_val; + if (output.has_result_expr()) { + CEL_ASSIGN_OR_RETURN( + expected_val, + input_evaluator.Evaluate(output.result_expr(), arena, message_factory)); + } else if (output.has_result_value()) { + CEL_ASSIGN_OR_RETURN(expected_val, + cel::test::FromExprValue(output.result_value(), pool, + message_factory, arena)); + } else { + return absl::InvalidArgumentError("Unsupported output kind"); + } + return testing::Matcher( + new CelValueMatcherImpl(expected_val, pool, message_factory, arena)); +} + +bool ShouldRunTest(absl::string_view test_name, + const std::vector& skip_tests) { + for (const std::string& skip : skip_tests) { + if (absl::StartsWith(test_name, skip)) { + return false; + } + } + return true; +} + +absl::Status PopulateActivation( + const cel::expr::conformance::test::TestCase& test, + const InputEvaluator& input_evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + absl::string_view context_msg_type_name, google::protobuf::Arena* arena, + Activation& activation) { + if (!test.has_input_context()) { + for (const auto& [var_name, input_val] : test.input()) { + CEL_ASSIGN_OR_RETURN( + auto val, + EvaluateInputValue(input_val, input_evaluator, descriptor_pool, + message_factory, arena)); + activation.InsertOrAssignValue(var_name, std::move(val)); + } + return absl::OkStatus(); + } + + const auto& input_context = test.input_context(); + const google::protobuf::Message* context_message = nullptr; + + if (input_context.has_context_message()) { + const google::protobuf::Any& any_msg = input_context.context_message(); + const google::protobuf::Descriptor* msg_descriptor = + descriptor_pool->FindMessageTypeByName(context_msg_type_name); + if (msg_descriptor == nullptr) { + return absl::NotFoundError(absl::StrCat( + "Failed to find message descriptor for: ", context_msg_type_name)); + } + const google::protobuf::Message* prototype = + message_factory->GetPrototype(msg_descriptor); + if (prototype == nullptr) { + return absl::NotFoundError( + absl::StrCat("Failed to get prototype for: ", context_msg_type_name)); + } + auto* buf = prototype->New(arena); + if (!any_msg.UnpackTo(buf)) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to unpack context message to ", context_msg_type_name)); + } + context_message = buf; + } else if (input_context.has_context_expr() && + !context_msg_type_name.empty()) { + CEL_ASSIGN_OR_RETURN(cel::Value evaluated_val, + input_evaluator.Evaluate(input_context.context_expr(), + arena, message_factory)); + + if (!evaluated_val.IsParsedMessage()) { + return absl::InvalidArgumentError( + absl::StrCat("Context expression did not evaluate to a message: ", + input_context.context_expr())); + } + if (evaluated_val.GetParsedMessage().GetDescriptor()->full_name() != + context_msg_type_name) { + return absl::InvalidArgumentError(absl::StrCat( + "Context expression evaluated to a message of type ", + evaluated_val.GetParsedMessage().GetDescriptor()->full_name(), + " which does not match the expected type ", context_msg_type_name)); + } + context_message = static_cast( + evaluated_val.GetParsedMessage().operator->()); + } + if (context_message == nullptr) { + return absl::InvalidArgumentError( + "Failed to resolve context message for test case"); + } + + return cel::extensions::BindProtoToActivation( + *context_message, + cel::extensions::BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool, message_factory, arena, &activation); +} + +class PolicyTestSuiteRunner { + public: + PolicyTestSuiteRunner(std::string suite_name, + std::unique_ptr compiler, + std::unique_ptr runtime, + std::shared_ptr policy_source, + CelPolicyValidationResult compile_result, + std::shared_ptr pool, + std::shared_ptr message_factory, + std::shared_ptr input_evaluator, + std::string context_msg_type_name, + bool expect_compile_fail = false) + : suite_name_(std::move(suite_name)), + compiler_(std::move(compiler)), + runtime_(std::move(runtime)), + policy_source_(std::move(policy_source)), + compile_result_(std::move(compile_result)), + pool_(std::move(pool)), + message_factory_(std::move(message_factory)), + input_evaluator_(std::move(input_evaluator)), + context_msg_type_name_(std::move(context_msg_type_name)), + expect_compile_fail_(expect_compile_fail) {} + + void RunTest(const cel::expr::conformance::test::TestCase& test, + absl::string_view full_test_name) { + const auto& output = test.output(); + + if (expect_compile_fail_) { + ASSERT_FALSE(compile_result_.IsValid()) + << "Expected compilation to fail in " << full_test_name; + ASSERT_TRUE(output.has_eval_error()) + << "Expected eval_error to be present in compile error test " + << full_test_name; + std::string err_msg = compile_result_.FormatIssues(); + for (const auto& expected_err : output.eval_error().errors()) { + EXPECT_THAT(err_msg, HasSubstr(expected_err.message())) + << "Did not find expected compile time error"; + } + return; + } + + // Compilation should have succeeded for evaluation tests + ASSERT_TRUE(compile_result_.IsValid()) + << "Compilation has validation errors in " << full_test_name << ": " + << compile_result_.FormatIssues(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime_->CreateProgram(std::make_unique( + *compile_result_.GetAst()))); + + // Parse Inputs and evaluate them + google::protobuf::Arena arena; + Activation activation; + ASSERT_THAT(PopulateActivation(test, *input_evaluator_, pool_.get(), + message_factory_.get(), + context_msg_type_name_, &arena, activation), + IsOk()); + + // Evaluate Policy + auto eval_result_or = program->Evaluate(&arena, activation); + ASSERT_THAT(eval_result_or.status(), IsOk()) + << "Evaluation failed in " << full_test_name; + cel::Value actual_val = *eval_result_or; + + ASSERT_OK_AND_ASSIGN( + auto matcher, + MakeExpectedValueMatcher(output, *input_evaluator_, pool_.get(), + message_factory_.get(), &arena)); + + // Apply matcher to the output of evaluation + EXPECT_THAT(actual_val, matcher) << "Test failed: " << full_test_name; + } + + private: + std::string suite_name_; + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::shared_ptr policy_source_; + CelPolicyValidationResult compile_result_; + std::shared_ptr pool_; + std::shared_ptr message_factory_; + std::shared_ptr input_evaluator_; + std::string context_msg_type_name_; + bool expect_compile_fail_; +}; + +class CelPolicyTest : public testing::Test { + public: + explicit CelPolicyTest(std::shared_ptr runner, + cel::expr::conformance::test::TestCase test_case, + std::string full_test_name, bool skip) + : runner_(std::move(runner)), + test_case_(std::move(test_case)), + full_test_name_(std::move(full_test_name)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP() << "Skipping test: " << full_test_name_; + } + EXPECT_NO_FATAL_FAILURE(runner_->RunTest(test_case_, full_test_name_)); + } + + private: + std::shared_ptr runner_; + cel::expr::conformance::test::TestCase test_case_; + std::string full_test_name_; + bool skip_; +}; + + +absl::Status RegisterTestSuite( + const std::filesystem::path& dir_path, const std::string& suite_name, + const std::shared_ptr& input_evaluator, + const std::shared_ptr& pool, + const std::shared_ptr& message_factory, + const std::vector& skip_tests) { + // Check if the entire suite should be skipped (prefix match) + for (const auto& skip : skip_tests) { + if (suite_name == skip || + absl::StartsWith(suite_name, absl::StrCat(skip, "/"))) { + std::cout << "[ SKIPPED SUITE ] " << suite_name << std::endl; + return absl::OkStatus(); + } + } + + std::filesystem::path policy_path = dir_path / "policy.yaml"; + std::filesystem::path tests_path = dir_path / "tests.yaml"; + bool is_yaml = true; + if (!std::filesystem::exists(tests_path)) { + tests_path = dir_path / "tests.textproto"; + is_yaml = false; + } + std::filesystem::path config_path = dir_path / "config.yaml"; + + if (!std::filesystem::exists(policy_path) || + !std::filesystem::exists(tests_path)) { + // Not a valid test suite, assume it's a directory we don't care about. + return absl::OkStatus(); + } + + // Parse Environment Config + cel::Config config; + if (std::filesystem::exists(config_path)) { + std::string config_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(config_path.string(), &config_content)); + CEL_ASSIGN_OR_RETURN(config, cel::EnvConfigFromYaml(config_content)); + } + + // Enable default extensions (optional, bindings) in the config + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + + // Set up compiler & runtime environments + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler_builder, env.NewCompilerBuilder()); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + env_runtime.CreateRuntimeBuilder()); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + // Register locationCode in runtime + CEL_RETURN_IF_ERROR( + (cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("locationCode", LocationCode, + runtime_builder.function_registry()))); + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + // Parse Policy + std::string policy_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(policy_path.string(), &policy_content)); + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(policy_content, "policy.yaml")); + auto policy_source = std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse policy.yaml in ", suite_name, + "\nIssues:\n", parse_result.FormattedIssues())); + } + const CelPolicy* policy = parse_result.GetPolicy(); + + // Compile Policy (unexpected non-ok status represents a bug) + CEL_ASSIGN_OR_RETURN(CelPolicyValidationResult compile_result, + CompilePolicy(*compiler, *policy)); + + std::string tests_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(tests_path.string(), &tests_content)); + TestSuite test_suite; + if (is_yaml) { + CEL_ASSIGN_OR_RETURN(test_suite, + cel::test::ParsePolicyTestSuiteYaml(tests_content)); + } else { + if (!google::protobuf::TextFormat::ParseFromString(tests_content, &test_suite)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse text proto in ", tests_path.string())); + } + } + + auto runner = std::make_shared( + suite_name, std::move(compiler), std::move(runtime), + std::move(policy_source), std::move(compile_result), pool, + message_factory, input_evaluator, config.GetContextType(), + /*expect_compile_fail=*/absl::StrContains(suite_name, "compile_errors")); + + for (const auto& section : test_suite.sections()) { + std::string section_name = section.name(); + for (const auto& test : section.tests()) { + std::string test_name = test.name(); + std::string full_test_name = + absl::StrCat(suite_name, "/", section_name, "/", test_name); + + bool skip = !ShouldRunTest(full_test_name, skip_tests); + + testing::RegisterTest( + suite_name.c_str(), + absl::StrCat(section_name, "/", test_name).c_str(), nullptr, + test_name.c_str(), __FILE__, __LINE__, + [runner, test, full_test_name, skip]() -> CelPolicyTest* { + return new CelPolicyTest(runner, test, full_test_name, skip); + }); + } + } + return absl::OkStatus(); +} + +void RegisterAllTests() { + // cel::google3-end + std::string testdata_example_flag = absl::GetFlag(FLAGS_testdata_example); + std::vector skip_tests = absl::GetFlag(FLAGS_skip_tests); + + std::string abs_testdata_example = + cel::internal::ResolveRunfilesPath(testdata_example_flag); + ABSL_CHECK(!abs_testdata_example.empty()) + << "Could not find testdata directory: " << testdata_example_flag; + + std::shared_ptr pool = + GetSharedTestingDescriptorPool(); + auto message_factory = + std::make_shared(pool.get()); + message_factory->SetDelegateToGeneratedFactory(true); + auto evaluator_or = InputEvaluator::Create(pool); + ABSL_CHECK_OK(evaluator_or.status()) << "Failed to create input evaluator"; + std::shared_ptr evaluator = std::move(evaluator_or.value()); + + std::filesystem::path testdata_path(abs_testdata_example); + ABSL_CHECK(std::filesystem::exists(testdata_path)) + << "Testdata path does not exist: " << testdata_path; + // walk up to find 'testdata' parent. A work around to portably + // get the expected directory from bazel. + while (!absl::EndsWith(testdata_path.string(), "testdata")) { + testdata_path = testdata_path.parent_path(); + ABSL_CHECK(testdata_path.string().size() > sizeof("testdata")) + << "could not resolve testdata directory"; + } + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(testdata_path)) { + if (!entry.is_directory()) { + continue; + } + std::filesystem::path dir_path = entry.path(); + // Check if this directory has policy.yaml and tests.yaml (or + // tests.textproto) + if (std::filesystem::exists(dir_path / "policy.yaml") && + (std::filesystem::exists(dir_path / "tests.yaml") || + std::filesystem::exists(dir_path / "tests.textproto"))) { + std::string suite_name = absl::StrReplaceAll( + std::filesystem::relative(dir_path, testdata_path).string(), + {{"\\", "/"}}); + + ABSL_CHECK_OK(RegisterTestSuite(dir_path, suite_name, evaluator, pool, + message_factory, skip_tests)); + } + } +} + +} // namespace +} // namespace cel + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + cel::RegisterAllTests(); + return RUN_ALL_TESTS(); +} diff --git a/conformance/run.bzl b/conformance/run.bzl index 4fcf325c6..8faeb6c16 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators): args = [] if modern: args.append("--modern") @@ -70,18 +70,20 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests={}".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") + if enable_variadic_logical_operators: + args.append("--enable_variadic_logical_operators") return args -def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): +def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data], + env = select( { - "@platforms//os:windows": ["--skip_tests={}".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests={}".format(",".join(skip_tests))], + "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, + "//conditions:default": {"CEL_SKIP_TESTS": ",".join(skip_tests)}, }, ), data = data, @@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ tags = tags, ) -def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []): +def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False): """Generates conformance tests. Args: name: prefix for all tests + data: textproto targets describing conformance tests modern: run using modern APIs checked: whether to apply type checking - data: textproto targets describing conformance tests + select_opt: enable select optimization + dashboard: enable dashboard mode skip_tests: tests to skip in the format of the cel-spec test runner. See documentation in github.com/google/cel-spec/tests/simple/simple_test.go tags: tags added to the generated targets - dashboard: enable dashboard mode + enable_variadic_logical_operators: enable variadic logical operators """ skip_check = not checked tests = [] @@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op skip_tests = _expand_tests_to_skip(skip_tests), tags = tags, dashboard = dashboard, + enable_variadic_logical_operators = enable_variadic_logical_operators, ) native.test_suite( name = name, diff --git a/conformance/run.cc b/conformance/run.cc index d5a919d76..1be16ba60 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -42,11 +42,13 @@ #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" #include "conformance/service.h" #include "conformance/utils.h" +#include "internal/runfiles.h" #include "internal/testing.h" #include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" @@ -64,11 +66,12 @@ ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); +ABSL_FLAG(bool, enable_variadic_logical_operators, false, + "Enable parsing logical AND & OR operators as a single flat variadic " + "call."); namespace { -using ::testing::IsEmpty; - using cel::expr::conformance::test::SimpleTest; using cel::expr::conformance::test::SimpleTestFile; using google::api::expr::conformance::v1alpha1::CheckRequest; @@ -77,6 +80,7 @@ using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; +using ::testing::IsEmpty; google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); @@ -260,6 +264,8 @@ NewConformanceServiceFromFlags() { .modern = absl::GetFlag(FLAGS_modern), .recursive = absl::GetFlag(FLAGS_recursive), .select_optimization = absl::GetFlag(FLAGS_select_optimization), + .enable_variadic_logical_operators = + absl::GetFlag(FLAGS_enable_variadic_logical_operators), }); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( @@ -273,9 +279,17 @@ int main(int argc, char** argv) { { auto service = NewConformanceServiceFromFlags(); auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); + if (const char* env_skip = std::getenv("CEL_SKIP_TESTS"); + env_skip != nullptr) { + for (absl::string_view test : + absl::StrSplit(env_skip, ',', absl::SkipEmpty())) { + tests_to_skip.push_back(std::string(test)); + } + } for (int argi = 1; argi < argc; argi++) { + std::string path = cel::internal::ResolveRunfilesPath(argv[argi]); ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, - absl::string_view(argv[argi]))); + absl::string_view(path))); } } int exit_code = RUN_ALL_TESTS(); diff --git a/conformance/service.cc b/conformance/service.cc index 3edc214e6..d81200cad 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,7 +14,6 @@ #include "conformance/service.h" -#include #include #include #include @@ -31,16 +30,14 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" +#include "google/rpc/status.pb.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" @@ -48,7 +45,6 @@ #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" -#include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" @@ -72,8 +68,6 @@ #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" -#include "parser/macro.h" -#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -85,6 +79,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -106,109 +101,6 @@ namespace google::api::expr::runtime { namespace { -bool IsCelNamespace(const cel::Expr& target) { - return target.has_ident_expr() && target.ident_expr().name() == "cel"; -} - -absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& bindings_arg = args[0]; - if (!bindings_arg.has_list_expr()) { - return factory.ReportErrorAt( - bindings_arg, "cel.block requires the first arg to be a list literal"); - } - return factory.NewCall("cel.@block", args); -} - -absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& index_arg = args[0]; - if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - int64_t index = index_arg.const_expr().int_value(); - if (index < 0) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - return factory.NewIdent(absl::StrCat("@index", index)); -} - -absl::optional CelIterVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.iterVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.iterVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::optional CelAccuVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.accuVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.accuVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { - CEL_ASSIGN_OR_RETURN(auto block_macro, - cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); - CEL_ASSIGN_OR_RETURN(auto index_macro, - cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); - CEL_ASSIGN_OR_RETURN( - auto iter_var_macro, - cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); - CEL_ASSIGN_OR_RETURN( - auto accu_var_macro, - cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); - return absl::OkStatus(); -} - google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -236,13 +128,15 @@ cel::expr::Expr ExtractExpr( absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response, - bool enable_optional_syntax) { + bool enable_optional_syntax, + bool enable_variadic_logical_operators) { if (request.cel_source().empty()) { return absl::InvalidArgumentError("no source code"); } cel::ParserOptions options; options.enable_optional_syntax = enable_optional_syntax; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = enable_variadic_logical_operators; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); CEL_RETURN_IF_ERROR( @@ -250,7 +144,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); - CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, @@ -285,6 +179,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( @@ -342,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, class LegacyConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< @@ -419,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder->GetRegistry(), options)); - return absl::WrapUnique( - new LegacyConformanceServiceImpl(std::move(builder))); + return absl::WrapUnique(new LegacyConformanceServiceImpl( + std::move(builder), enable_variadic_logical_operators)); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/false); + LegacyParse(request, response, /*enable_optional_syntax=*/false, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -524,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { } private: - explicit LegacyConformanceServiceImpl( - std::unique_ptr builder) - : builder_(std::move(builder)) {} + LegacyConformanceServiceImpl(std::unique_ptr builder, + bool enable_variadic_logical_operators) + : builder_(std::move(builder)), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} std::unique_ptr builder_; + bool enable_variadic_logical_operators_; }; class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< @@ -576,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { options.max_recursion_depth = 48; } - return absl::WrapUnique(new ModernConformanceServiceImpl( - options, optimize, select_optimization)); + return absl::WrapUnique( + new ModernConformanceServiceImpl(options, optimize, select_optimization, + enable_variadic_logical_operators)); } absl::StatusOr> Setup( @@ -629,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/true); + LegacyParse(request, response, /*enable_optional_syntax=*/true, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -720,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { private: ModernConformanceServiceImpl(const RuntimeOptions& options, bool enable_optimizations, - bool enable_select_optimization) + bool enable_select_optimization, + bool enable_variadic_logical_operators) : options_(options), enable_optimizations_(enable_optimizations), - enable_select_optimization_(enable_select_optimization) {} + enable_select_optimization_(enable_select_optimization), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} static absl::StatusOr> Plan( const cel::Runtime& runtime, @@ -754,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool enable_optimizations_; bool enable_select_optimization_; + bool enable_variadic_logical_operators_; }; } // namespace @@ -766,10 +672,12 @@ absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } } diff --git a/conformance/service.h b/conformance/service.h index 2dd2abf32..8eb97296e 100644 --- a/conformance/service.h +++ b/conformance/service.h @@ -46,6 +46,7 @@ struct ConformanceServiceOptions { bool arena; bool recursive; bool select_optimization; + bool enable_variadic_logical_operators = false; }; absl::StatusOr> diff --git a/env/BUILD b/env/BUILD new file mode 100644 index 000000000..0c17d6305 --- /dev/null +++ b/env/BUILD @@ -0,0 +1,320 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "config", + srcs = [ + "config.cc", + "type_info.cc", + ], + hdrs = [ + "config.h", + "type_info.h", + ], + deps = [ + "//common:ast", + "//common:constant", + "//common:type", + "//common:type_kind", + "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env", + srcs = ["env.cc"], + hdrs = ["env.h"], + deps = [ + ":config", + "//checker:type_checker_builder", + "//common:constant", + "//common:container", + "//common:decl", + "//common:signature", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//env/internal:ext_registry", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_runtime", + srcs = ["env_runtime.cc"], + hdrs = ["env_runtime.h"], + deps = [ + ":config", + "//env/internal:runtime_ext_registry", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_std_extensions", + srcs = ["env_std_extensions.cc"], + hdrs = ["env_std_extensions.h"], + deps = [ + ":env", + "//checker:optional", + "//compiler:optional", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:proto_ext", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + ], +) + +cc_library( + name = "env_yaml", + srcs = ["env_yaml.cc"], + hdrs = ["env_yaml.h"], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + "//common:ast", + "//common:constant", + "//common:signature", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@yaml-cpp", + ], +) + +cc_library( + name = "runtime_std_extensions", + srcs = ["runtime_std_extensions.cc"], + hdrs = ["runtime_std_extensions.h"], + deps = [ + ":env_runtime", + "//checker:optional", + "//env/internal:runtime_ext_registry", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "config_test", + srcs = ["config_test.cc"], + deps = [ + ":config", + "//common:constant", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "type_info_test", + srcs = ["type_info_test.cc"], + deps = [ + ":config", + "//common:type", + "//common:type_proto", + "//common/ast:metadata", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_test", + srcs = ["env_test.cc"], + deps = [ + ":config", + ":env", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:constant", + "//common:decl", + "//common:expr", + "//common:type", + "//common:value", + "//compiler", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_runtime_test", + srcs = ["env_runtime_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":env_yaml", + ":runtime_std_extensions", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//common:value", + "//compiler", + "//extensions:math_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_std_extensions_test", + srcs = ["env_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_std_extensions", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "env_yaml_test", + srcs = ["env_yaml_test.cc"], + deps = [ + ":config", + ":env_yaml", + "//common:constant", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "runtime_std_extensions_test", + srcs = ["runtime_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":runtime_std_extensions", + "//checker:optional", + "//checker:validation_result", + "//common:ast", + "//common:value", + "//compiler", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:strings", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/config.cc b/env/config.cc new file mode 100644 index 000000000..202a607bf --- /dev/null +++ b/env/config.cc @@ -0,0 +1,196 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +const char* ConstantKindToTypeName(const ConstantKind& kind) { + return std::visit(absl::Overload{ + [](const std::monostate& arg) { return "dyn"; }, + [](const std::nullptr_t& arg) { return "null"; }, + [](bool arg) { return "bool"; }, + [](int64_t arg) { return "int"; }, + [](uint64_t arg) { return "uint"; }, + [](double arg) { return "double"; }, + [](const BytesConstant& arg) { return "bytes"; }, + [](const StringConstant& arg) { return "string"; }, + [](absl::Duration arg) { return "duration"; }, + [](absl::Time arg) { return "timestamp"; }, + }, + kind); +} +} // namespace + +absl::Status Config::AddExtensionConfig(std::string name, int version) { + for (const ExtensionConfig& extension_config : extension_configs_) { + if (extension_config.name == name) { + if (extension_config.version == version) { + return absl::OkStatus(); + } + std::string version_str; + if (version == ExtensionConfig::kLatest) { + version_str = "'latest'"; + } else { + version_str = absl::StrCat(version); + } + return absl::AlreadyExistsError(absl::StrCat( + "Extension '", name, "' version ", extension_config.version, + " is already included. Cannot also include version ", version_str)); + } + } + extension_configs_.push_back( + ExtensionConfig{.name = std::move(name), .version = version}); + return absl::OkStatus(); +} + +absl::Status Config::SetStandardLibraryConfig( + const Config::StandardLibraryConfig& standard_library_config) { + if (!standard_library_config.included_macros.empty() && + !standard_library_config.excluded_macros.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded macros."); + } + + if (!standard_library_config.included_functions.empty() && + !standard_library_config.excluded_functions.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded functions."); + } + + absl::flat_hash_set included_function_names; + for (const auto& function : standard_library_config.included_functions) { + if (function.second.empty()) { + included_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.included_functions) { + if (included_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot include function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + absl::flat_hash_set excluded_function_names; + for (const auto& function : standard_library_config.excluded_functions) { + if (function.second.empty()) { + excluded_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.excluded_functions) { + if (excluded_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot exclude function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + standard_library_config_ = standard_library_config; + return absl::OkStatus(); +} + +absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { + for (const VariableConfig& existing_variable_config : variable_configs_) { + if (existing_variable_config.name == variable_config.name) { + return absl::AlreadyExistsError(absl::StrCat( + "Variable '", variable_config.name, "' is already included.")); + } + } + if (variable_config.value.has_value()) { + absl::string_view constant_type_name = + ConstantKindToTypeName(variable_config.value.kind()); + if (constant_type_name != variable_config.type_info.name) { + return absl::InvalidArgumentError( + absl::StrCat("Variable '", variable_config.name, "' has type ", + variable_config.type_info.name, + " but is assigned a constant value of type ", + constant_type_name, ".")); + } + } + variable_configs_.push_back(variable_config); + return absl::OkStatus(); +} + +absl::Status Config::ValidateFunctionConfig( + const FunctionConfig& function_config) { + for (const auto& overload : function_config.overload_configs) { + if (overload.is_member_function && overload.parameters.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Function '", function_config.name, "' overload '", + overload.overload_id, + "' is marked as a member function but has no parameters. Member " + "functions must have at least one parameter (target).")); + } + } + return absl::OkStatus(); +} + +absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { + CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); + function_configs_.push_back(function_config); + return absl::OkStatus(); +} + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config) { + os << "StandardLibraryConfig("; + if (!config.included_macros.empty()) { + os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); + } + if (!config.excluded_macros.empty()) { + os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); + } + if (!config.included_functions.empty()) { + os << "\n included_functions=" + << absl::StrJoin(config.included_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + if (!config.excluded_functions.empty()) { + os << "\n excluded_functions=" + << absl::StrJoin(config.excluded_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + os << "\n)"; + return os; +} + +} // namespace cel diff --git a/env/config.h b/env/config.h new file mode 100644 index 000000000..68e4a1dd9 --- /dev/null +++ b/env/config.h @@ -0,0 +1,173 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ +#define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel { + +class Config { + public: + void SetName(std::string name) { name_ = std::move(name); } + std::string GetName() const { return name_; } + + void SetContextType(std::string context_type) { + context_type_ = std::move(context_type); + } + std::string GetContextType() const { return context_type_; } + + struct ContainerConfig { + std::string name; + std::vector abbreviations; + struct Alias { + std::string alias; + std::string qualified_name; + }; + std::vector aliases; + + bool IsEmpty() const { + return name.empty() && abbreviations.empty() && aliases.empty(); + } + }; + + void SetContainerConfig(ContainerConfig container_config) { + container_config_ = std::move(container_config); + } + + const ContainerConfig& GetContainerConfig() const { + return container_config_; + } + + struct ExtensionConfig { + static constexpr int kLatest = std::numeric_limits::max(); + + std::string name; + int version = kLatest; + }; + + absl::Status AddExtensionConfig(std::string name, + int version = ExtensionConfig::kLatest); + + const std::vector& GetExtensionConfigs() const { + return extension_configs_; + } + + struct StandardLibraryConfig { + // Exclude the entire standard library. + bool disable = false; + + // Exclude all standard library macros. + bool disable_macros = false; + + // Either included or excluded macros can be set, not both. If neither are + // set, all standard library macros are included. + absl::flat_hash_set included_macros; + absl::flat_hash_set excluded_macros; + + // Sets of pairs of function name and overload id to include or exclude. + // Either included or excluded functions can be set, not both. If neither + // are set, all standard library functions are included. + // If an overload is specified, only that overload is included or excluded. + // If no overload is specified (empty second element of pair), all overloads + // are included or excluded. + absl::flat_hash_set> included_functions; + absl::flat_hash_set> excluded_functions; + + bool IsEmpty() const { + return !disable && !disable_macros && included_macros.empty() && + excluded_macros.empty() && included_functions.empty() && + excluded_functions.empty(); + } + }; + + absl::Status SetStandardLibraryConfig( + const StandardLibraryConfig& standard_library_config); + + const StandardLibraryConfig& GetStandardLibraryConfig() const { + return standard_library_config_; + } + + struct TypeInfo { + std::string name; + std::vector params; + bool is_type_param = false; + }; + + struct VariableConfig { + std::string name; + std::string description; + TypeInfo type_info; + Constant value; + }; + + // Adds a variable config to the environment. The variable name and type + // are used by the CEL type checker to validate expressions. The variable + // value is used as an input value at runtime. + // + // Returns an error if a variable with the same name already exists, or if the + // type of the constant value does not match the specified type. + absl::Status AddVariableConfig(const VariableConfig& variable_config); + + const std::vector& GetVariableConfigs() const { + return variable_configs_; + } + + struct FunctionOverloadConfig { + std::string overload_id; + std::vector examples; + bool is_member_function = false; + std::vector parameters; + TypeInfo return_type; + }; + + struct FunctionConfig { + std::string name; + std::string description; + std::vector overload_configs; + }; + + absl::Status AddFunctionConfig(const FunctionConfig& function_config); + + const std::vector& GetFunctionConfigs() const { + return function_configs_; + } + + private: + std::string name_; + std::string context_type_; + ContainerConfig container_config_; + std::vector extension_configs_; + StandardLibraryConfig standard_library_config_; + std::vector variable_configs_; + std::vector function_configs_; + + absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); +}; + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ diff --git a/env/config_test.cc b/env/config_test.cc new file mode 100644 index 000000000..8cfc3cf7f --- /dev/null +++ b/env/config_test.cc @@ -0,0 +1,277 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(EnvConfigTest, ExtensionConfigs) { + Config config; + ASSERT_THAT( + config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), + IsOk()); + ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvConfigTest, ExtensionConfigConflict) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 3), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::string expected_error; // Empty if no error is expected. +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + + Config config; + absl::Status status = + config.SetStandardLibraryConfig(param.standard_library_config); + if (param.expected_error.empty()) { + EXPECT_THAT(status, IsOk()); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + ::testing::Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_macros = {"all", "exists"}, + .excluded_macros = {"map", "filter"}, + }, + .expected_error = "Cannot set both included and excluded macros.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add(int,int)'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add(int,int)'", + })); + +TEST(VariableConfigTest, VariableConfig) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = + { + .name = "mytype", + .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, + }, + }; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + + ASSERT_EQ(config.GetVariableConfigs().size(), 1); + const auto& added_config = config.GetVariableConfigs()[0]; + EXPECT_EQ(added_config.type_info.name, "mytype"); + ASSERT_THAT(added_config.type_info.params.size(), 2); + EXPECT_EQ(added_config.type_info.params[0].name, "int"); + EXPECT_FALSE(added_config.type_info.params[0].is_type_param); + EXPECT_EQ(added_config.type_info.params[1].name, "A"); + EXPECT_TRUE(added_config.type_info.params[1].is_type_param); +} + +TEST(VariableConfigTest, VariableConfigConflict) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + .value = Constant(StringConstant("hello")), + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Variable 'test' has type int but is assigned " + "a constant value of type string."))); +} + +TEST(FunctionConfigTest, FunctionConfig) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.description = "Ultimate test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_with_pill", + .examples = {"oracle.isTheOne('Neo', RED)"}, + .is_member_function = true, + .parameters = {{.name = "string"}, {.name = "Choice"}}, + .return_type = {.name = "bool"}, + }); + ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); + ASSERT_EQ(config.GetFunctionConfigs().size(), 1); + const auto& added_config = config.GetFunctionConfigs()[0]; + EXPECT_EQ(added_config.name, "test"); + EXPECT_EQ(added_config.description, "Ultimate test"); + EXPECT_EQ(added_config.overload_configs.size(), 1); + + const auto& overload_config = added_config.overload_configs[0]; + EXPECT_EQ(overload_config.overload_id, "test_with_pill"); + EXPECT_THAT(overload_config.examples, + ElementsAre("oracle.isTheOne('Neo', RED)")); + EXPECT_TRUE(overload_config.is_member_function); + EXPECT_THAT( + overload_config.parameters, + ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), + Field(&Config::TypeInfo::is_type_param, false)), + AllOf(Field(&Config::TypeInfo::name, "Choice"), + Field(&Config::TypeInfo::is_type_param, false)))); + EXPECT_THAT(overload_config.return_type, + Field(&Config::TypeInfo::name, "bool")); +} + +TEST(FunctionConfigTest, FunctionConfigInvalidMember) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_member_no_params", + .is_member_function = true, + .parameters = {}, + }); + EXPECT_THAT(config.AddFunctionConfig(function_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("is marked as a member function but has no " + "parameters"))); +} + +} // namespace +} // namespace cel diff --git a/env/env.cc b/env/env.cc new file mode 100644 index 000000000..85c5139da --- /dev/null +++ b/env/env.cc @@ -0,0 +1,222 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/signature.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "env/config.h" +#include "env/type_info.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, + absl::string_view macro) { + if (config.disable_macros) { + return false; + } + if (config.excluded_macros.contains(macro)) { + return false; + } + if (!config.included_macros.empty() && + !config.included_macros.contains(macro)) { + return false; + } + return true; +} + +bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, + absl::string_view function, + const OverloadDecl& overload) { + if (config.excluded_functions.empty() && config.included_functions.empty()) { + return true; + } + + if (!config.excluded_functions.empty()) { + if (config.excluded_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + absl::StatusOr signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + if (signature.ok() && config.excluded_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return false; + } + } + + if (!config.included_functions.empty()) { + if (config.included_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.included_functions.contains( + std::make_pair(std::string(function), ""))) { + return true; + } + // Ok to call MakeOverloadSignature() again, because in practice either + // included or excluded functions may be specified, but not both. + absl::StatusOr signature = + MakeOverloadSignature(function, overload.args(), overload.member()); + if (signature.ok() && config.included_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return true; + } + return false; + } + + return true; // Never reached +} + +absl::StatusOr MakeStdlibSubset( + const Config::StandardLibraryConfig& standard_library_config) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + // Capturing by reference is safe. The returned CompilerLibrarySubset's + // callbacks are only used during CompilerBuilder::Build() to configure + // contributed functions and macros. They are not retained by the constructed + // Compiler instance. The referenced config outlives the Build() call. + subset.should_include_macro = [&standard_library_config](const Macro& macro) { + return ShouldIncludeMacro(standard_library_config, macro.function()); + }; + subset.should_include_overload = [&standard_library_config]( + absl::string_view function, + const OverloadDecl& overload) { + return ShouldIncludeFunction(standard_library_config, function, overload); + }; + return subset; +} + +absl::StatusOr FunctionConfigToFunctionDecl( + const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + FunctionDecl function_decl; + function_decl.set_name(function_config.name); + for (const Config::FunctionOverloadConfig& overload_config : + function_config.overload_configs) { + OverloadDecl overload_decl; + overload_decl.set_id(overload_config.overload_id); + overload_decl.set_member(overload_config.is_member_function); + for (const Config::TypeInfo& parameter : overload_config.parameters) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(parameter, descriptor_pool, arena)); + overload_decl.mutable_args().push_back(parameter_type); + } + CEL_ASSIGN_OR_RETURN( + Type return_type, + TypeInfoToType(overload_config.return_type, descriptor_pool, arena)); + overload_decl.set_result(return_type); + CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); + } + return function_decl; +} + +} // namespace + +Env::Env() { + compiler_options_.parser_options.enable_quoted_identifiers = true; + compiler_options_.adapt_parser_errors = true; +} + +absl::StatusOr> Env::NewCompilerBuilder() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); + cel::TypeCheckerBuilder& checker_builder = + compiler_builder->GetCheckerBuilder(); + + ExpressionContainer container; + CEL_RETURN_IF_ERROR( + container.SetContainer(config_.GetContainerConfig().name)); + for (const auto& abbr : config_.GetContainerConfig().abbreviations) { + CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); + } + + if (!config_.GetContextType().empty()) { + CEL_RETURN_IF_ERROR( + checker_builder.AddContextDeclaration(config_.GetContextType())); + } + for (const auto& alias : config_.GetContainerConfig().aliases) { + CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); + } + checker_builder.SetExpressionContainer(std::move(container)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(StandardCompilerLibrary())); + CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, + MakeStdlibSubset(config_.GetStandardLibraryConfig())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); + } + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_ASSIGN_OR_RETURN(CompilerLibrary library, + extension_registry_.GetCompilerLibrary( + extension_config.name, extension_config.version)); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); + } + + google::protobuf::Arena* arena = checker_builder.arena(); + for (const Config::VariableConfig& variable_config : + config_.GetVariableConfigs()) { + VariableDecl variable_decl; + variable_decl.set_name(variable_config.name); + CEL_ASSIGN_OR_RETURN(Type type, + TypeInfoToType(variable_config.type_info, + descriptor_pool_.get(), arena)); + variable_decl.set_type(type); + if (variable_config.value.has_value()) { + variable_decl.set_value(variable_config.value); + } + CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); + } + + for (const Config::FunctionConfig& function_config : + config_.GetFunctionConfigs()) { + CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, + FunctionConfigToFunctionDecl(function_config, arena, + descriptor_pool_.get())); + CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); + } + + return compiler_builder; +} + +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler_builder, + NewCompilerBuilder()); + return compiler_builder->Build(); +} +} // namespace cel diff --git a/env/env.h b/env/env.h new file mode 100644 index 000000000..9830b67d7 --- /dev/null +++ b/env/env.h @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ + +#include + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/internal/ext_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Env class establishes the environment for compiling CEL expressions. +// +// It is used to configure compiler options, extension functions, and other +// customizable CEL features. +class Env { + public: + Env(); + + // Registers a `CompilerLibrary` with the environment. Note that the library + // does not automatically get added to a `Compiler`. `NewCompiler` relies + // on `Config` to determine which libraries to load. + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + extension_registry_.RegisterCompilerLibrary(name, alias, version, + std::move(library_factory)); + } + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + const google::protobuf::DescriptorPool* GetDescriptorPool() const { + return descriptor_pool_.get(); + } + + void SetConfig(const Config& config) { config_ = config; } + + absl::StatusOr> NewCompilerBuilder(); + + // Shortcut for NewCompilerBuilder() followed by Build(). + absl::StatusOr> NewCompiler(); + + private: + cel::env_internal::ExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + CompilerOptions compiler_options_; + Config config_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ diff --git a/env/env_runtime.cc b/env/env_runtime.cc new file mode 100644 index 000000000..33e0747cc --- /dev/null +++ b/env/env_runtime.cc @@ -0,0 +1,89 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" + +namespace cel { + +void EnvRuntime::RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback) { + extension_registry_.AddFunctionRegistration( + name, alias, version, std::move(function_registration_callback)); +} + +absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { + const std::vector& extension_configs = + config_.GetExtensionConfigs(); + const Config::ExtensionConfig* optional_extension_config = nullptr; + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (extension_config.name == "optional") { + optional_extension_config = &extension_config; + runtime_options_.enable_qualified_type_identifiers = true; + break; + } + } + + CEL_ASSIGN_OR_RETURN( + RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR(RegisterStandardFunctions( + runtime_builder.function_registry(), runtime_options_)); + } + + // Register optional extension functions first, because other extensions + // depend on it (e.g. regex). + if (optional_extension_config != nullptr) { + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, optional_extension_config->name, + optional_extension_config->version)); + } + + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (&extension_config == optional_extension_config) { + continue; + } + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, extension_config.name, + extension_config.version)); + } + return runtime_builder; +} + +absl::StatusOr> EnvRuntime::NewRuntime() { + CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); + return std::move(runtime_builder).Build(); +} + +} // namespace cel diff --git a/env/env_runtime.h b/env/env_runtime.h new file mode 100644 index 000000000..63473c295 --- /dev/null +++ b/env/env_runtime.h @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "env/config.h" +#include "env/internal/runtime_ext_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// EnvRuntime class establishes the environment for creating CEL runtimes. +// +// It is used to configure runtime options, extension functions, and other +// customizable CEL runtime features. +// +// EnvRuntime is separate from Env to avoid a dependency on the compiler for +// binaries that only use the runtime. +// +// Even though EnvRuntime is separate from Env, the Config and DescriptorPool +// passed to EnvRuntime are expected to be the same as those passed to Env for +// compilation. This ensures consistency between compilation and runtime. +class EnvRuntime { + public: + // Registers a function registration callback for an extension. The callback + // is invoked when a runtime is created, if the corresponding functions are + // enabled in the runtime config. + void RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback); + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + void SetConfig(const Config& config) { config_ = config; } + + RuntimeOptions& mutable_runtime_options() { return runtime_options_; } + + absl::StatusOr CreateRuntimeBuilder(); + + // Shortcut for CreateRuntimeBuilder() followed by Build(). + absl::StatusOr> NewRuntime(); + + private: + cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { + return extension_registry_; + } + + friend void RegisterStandardExtensions(EnvRuntime& env_runtime); + + cel::env_internal::RuntimeExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + Config config_; + RuntimeOptions runtime_options_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc new file mode 100644 index 000000000..47892772c --- /dev/null +++ b/env/env_runtime_test.cc @@ -0,0 +1,199 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/math_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string config_yaml; + std::string expr; + bool expected_to_fail = false; +}; + +class EnvRuntimeTest : public testing::TestWithParam {}; + +TEST_P(EnvRuntimeTest, EndToEnd) { + const TestCase& param = GetParam(); + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + std::unique_ptr ast; + if (!param.expected_to_fail) { + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); + } else { + // Bypass type checking to allow compilation to succeed since we expect the + // runtime to fail. + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource(param.expr, "")); + ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); + } + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + + absl::StatusOr> program_or = + runtime->CreateProgram(std::move(ast)); + if (param.expected_to_fail) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr; + return; + } + + ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; + + std::unique_ptr program = *std::move(program_or); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; +} + +std::vector GetEnvRuntimeTestCases() { + return { + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + - name: "optional" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + )yaml", + .expr = "1 + 2 == 3", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "1 + 2 == 3", + .expected_to_fail = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, + ValuesIn(GetEnvRuntimeTestCases())); + +TEST(EnvRuntimeTest, RegisterExtensionFunctions) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("math.sqrt(4) == 2.0")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + env_runtime.RegisterExtensionFunctions( + "cel.lib.math", "math", 2, + [](cel::RuntimeBuilder& runtime_builder, + const cel::RuntimeOptions& opts) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), opts, 2); + }); + env_runtime.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()); +} +} // namespace +} // namespace cel diff --git a/env/env_std_extensions.cc b/env/env_std_extensions.cc new file mode 100644 index 000000000..f2041b979 --- /dev/null +++ b/env/env_std_extensions.cc @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include "checker/optional.h" +#include "compiler/optional.h" +#include "env/env.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/proto_ext.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" + +namespace cel { + +void RegisterStandardExtensions(Env& env) { + env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { + return extensions::BindingsCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { + return extensions::EncodersCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.lists", "lists", version, + [version]() { return extensions::ListsCompilerLibrary(version); }); + } + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.math", "math", version, + [version]() { return extensions::MathCompilerLibrary(version); }); + } + for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { + env.RegisterCompilerLibrary("optional", "", version, [version]() { + return OptionalCompilerLibrary(version); + }); + } + env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { + return extensions::ProtoExtCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { + return extensions::SetsCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.strings", "strings", version, + [version]() { return extensions::StringsCompilerLibrary(version); }); + } + env.RegisterCompilerLibrary( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + []() { return extensions::ComprehensionsV2CompilerLibrary(); }); + env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { + return extensions::RegexExtCompilerLibrary(); + }); +} + +} // namespace cel diff --git a/env/env_std_extensions.h b/env/env_std_extensions.h new file mode 100644 index 000000000..79cf37dbf --- /dev/null +++ b/env/env_std_extensions.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ + +#include "env/env.h" + +namespace cel { + +// Registers the standard CEL extensions with the given environment. This makes +// them available, but does not enable them. See Env::Config for how to enable +// extensions. +// +// Extensions are registered under the following names: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// - cel.lib.ext.regex (alias: "regex") +void RegisterStandardExtensions(Env& env); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ diff --git a/env/env_std_extensions_test.cc b/env/env_std_extensions_test.cc new file mode 100644 index 000000000..7d9572cc0 --- /dev/null +++ b/env/env_std_extensions_test.cc @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::TestWithParam; + +struct TestCase { + std::string extension; + std::string expr; +}; + +class EnvStdExtensions : public testing::TestWithParam {}; + +TEST_P(EnvStdExtensions, RegistrationTest) { + const TestCase& param = GetParam(); + + Env env; + RegisterStandardExtensions(env); + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); + ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr + << " but got: " << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + RegistrationTest, EnvStdExtensions, + ::testing::Values( + TestCase{ + .extension = "cel.lib.ext.bindings", // official name + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "bindings", // alias + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "encoders", + .expr = "base64.encode(b'hello')", + }, + TestCase{ + .extension = "lists", + .expr = "[1, 2, 3].sort()", + }, + TestCase{ + .extension = "lists", + .expr = "['a'].sortBy(e, e)", + }, + TestCase{ + .extension = "math", + .expr = "math.sqrt(-1)", + }, + TestCase{ + .extension = "optional", + .expr = "[1, 2].first()", + }, + TestCase{ + .extension = "optional", + .expr = "[0][?1]", // optional syntax auto-enabled + }, + TestCase{ + .extension = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension = "strings", + .expr = "'foo'.reverse()", + }, + TestCase{ + .extension = "two-var-comprehensions", + .expr = "[1, 2, 3, 4].all(i, v, i < v)", + }, + TestCase{ + .extension = "regex", + .expr = "regex.replace('abc', '$', '_end')", + })); + +} // namespace +} // namespace cel diff --git a/env/env_test.cc b/env/env_test.cc new file mode 100644 index 000000000..00143a857 --- /dev/null +++ b/env/env_test.cc @@ -0,0 +1,666 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::testing::Values; +using ::testing::ValuesIn; + +Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { + return factory.NewStringConst("Hello"); +} + +class TestLibrary : public CompilerLibrary { + public: + explicit TestLibrary(int version) + : CompilerLibrary( + "testlib", + [version](ParserBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto macro1, + cel::Macro::Global("testMacro1", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto macro2, + cel::Macro::Global("testMacro2", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro2)); + } + return status; + }, + [version](TypeCheckerBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto func1, cel::MakeFunctionDecl( + "testFunc1", MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto func2, + cel::MakeFunctionDecl("testFunc2", + MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func2)); + } + return status; + }) {}; +}; + +absl::StatusOr CompileAndEvalExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); + if (compiler == nullptr) { + return absl::InternalError("Failed to create compiler"); + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); + if (!result.GetIssues().empty()) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::RuntimeOptions opts; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + rt_builder, cel::ReferenceResolverEnabled::kAlways)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + if (runtime == nullptr) { + return absl::InternalError("Failed to create runtime"); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); + if (ast == nullptr) { + return absl::InternalError("Failed to create AST"); + } + google::protobuf::Arena arena; + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + if (program == nullptr) { + return absl::InternalError("Failed to create program"); + } + CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); + return value; +} + +absl::StatusOr CompileAndEvalBooleanExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); + return value.GetBool(); +} + +class LibraryConfigTest : public testing::Test { + protected: + void SetUp() override { + env_.RegisterCompilerLibrary("testlib", "ml", 1, + []() { return TestLibrary(1); }); + env_.RegisterCompilerLibrary("testlib", "ml", 2, + []() { return TestLibrary(2); }); + env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + } + + Env env_; +}; + +TEST_F(LibraryConfigTest, DefaultVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), IsEmpty()); + EXPECT_THAT(result4.GetIssues(), IsEmpty()); +} + +TEST_F(LibraryConfigTest, SpecificVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testMacro2'")))); + EXPECT_THAT(result4.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testFunc2'")))); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::vector expected_valid_expressions; + std::vector expected_invalid_expressions; +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), + IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + for (const std::string& expr : param.expected_valid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), IsEmpty()) + << "With config: " << param.standard_library_config + << ", expected no issues for expr: " << expr + << " but got: " << result1.FormatError(); + } + for (const std::string& expr : param.expected_invalid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) + << "With config: " << param.standard_library_config + << ", expected compilation error for expr: " << expr << " but got: \'" + << result1.FormatError() << "\'"; + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + .expected_valid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable = true}, + .expected_invalid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable_macros = true}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, + .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "_+_(bytes,bytes)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}, + {"_+_", "_+_(string,string)"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "add_bytes"}, + {"_+_", "add_list"}, + {"_+_", "add_string"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_functions = {{"_+_", ""}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "_+_(int,int)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + })); + +TEST(ContainerConfigTest, ContainerConfig) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAbbreviations) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .abbreviations = {"cel.expr.conformance.proto2.TestAllTypes"}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAliases) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .aliases = { + {.alias = "MyTestType", + .qualified_name = "cel.expr.conformance.proto2.TestAllTypes"}}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("MyTestType{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContextVariableConfigTest, Basic) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContextType("cel.expr.conformance.proto3.TestAllTypes"); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + // Top-level fields of TestAllTypes like "single_int32" should resolve + // successfully. + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("single_int32 > 10")); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto result_invalid, + compiler->Compile("non_existent_field > 10")); + EXPECT_THAT(result_invalid.GetIssues(), Not(IsEmpty())); +} + +struct VariableConfigWithValueTestCase { + Config::VariableConfig variable_config; + std::string validate_type_expr; + std::string validate_value_expr; +}; + +class VariableConfigWithValueTest + : public testing::TestWithParam {}; + +TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { + const VariableConfigWithValueTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN( + bool type_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_type_expr)); + ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; + if (!param.validate_value_expr.empty()) { + ASSERT_OK_AND_ASSIGN( + bool value_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_value_expr)); + ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; + } +} + +Config::VariableConfig MakeConstant( + absl::string_view variable_name, absl::string_view type_name, + absl::AnyInvocable setter) { + Config::VariableConfig variable_config; + variable_config.name = variable_name; + Constant c; + setter(c); + variable_config.type_info.name = type_name; + variable_config.value = c; + return variable_config; +} + +std::vector +GetVariableConfigWithValueTestCases() { + return { + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "null", [](auto& c) { c.set_null_value(nullptr); }), + .validate_type_expr = "type(x) == type(null)", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "bool", [](auto& c) { c.set_bool_value(true); }), + .validate_type_expr = "type(x) == bool", + .validate_value_expr = "x == true", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "int", [](Constant& c) { c.set_int_value(42); }), + .validate_type_expr = "type(x) == int", + .validate_value_expr = "x == 42", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "uint", [](Constant& c) { c.set_uint_value(777); }), + .validate_type_expr = "type(x) == uint", + .validate_value_expr = "x == 777u", + }, + VariableConfigWithValueTestCase{ + .variable_config = + MakeConstant("x", "double", + [](Constant& c) { c.set_double_value(1.0 / 3.0); }), + .validate_type_expr = "type(x) == double", + .validate_value_expr = "x > 0.333 && x < 0.334", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant("x", "bytes", + [](Constant& c) { + c.set_bytes_value(absl::string_view( + "\xff\x00\x01", 3)); + }), + .validate_type_expr = "type(x) == bytes", + .validate_value_expr = "x == b'\\xff\\x00\\x01'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "string", [](Constant& c) { c.set_string_value("hello"); }), + .validate_type_expr = "type(x) == string", + .validate_value_expr = "x == 'hello'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "timestamp", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); + }), + .validate_type_expr = + "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", + .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "duration", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3)); + }), + .validate_type_expr = "type(x) == type(duration('1h2m3s'))", + .validate_value_expr = "x == duration('1h2m3s')", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, + ValuesIn(GetVariableConfigWithValueTestCases())); + +struct FunctionConfigTestCase { + Config::FunctionConfig function_config; + std::vector variable_configs; + std::string expr; + std::string expected_error; +}; + +class FunctionConfigTest + : public testing::TestWithParam {}; + +TEST_P(FunctionConfigTest, FunctionConfig) { + const FunctionConfigTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + for (const Config::VariableConfig& variable_config : param.variable_configs) { + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + } + ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + if (param.expected_error.empty()) { + EXPECT_TRUE(result.GetIssues().empty()) + << " expr: " << param.expr << " error: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + UnorderedElementsAre(Property(&TypeCheckIssue::message, + HasSubstr(param.expected_error)))) + << " expr: " << param.expr << " error: " << result.FormatError(); + } +} + +std::vector GetFunctionConfigTestCases() { + return {{ + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(int,int)", + .examples = {"add(1, 2) -> 3"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "add(1, 2)", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 3"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "1.add(2) == 3", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(string,string)", + .examples = + {"add('hello', 'world') -> 'hello world'"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "add('hello', 'world')", + .expected_error = "found no matching overload for 'add' applied to " + "'(string, string)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 'three'"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "1.add(2) == 3", + .expected_error = "found no matching overload for '_==_' applied to " + "'(string, int)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "double"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Matching opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "int"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Mismatched opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + .expected_error = "found no matching overload for 'sum' applied to " + "'(collection(double))'", + }, + }}; +} + +INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, + ::testing::ValuesIn(GetFunctionConfigTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_yaml.cc b/env/env_yaml.cc new file mode 100644 index 000000000..281cf3ff1 --- /dev/null +++ b/env/env_yaml.cc @@ -0,0 +1,1322 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/ast.h" +#include "common/constant.h" +#include "common/signature.h" +#include "env/config.h" +#include "env/type_info.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "yaml-cpp/emitter.h" +#include "yaml-cpp/emittermanip.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/mark.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +namespace { + +std::string FormatYamlErrorMessage(absl::string_view yaml, + absl::string_view error, + const YAML::Mark& mark) { + if (mark.is_null()) { + return std::string(error); + } + std::string message; + absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, + "\n|"); + size_t start = mark.pos - mark.column; + size_t end = yaml.find('\n', mark.pos); + if (end == std::string::npos) { + end = yaml.size(); + } + + absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", + std::string(mark.column, ' '), "^"); + + return message; +} + +absl::StatusOr LoadYaml(const std::string& yaml) { + try { + return YAML::Load(yaml); + } catch (YAML::ParserException& e) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, e.msg, e.mark)); + } +} + +absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, + absl::string_view error) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, error, node.Mark())); +} + +std::string GetString(absl::string_view yaml, const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return ""; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return ""; + } +} + +bool IsBinary(const YAML::Node& node) { + return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; +} + +absl::StatusOr GetBinary(absl::string_view yaml, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { + return ""; + } + std::string binary; + // Instead of using the YAML::Binary type, we use absl::Base64Unescape + // because YAML::Binary is lenient to Base64 decoding errors. + if (absl::Base64Unescape(GetString(yaml, node), &binary)) { + return binary; + } else { + return YamlError(yaml, node, + absl::StrCat("Node '", GetString(yaml, node), + "' is not a valid Base64 encoded binary")); + } +} + +absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return false; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + return YamlError(yaml, node, + absl::StrCat("Node '", key, "' is not a boolean")); + } +} + +// Returns the key in the map `node` that has the given `value_node` as its +// value. If no such key exists, returns `value_node` itself. +YAML::Node GetContextNodeForKeyValue(const YAML::Node& node, + const YAML::Node& value_node) { + for (const auto& kv : node) { + if (kv.second.IsDefined() && kv.second.is(value_node)) { + return kv.first; + } + } + return value_node; +} + +absl::Status ParseName(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node name = root["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' is not a string"); + } + config.SetName(GetString(yaml, name)); + } + return absl::OkStatus(); +} + +absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node container = root["container"]; + if (!container.IsDefined()) { + return absl::OkStatus(); + } + + if (container.IsScalar()) { + config.SetContainerConfig({.name = GetString(yaml, container)}); + return absl::OkStatus(); + } + + if (!container.IsMap()) { + return YamlError(yaml, container, + "Node 'container' is neither a string nor a map"); + } + + Config::ContainerConfig container_config; + + const YAML::Node name = container["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' in container is not a string"); + } + container_config.name = GetString(yaml, name); + } + + const YAML::Node abbreviations = container["abbreviations"]; + if (abbreviations.IsDefined()) { + if (!abbreviations.IsSequence()) { + return YamlError(yaml, abbreviations, + "Node 'abbreviations' is not a sequence"); + } + for (const YAML::Node& abbr : abbreviations) { + if (!abbr.IsScalar()) { + return YamlError(yaml, abbr, "Abbreviation is not a string"); + } + container_config.abbreviations.push_back(GetString(yaml, abbr)); + } + } + + const YAML::Node aliases = container["aliases"]; + if (aliases.IsDefined()) { + if (!aliases.IsSequence()) { + return YamlError(yaml, aliases, "Node 'aliases' is not a sequence"); + } + for (const YAML::Node& alias_node : aliases) { + if (!alias_node.IsMap()) { + return YamlError(yaml, alias_node, "Alias entry is not a map"); + } + const YAML::Node alias_key = alias_node["alias"]; + const YAML::Node qualified_name_key = alias_node["qualified_name"]; + + if (!alias_key.IsDefined() || !alias_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'alias' string"); + } + if (!qualified_name_key.IsDefined() || !qualified_name_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'qualified_name' string"); + } + + container_config.aliases.push_back( + {.alias = GetString(yaml, alias_key), + .qualified_name = GetString(yaml, qualified_name_key)}); + } + } + + config.SetContainerConfig(std::move(container_config)); + return absl::OkStatus(); +} + +absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node extensions = root["extensions"]; + if (!extensions.IsDefined()) { + return absl::OkStatus(); + } + if (!extensions.IsSequence()) { + return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); + } + + for (const YAML::Node& extension : extensions) { + if (!extension || !extension.IsMap()) { + return YamlError(yaml, extension, "Extension is not a map"); + } + const YAML::Node name = extension["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Extension name is not a string"); + } + std::string name_str = GetString(yaml, name); + + const YAML::Node version = extension["version"]; + std::string version_str = GetString(yaml, version); + int extension_version; + if (version.IsDefined()) { + bool is_valid_version = false; + if (version.IsScalar()) { + if (version_str == "latest") { + extension_version = Config::ExtensionConfig::kLatest; + is_valid_version = true; + } else { + if (absl::SimpleAtoi(version_str, &extension_version) && + extension_version >= 0) { + is_valid_version = true; + } + } + } + if (!is_valid_version) { + return YamlError( + yaml, version, + absl::StrCat("Extension '", name_str, + "' version is not a valid number or 'latest'")); + } + } else { + extension_version = Config::ExtensionConfig::kLatest; + } + absl::Status add_status = + config.AddExtensionConfig(name_str, extension_version); + if (!add_status.ok()) { + return YamlError(yaml, extension, add_status.message()); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> ParseMacroList( + absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set macro_set; + const YAML::Node macros = standard_library[std::string(key)]; + if (!macros.IsDefined()) { + return macro_set; + } + if (!macros.IsSequence()) { + return YamlError(yaml, macros, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& macro : macros) { + if (!macro.IsScalar()) { + return YamlError(yaml, macro, + absl::StrCat("Entry in '", key, "' is not a string")); + } + macro_set.insert(GetString(yaml, macro)); + } + return macro_set; +} + +absl::StatusOr>> +ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set> function_set; + const YAML::Node functions = standard_library[std::string(key)]; + if (!functions.IsDefined()) { + return function_set; + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& function : functions) { + if (!function.IsMap()) { + return YamlError(yaml, function, + absl::StrCat("Entry in '", key, "' is not a map")); + } + const YAML::Node name = function["name"]; + if (!name.IsDefined()) { + return YamlError( + yaml, function, + absl::StrCat("Function name in not specified in '", key, "'")); + } + if (!name.IsScalar()) { + return YamlError( + yaml, name, + absl::StrCat("Function name in '", key, "' entry is not a string")); + } + std::string name_str = GetString(yaml, name); + const YAML::Node overloads = function["overloads"]; + if (!overloads.IsDefined()) { + function_set.insert(std::make_pair(name_str, "")); + } else { + if (!overloads.IsSequence()) { + return YamlError( + yaml, overloads, + absl::StrCat("Overloads in '", key, "' entry is not a sequence")); + } + for (const YAML::Node& overload : overloads) { + if (!overload.IsMap()) { + return YamlError( + yaml, overload, + absl::StrCat("Overload in '", key, "' entry is not a map")); + } + const YAML::Node id = overload["id"]; + if (!id || !id.IsScalar()) { + return YamlError( + yaml, id, + absl::StrCat("Overload id in '", key, "' entry is not a string")); + } + function_set.insert(std::make_pair(name_str, GetString(yaml, id))); + } + } + } + return function_set; +} + +absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node standard_library = root["stdlib"]; + if (!standard_library.IsDefined()) { + return absl::OkStatus(); + } + + if (!standard_library.IsMap()) { + return YamlError(yaml, standard_library, + "Standard library config ('stdlib') is not a map"); + } + + Config::StandardLibraryConfig standard_library_config; + + const YAML::Node disable = standard_library["disable"]; + if (disable.IsDefined()) { + if (!disable.IsScalar()) { + return YamlError(yaml, disable, "Node 'disable' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable, + GetBool(yaml, "disable", disable)); + } + + const YAML::Node disable_macros = standard_library["disable_macros"]; + if (disable_macros.IsDefined()) { + if (!disable_macros.IsScalar()) { + return YamlError(yaml, disable_macros, + "Node 'disable_macros' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, + GetBool(yaml, "disable_macros", disable_macros)); + } + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_macros, + ParseMacroList(yaml, standard_library, "include_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_macros, + ParseMacroList(yaml, standard_library, "exclude_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_functions, + ParseFunctionList(yaml, standard_library, "include_functions")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_functions, + ParseFunctionList(yaml, standard_library, "exclude_functions")); + + return config.SetStandardLibraryConfig(standard_library_config); +} + +absl::StatusOr ParseTypeInfo(const YAML::Node& node, + absl::string_view yaml) { + Config::TypeInfo type_config; + const YAML::Node type = node["type"]; + const YAML::Node type_name = node["type_name"]; + if (type.IsDefined() && type_name.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(node, type_name), + "Node 'type' and 'type_name' are mutually exclusive"); + } + + if (type.IsDefined()) { + if (!type.IsScalar()) { + return YamlError(yaml, type, "Node 'type' is not a string"); + } + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); + return type_config; + } + + if (!type_name.IsDefined()) { + return type_config; + } + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + type_config.name = GetString(yaml, type_name); + + const YAML::Node is_type_param = node["is_type_param"]; + if (is_type_param.IsDefined()) { + if (!is_type_param.IsScalar()) { + return YamlError(yaml, is_type_param, + "Node 'is_type_param' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(type_config.is_type_param, + GetBool(yaml, "is_type_param", is_type_param)); + } + + const YAML::Node params = node["params"]; + if (!params.IsDefined()) { + return type_config; + } + if (!params.IsSequence()) { + return YamlError(yaml, params, "Node 'params' is not a sequence"); + } + for (const YAML::Node& param : params) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, + ParseTypeInfo(param, yaml)); + type_config.params.push_back(param_config); + } + + return type_config; +} + +bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { + if (a.name != b.name) { + return a.name < b.name; + } + if (a.params.size() != b.params.size()) { + return a.params.size() < b.params.size(); + } + for (size_t i = 0; i < a.params.size(); ++i) { + if (CompareTypeInfo(a.params[i], b.params[i])) { + return true; + } + if (CompareTypeInfo(b.params[i], a.params[i])) { + return false; + } + } + return false; // They are equal +} + +ConstantKindCase GetConstantKindCase(absl::string_view type_name) { + static const auto kTypeNameToConstantKindCase = + absl::NoDestructor>({ + {"null", ConstantKindCase::kNull}, + {"bool", ConstantKindCase::kBool}, + {"int", ConstantKindCase::kInt}, + {"uint", ConstantKindCase::kUint}, + {"double", ConstantKindCase::kDouble}, + {"string", ConstantKindCase::kString}, + {"bytes", ConstantKindCase::kBytes}, + {"duration", ConstantKindCase::kDuration}, + {"timestamp", ConstantKindCase::kTimestamp}, + }); + if (auto it = kTypeNameToConstantKindCase->find(type_name); + it != kTypeNameToConstantKindCase->end()) { + return it->second; + } + return ConstantKindCase::kUnspecified; +} + +absl::StatusOr ParseConstantValue(absl::string_view yaml, + const YAML::Node& node, + ConstantKindCase constant_kind_case, + absl::string_view value) { + switch (constant_kind_case) { + case ConstantKindCase::kNull: + if (!value.empty()) { + return YamlError(yaml, node, "Failed to parse null constant"); + } + return Constant(nullptr); + case ConstantKindCase::kBool: + if (absl::EqualsIgnoreCase(value, "true")) { + return Constant(true); + } else if (absl::EqualsIgnoreCase(value, "false")) { + return Constant(false); + } else { + return YamlError(yaml, node, "Failed to parse bool constant"); + } + case ConstantKindCase::kInt: + int64_t int_value; + if (!absl::SimpleAtoi(value, &int_value)) { + return YamlError(yaml, node, "Failed to parse int constant"); + } + return Constant(int_value); + case ConstantKindCase::kUint: + uint64_t uint_value; + if (absl::EndsWith(value, "u")) { + value = value.substr(0, value.size() - 1); + } + if (!absl::SimpleAtoi(value, &uint_value)) { + return YamlError(yaml, node, "Failed to parse uint constant"); + } + return Constant(uint_value); + case ConstantKindCase::kDouble: + double double_value; + if (!absl::SimpleAtod(value, &double_value)) { + return YamlError(yaml, node, "Failed to parse double constant"); + } + return Constant(double_value); + case ConstantKindCase::kBytes: { + if (!IsBinary(node)) { + absl::StatusOr bytes_literal = + internal::ParseBytesLiteral(value); + if (bytes_literal.ok()) { + return Constant(BytesConstant(*bytes_literal)); + } + } + return Constant(BytesConstant(value)); + } + case ConstantKindCase::kString: + return Constant(StringConstant(value)); + case ConstantKindCase::kDuration: { + // Duration is deprecated as a builtin type, but still supported for + // compatibility. + absl::Duration duration_value; + if (!absl::ParseDuration(value, &duration_value)) { + return YamlError(yaml, node, "Failed to parse duration constant"); + } + return Constant(duration_value); + } + case ConstantKindCase::kTimestamp: { + // Timestamp is deprecated as a builtin type, but still supported for + // compatibility. + absl::Time timestamp_value; + std::string error; + // Format: YYYY-MM-DDThh:mm:ssZ + if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, + &error)) { + return YamlError( + yaml, node, + absl::StrCat("Failed to parse timestamp constant: ", error, + " supported format: YYYY-MM-DDThh:mm:ssZ")); + } + return Constant(timestamp_value); + } + default: + // This should never happen. + return YamlError(yaml, node, "Constant type is not supported"); + } +} + +absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node variables = root["variables"]; + if (!variables.IsDefined()) { + return absl::OkStatus(); + } + if (!variables.IsSequence()) { + return YamlError(yaml, variables, "Node 'variables' is not a sequence"); + } + + for (const YAML::Node& variable : variables) { + Config::VariableConfig variable_config; + if (!variable || !variable.IsMap()) { + return YamlError(yaml, variable, "Variable is not a map"); + } + const YAML::Node name = variable["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Variable name is not a string"); + } + variable_config.name = GetString(yaml, name); + const YAML::Node description = variable["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Variable description is not a string"); + } + variable_config.description = GetString(yaml, description); + } + const YAML::Node type = variable["type"]; + Config::TypeInfo type_info; + if (type.IsDefined() && !type.IsScalar()) { + // Old format, type spec is in 'type' instead of directly embedded. + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable["type"], yaml)); + } else { + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable, yaml)); + } + ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); + std::string value_str; + YAML::Node value = variable["value"]; + if (value.IsDefined()) { + if (constant_kind_case == ConstantKindCase::kUnspecified) { + return YamlError(yaml, value, + absl::StrCat("Constant type '", type_info.name, + "' is not supported")); + } + if (!value.IsScalar()) { + return YamlError(yaml, value, "Variable value is not a scalar"); + } + if (IsBinary(value)) { + CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); + } else { + value_str = GetString(yaml, value); + } + } + + variable_config.type_info = type_info; + + if (constant_kind_case != ConstantKindCase::kUnspecified && + !value_str.empty()) { + CEL_ASSIGN_OR_RETURN( + variable_config.value, + ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } else if (constant_kind_case == ConstantKindCase::kNull) { + variable_config.value = Constant(nullptr); + } + + CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseFunctionOverloadConfig( + absl::string_view yaml, const YAML::Node& overload, + absl::string_view function_name) { + Config::FunctionOverloadConfig overload_config; + if (!overload || !overload.IsMap()) { + return YamlError(yaml, overload, "Function overload is not a map"); + } + const YAML::Node id = overload["id"]; + if (id.IsDefined()) { + if (!id.IsScalar()) { + return YamlError(yaml, id, "Function overload id is not a string"); + } + overload_config.overload_id = GetString(yaml, id); + } + const YAML::Node examples = overload["examples"]; + if (examples.IsDefined()) { + if (!examples.IsSequence()) { + return YamlError(yaml, examples, + "Function overload examples is not a sequence"); + } + for (const YAML::Node& example : examples) { + if (!example.IsScalar()) { + return YamlError(yaml, example, + "Function overload example is not a string"); + } + overload_config.examples.push_back(GetString(yaml, example)); + } + } + + const YAML::Node signature_node = overload["signature"]; + const YAML::Node target = overload["target"]; + const YAML::Node args = overload["args"]; + if (signature_node.IsDefined()) { + if (!signature_node.IsScalar()) { + return YamlError(yaml, signature_node, + "Function overload signature is not a string"); + } + + if (target.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, target), + "Function overload signature and target are mutually " + "exclusive"); + } + if (args.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, args), + "Function overload signature and args are mutually " + "exclusive"); + } + + std::string signature = GetString(yaml, signature_node); + CEL_ASSIGN_OR_RETURN(ParsedFunctionOverload parsed_signature, + ParseFunctionSignature(signature)); + if (parsed_signature.function_name != function_name) { + return YamlError(yaml, signature_node, + absl::StrCat("Function overload name \"", + parsed_signature.function_name, + "\" does not match function name \"", + function_name, "\"")); + } + overload_config.is_member_function = parsed_signature.is_member; + if (overload_config.overload_id.empty()) { + overload_config.overload_id = signature; + } + if (!parsed_signature.signature_type.has_function()) { + return absl::InternalError(absl::StrCat( + "Function overload signature has no function type: ", signature)); + } + const FunctionTypeSpec& function_type_spec = + parsed_signature.signature_type.function(); + for (const auto& arg : function_type_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto type_info, TypeSpecToTypeInfo(arg)); + overload_config.parameters.push_back(std::move(type_info)); + } + } else { + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; + overload_config.parameters.push_back(type_info); + } + + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, + "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + } + const YAML::Node return_type = overload["return"]; + if (return_type.IsDefined()) { + if (return_type.IsScalar()) { + CEL_ASSIGN_OR_RETURN(auto type_spec, + ParseTypeSpec(GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + TypeSpecToTypeInfo(type_spec)); + } else if (return_type.IsMap()) { + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } else { + return YamlError( + yaml, return_type, + "Function overload return type is neither a string nor a map"); + } + } + return overload_config; +} + +absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node functions = root["functions"]; + if (!functions.IsDefined()) { + return absl::OkStatus(); + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, "Node 'functions' is not a sequence"); + } + + for (const YAML::Node& function : functions) { + Config::FunctionConfig function_config; + if (!function || !function.IsMap()) { + return YamlError(yaml, function, "Function is not a map"); + } + const YAML::Node name = function["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Function name is not a string"); + } + function_config.name = GetString(yaml, name); + const YAML::Node description = function["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Function description is not a string"); + } + function_config.description = GetString(yaml, description); + } + const YAML::Node overloads = function["overloads"]; + if (overloads.IsDefined()) { + if (!overloads.IsSequence()) { + return YamlError(yaml, overloads, + "Function 'overloads' item is not a sequence"); + } + + for (const YAML::Node& overload : overloads) { + CEL_ASSIGN_OR_RETURN( + Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload, function_config.name)); + function_config.overload_configs.push_back(std::move(overload_config)); + } + } + + CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); + } + return absl::OkStatus(); +} + +void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { + const auto& container_config = env_config.GetContainerConfig(); + if (container_config.IsEmpty()) { + return; + } + + out << YAML::Key << "container"; + if (container_config.abbreviations.empty() && + container_config.aliases.empty()) { + out << YAML::Value << YAML::DoubleQuoted << container_config.name; + } else { + out << YAML::Value << YAML::BeginMap; + if (!container_config.name.empty()) { + out << YAML::Key << "name" << YAML::Value << YAML::DoubleQuoted + << container_config.name; + } + if (!container_config.abbreviations.empty()) { + std::vector sorted_abbrs = container_config.abbreviations; + absl::c_sort(sorted_abbrs); + out << YAML::Key << "abbreviations" << YAML::Value << YAML::BeginSeq; + for (const auto& abbr : sorted_abbrs) { + out << YAML::Value << YAML::DoubleQuoted << abbr; + } + out << YAML::EndSeq; + } + if (!container_config.aliases.empty()) { + std::vector sorted_aliases = + container_config.aliases; + absl::c_sort(sorted_aliases, [](const Config::ContainerConfig::Alias& a, + const Config::ContainerConfig::Alias& b) { + return a.alias < b.alias; + }); + out << YAML::Key << "aliases" << YAML::Value << YAML::BeginSeq; + for (const auto& alias : sorted_aliases) { + out << YAML::BeginMap; + out << YAML::Key << "alias" << YAML::Value << YAML::DoubleQuoted + << alias.alias; + out << YAML::Key << "qualified_name" << YAML::Value + << YAML::DoubleQuoted << alias.qualified_name; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } +} + +void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { + if (env_config.GetExtensionConfigs().empty()) { + return; + } + + // Sort the extensions to make the output deterministic. + std::vector sorted_extensions = + env_config.GetExtensionConfigs(); + absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, + const Config::ExtensionConfig& b) { + return a.name < b.name; + }); + out << YAML::Key << "extensions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::ExtensionConfig& extension_config : sorted_extensions) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << extension_config.name; + if (extension_config.version != Config::ExtensionConfig::kLatest) { + out << YAML::Key << "version"; + out << YAML::Value << extension_config.version; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitMacroList(YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set& macros) { + if (macros.empty()) { + return; + } + out << YAML::Key << std::string(key); + out << YAML::Value << YAML::BeginSeq; + std::vector sorted_macros(macros.begin(), macros.end()); + absl::c_sort(sorted_macros); + for (const std::string& macro : sorted_macros) { + out << YAML::Value << YAML::DoubleQuoted << macro; + } + out << YAML::EndSeq; +} + +void EmitFunctionList( + YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set>& functions) { + if (functions.empty()) { + return; + } + + // Build a map from function name to a vector of overload ids. + // Using std::map ensures function names are sorted. + std::map> function_overloads; + for (const auto& pair : functions) { + function_overloads[pair.first].push_back(pair.second); + } + + out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; + for (auto const& [name, overloads] : function_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << name; + + // If the only overload is the empty string, it signifies that all overloads + // of the function are included/excluded. In this case, we don't emit the + // "overloads" key. Otherwise, emit the specific overloads. + if (!(overloads.size() == 1 && overloads[0].empty())) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = overloads; + absl::c_sort(sorted_overloads); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const std::string& overload : sorted_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { + const Config::StandardLibraryConfig& standard_library_config = + env_config.GetStandardLibraryConfig(); + if (standard_library_config.IsEmpty()) { + return; + } + + out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; + if (standard_library_config.disable) { + out << YAML::Key << "disable" << YAML::Value << true; + } + if (standard_library_config.disable_macros) { + out << YAML::Key << "disable_macros" << YAML::Value << true; + } + EmitMacroList(out, "include_macros", standard_library_config.included_macros); + EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); + EmitFunctionList(out, "include_functions", + standard_library_config.included_functions); + EmitFunctionList(out, "exclude_functions", + standard_library_config.excluded_functions); + out << YAML::EndMap; +} + +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + // Note: the map is already started when this is called, so we don't emit + // BeginMap here or EndMap at the end. + bool signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); + if (type_spec.ok()) { + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "type"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; + } + } + } + if (!signature_generated) { + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } +} + +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + const auto& variable_configs = env_config.GetVariableConfigs(); + if (variable_configs.empty()) { + return; + } + + // Sort variable_configs by name to ensure deterministic output. + std::vector sorted_variable_configs = + variable_configs; + absl::c_sort(sorted_variable_configs, + [](const Config::VariableConfig& a, + const Config::VariableConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "variables"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::VariableConfig& variable_config : + sorted_variable_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.name; + if (!variable_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.description; + } + EmitTypeInfo(variable_config.type_info, out, options); + if (variable_config.value.has_value()) { + const Constant& constant = variable_config.value; + switch (constant.kind_case()) { + case ConstantKindCase::kUnspecified: + case ConstantKindCase::kNull: + break; + case ConstantKindCase::kBool: + out << YAML::Key << "value" << YAML::Value << constant.bool_value(); + break; + case ConstantKindCase::kInt: + out << YAML::Key << "value" << YAML::Value << constant.int_value(); + break; + case ConstantKindCase::kUint: + out << YAML::Key << "value" << YAML::Value << constant.uint_value(); + break; + case ConstantKindCase::kDouble: + out << YAML::Key << "value" << YAML::Value << constant.double_value(); + break; + case ConstantKindCase::kBytes: { + out << YAML::Key << "value"; + const std::string& bytes_value = constant.bytes_value(); + std::string hex_escaped = "b\""; + for (unsigned char byte : bytes_value) { + absl::StrAppend(&hex_escaped, "\\x"); + absl::StrAppendFormat(&hex_escaped, "%02x", byte); + } + absl::StrAppend(&hex_escaped, "\""); + out << YAML::Value << hex_escaped; + break; + } + case ConstantKindCase::kString: + out << YAML::Key << "value"; + out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); + break; + case ConstantKindCase::kDuration: + out << YAML::Key << "value" << YAML::Value; + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + out << absl::FormatDuration(constant.duration_value()); + break; + case ConstantKindCase::kTimestamp: + out << YAML::Key << "value" << YAML::Value; + out << absl::FormatTime( + "%Y-%m-%d%ET%H:%M:%E*SZ", + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + constant.timestamp_value(), absl::UTCTimeZone()); + break; + } + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitFunctionOverloadConfig( + absl::string_view function_name, + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + out << YAML::BeginMap; + bool signature_generated = false; + std::string signature_str; + if (options.use_type_signatures) { + bool param_type_spec_generated = true; + std::vector params; + params.reserve(overload_config.parameters.size()); + for (const auto& parameter : overload_config.parameters) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(parameter); + if (!type_spec.ok()) { + param_type_spec_generated = false; + break; + } + params.push_back(std::move(*type_spec)); + } + if (param_type_spec_generated) { + absl::StatusOr signature = MakeOverloadSignature( + function_name, params, overload_config.is_member_function); + if (signature.ok()) { + signature_str = std::move(*signature); + signature_generated = true; + } + } + } + if (!overload_config.overload_id.empty()) { + if (!signature_generated || overload_config.overload_id != signature_str) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + } + if (signature_generated) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << signature_str; + } + if (!signature_generated) { + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out, options); + } else { + EmitTypeInfo(overload_config.parameters[0], out, options); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } + } + bool return_type_signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = + TypeInfoToTypeSpec(overload_config.return_type); + if (type_spec.ok()) { + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + return_type_signature_generated = true; + } + } + } + if (!return_type_signature_generated) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out, options); + out << YAML::EndMap; + } + out << YAML::EndMap; +} + +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { + const std::vector& function_configs = + env_config.GetFunctionConfigs(); + if (function_configs.empty()) { + return; + } + + // Sort function_configs by name to ensure deterministic output. + std::vector sorted_function_configs = + function_configs; + absl::c_sort(sorted_function_configs, + [](const Config::FunctionConfig& a, + const Config::FunctionConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "functions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionConfig& function_config : + sorted_function_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << function_config.name; + if (!function_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << function_config.description; + } + if (!function_config.overload_configs.empty()) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = + function_config.overload_configs; + absl::c_sort(sorted_overloads, + [](const Config::FunctionOverloadConfig& a, + const Config::FunctionOverloadConfig& b) { + for (size_t i = 0; i < a.parameters.size(); ++i) { + // Order like this: foo(a), foo(a, b) + if (i >= b.parameters.size()) { + return false; + } + if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { + return true; + } + if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { + return false; + } + } + return false; + }); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionOverloadConfig& overload_config : + sorted_overloads) { + EmitFunctionOverloadConfig(function_config.name, overload_config, out, + options); + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +absl::Status ParseContextVariableConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node context_variable = root["context_variable"]; + if (!context_variable.IsDefined()) { + return absl::OkStatus(); + } + if (!context_variable.IsMap()) { + return YamlError(yaml, context_variable, + "Node 'context_variable' is not a map"); + } + + const YAML::Node type_name = context_variable["type_name"]; + const YAML::Node type = context_variable["type"]; + const YAML::Node* type_node = nullptr; + if (type.IsDefined() && type.IsScalar()) { + type_node = &type; + } else if (type_name.IsDefined() && type_name.IsScalar()) { + type_node = &type_name; + } else { + return YamlError(yaml, context_variable, + "Node 'context_variable' does not have a valid type"); + } + ABSL_DCHECK(type_node != nullptr); + config.SetContextType(GetString(yaml, *type_node)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { + Config config; + CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + if (!root.IsDefined() || root.IsNull()) { + return config; + } + + if (!root.IsMap()) { + return absl::InvalidArgumentError(FormatYamlErrorMessage( + yaml, "Invalid CEL environment config YAML", root.Mark())); + } + + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContextVariableConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); + return config; +} + +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options) { + YAML::Emitter out(os); + out.SetIndent(2); + out << YAML::BeginMap; + if (!env_config.GetName().empty()) { + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); + } + EmitContainerConfig(env_config, out); + EmitExtensionConfigs(env_config, out); + EmitStandardLibraryConfig(env_config, out); + EmitVariableConfigs(env_config, out, options); + EmitFunctionConfigs(env_config, out, options); + out << YAML::EndMap; +} + +} // namespace cel diff --git a/env/env_yaml.h b/env/env_yaml.h new file mode 100644 index 000000000..7bf7bf6b4 --- /dev/null +++ b/env/env_yaml.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" + +namespace cel { + +// EnvConfigFromYaml creates an environment configuration from a YAML string. +// +// To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz +// coverage, but its security model is unclear. Additionally, callers should be +// aware that improper CEL configuration can lead to unsafe or unpredictably +// expensive expressions. +absl::StatusOr EnvConfigFromYaml(const std::string& yaml); + +struct EnvConfigToYamlOptions { + // Whether to use type and overload signatures instead of arg/return types in + // the output YAML. + // Example of type signature: "map>" vs + // type_name: "map" + // params: + // - type_name: "int" + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // + // Example of overload signature config: + // name: "foo" + // overloads: + // - signature: "timestamp.foo(A<~B>)" + // return: "int" + // vs + // name: "foo" + // overloads: + // - id: "foo_id" + // target: + // type_name: "timestamp" + // args: + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // return: + // type_name: "int" + // TODO(uncreated-issue/91): default to true after all dependencies are updated + bool use_type_signatures = false; +}; + +// EnvConfigToYaml serializes an environment configuration as a YAML string. +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc new file mode 100644 index 000000000..c5bd1b787 --- /dev/null +++ b/env/env_yaml_test.cc @@ -0,0 +1,1949 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(EnvYamlTest, ParseContainerConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: "test.container" + )yaml")); + + EXPECT_THAT(config.GetContainerConfig(), + Field(&Config::ContainerConfig::name, "test.container")); +} + +TEST(EnvYamlTest, ParseContainerConfig_AlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: + name: test.container + abbreviations: + - abbr1.Abbr1 + - abbr2.Abbr2 + aliases: + - alias: alias1 + qualified_name: qual.name1 + - alias: alias2 + qualified_name: qual.name2 + )yaml")); + + const auto& container_config = config.GetContainerConfig(); + EXPECT_EQ(container_config.name, "test.container"); + EXPECT_THAT(container_config.abbreviations, + UnorderedElementsAre("abbr1.Abbr1", "abbr2.Abbr2")); + ASSERT_THAT(container_config.aliases, SizeIs(2)); + EXPECT_EQ(container_config.aliases[0].alias, "alias1"); + EXPECT_EQ(container_config.aliases[0].qualified_name, "qual.name1"); + EXPECT_EQ(container_config.aliases[1].alias, "alias2"); + EXPECT_EQ(container_config.aliases[1].qualified_name, "qual.name2"); +} + +TEST(EnvYamlTest, ParseExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + extensions: + - name: "math" + version: latest + - name: "optional" + version: 2 + - name: "strings" + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvYamlTest, DefaultExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); +} + +TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + disable: true + disable_macros: true + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_TRUE(stdlib_config.disable); + EXPECT_TRUE(stdlib_config.disable_macros); + EXPECT_THAT(stdlib_config.excluded_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); + EXPECT_THAT( + stdlib_config.excluded_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + include_macros: + - map + - filter + include_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: "_+_(list<~A>,list<~A>)" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_THAT(stdlib_config.included_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT( + stdlib_config.included_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseVariableConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "msg"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); + EXPECT_EQ(variable_config.description, + "msg represents all possible type permutation which CEL " + "understands from a proto perspective"); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type: "map" + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseContextVariableConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type_name: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableConfigAlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable) { + EXPECT_THAT(EnvConfigFromYaml(R"yaml( + context_variable: 123 + + )yaml"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' is not a map"))); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable2) { + EXPECT_THAT( + EnvConfigFromYaml(R"yaml( + context_variable: + type: + foo: bar + )yaml"), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' does not have a valid type"))); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type_name: "map" + params: + - type_name: "string" + - type_name: "A" + is_type_param: true + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseVariableConfigWithNestedRuleOldFormat) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "x" + type: + type_name: "int" + )yaml")); + + ASSERT_THAT(config.GetVariableConfigs(), SizeIs(1)); + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "x"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "int"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); +} + +struct ParseConstantTestCase { + std::string type; + std::string value; + std::string expected_error; // Empty if no error. + Constant expected_constant; +}; + +class EnvYamlParseConstantTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { + const ParseConstantTestCase& param = GetParam(); + const std::string yaml = absl::StrFormat( + R"yaml( + variables: + - name: "const" + type: "%s" + value: %s + )yaml", + param.type, param.value); + absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); + if (!param.expected_error.empty()) { + EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + return; + } + ASSERT_OK_AND_ASSIGN(Config config, status_or_config); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "const"); + EXPECT_EQ(variable_config.type_info.name, param.type) << " yaml: " << yaml; + EXPECT_EQ(variable_config.value, param.expected_constant) + << " yaml: " << yaml; +} + +std::vector GetParseConstantTestCases() { + return { + ParseConstantTestCase{ + .type = "null", + .value = "\"\"", + .expected_constant = Constant(nullptr), + }, + ParseConstantTestCase{ + .type = "null", + .value = "anything", + .expected_error = "Failed to parse null constant", + }, + ParseConstantTestCase{ + .type = "bool", + .value = "TRUE", + .expected_constant = Constant(true), + }, + ParseConstantTestCase{ + .type = "bool", + .value = "false", + .expected_constant = Constant(false), + }, + ParseConstantTestCase{ + .type = "bool", + .value = "yes", + .expected_error = "Failed to parse bool constant", + }, + ParseConstantTestCase{ + .type = "int", + .value = "42", + .expected_constant = Constant(int64_t{42}), + }, + ParseConstantTestCase{ + .type = "int", + .value = "41.999", + .expected_error = "Failed to parse int constant", + }, + ParseConstantTestCase{ + .type = "uint", + .value = "42", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type = "uint", + .value = "42u", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type = "uint", + .value = "-1", + .expected_error = "Failed to parse uint constant", + }, + ParseConstantTestCase{ + .type = "double", + .value = "42.42", + .expected_constant = Constant(42.42), + }, + ParseConstantTestCase{ + .type = "double", + .value = "abc", + .expected_error = "Failed to parse double constant", + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "b\"\\xFF\\x00\\x01\"", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "!!binary /wAB", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "!!binary YWJj=", + .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", + }, + ParseConstantTestCase{ + .type = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type = "string", + .value = "abc", + .expected_constant = Constant(StringConstant("abc")), + }, + ParseConstantTestCase{ + .type = "string", + .value = "\"\\\"abc\\\"\"", + .expected_constant = Constant(StringConstant("\"abc\"")), + }, + ParseConstantTestCase{ + .type = "duration", + .value = "1s", + .expected_constant = Constant(absl::Seconds(1)), + }, + ParseConstantTestCase{ + .type = "duration", + .value = "abc", + .expected_error = "Failed to parse duration constant", + }, + ParseConstantTestCase{ + .type = "timestamp", + .value = "2023-01-01T00:00:00Z", + .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), + }, + ParseConstantTestCase{ + .type = "timestamp", + .value = "abc", + .expected_error = "Failed to parse timestamp constant", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, + ::testing::ValuesIn(GetParseConstantTestCases())); + +struct ParseFunctionTestCase { + std::string yaml; + Config::FunctionConfig expected_function_config; +}; + +class EnvYamlParseFunctionTest + : public testing::TestWithParam {}; + +void ExpectTypeInfoEqual(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_THAT(actual.params, SizeIs(expected.params.size())); + for (size_t i = 0; i < expected.params.size(); ++i) { + ExpectTypeInfoEqual(actual.params[i], expected.params[i]); + } +} + +TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { + const ParseFunctionTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); + + ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); + const Config::FunctionConfig& function_config = + config.GetFunctionConfigs()[0]; + const Config::FunctionConfig& expected = param.expected_function_config; + + EXPECT_EQ(function_config.name, expected.name); + EXPECT_EQ(function_config.description, expected.description); + + ASSERT_THAT(function_config.overload_configs, + SizeIs(expected.overload_configs.size())); + + for (size_t i = 0; i < expected.overload_configs.size(); ++i) { + const auto& actual_overload = function_config.overload_configs[i]; + const auto& expected_overload = expected.overload_configs[i]; + + EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); + EXPECT_THAT(actual_overload.examples, + ElementsAreArray(expected_overload.examples)); + EXPECT_EQ(actual_overload.is_member_function, + expected_overload.is_member_function); + + ASSERT_THAT(actual_overload.parameters, + SizeIs(expected_overload.parameters.size())); + for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { + ExpectTypeInfoEqual(actual_overload.parameters[j], + expected_overload.parameters[j]); + } + + ExpectTypeInfoEqual(actual_overload.return_type, + expected_overload.return_type); + } +} + +std::vector GetParseFunctionTestCases() { + return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - signature: "google.protobuf.StringValue.isEmpty()" + examples: + - "''.isEmpty() // true" + return: "bool" + - signature: "list<~T>.isEmpty()" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = + "google.protobuf.StringValue.isEmpty()", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = {{.name = "string_wrapper"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list<~T>.isEmpty()", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "wrapper_string_isEmpty", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = + {{.name = "google.protobuf.StringValue"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list_isEmpty", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - signature: "contains(list<~T>, ~T)" + examples: + - "contains([1, 2, 3], 2) // true" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "contains(list<~T>, ~T)", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - id: "global_contains" + examples: + - "contains([1, 2, 3], 2) // true" + args: + - type_name: "list" + params: + - type_name: "T" + is_type_param: true + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "global_contains", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, + ::testing::ValuesIn(GetParseFunctionTestCases())); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +class EnvYamlParseTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { + const ParseTestCase& param = GetParam(); + absl::StatusOr config = EnvConfigFromYaml(param.yaml); + EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlParseTest, EnvYamlParseTest, + ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Invalid CEL environment config YAML\n" + "| invalid yaml \n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + name: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'name' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + - error: "error" + )yaml", + .expected_error = + "3:19: Node 'container' is neither a string nor a map\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + name: [] + )yaml", + .expected_error = "3:25: Node 'name' in container is not a string\n" + "| name: []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: "abbr" + )yaml", + .expected_error = "3:34: Node 'abbreviations' is not a sequence\n" + "| abbreviations: \"abbr\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: + - [] + )yaml", + .expected_error = "4:21: Abbreviation is not a string\n" + "| - []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: "not a sequence" + )yaml", + .expected_error = "3:28: Node 'aliases' is not a sequence\n" + "| aliases: \"not a sequence\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - "not a map" + )yaml", + .expected_error = "4:21: Alias entry is not a map\n" + "| - \"not a map\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - qualified_name: "qual" + )yaml", + .expected_error = "4:21: Alias entry missing 'alias' string\n" + "| - qualified_name: \"qual\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - alias: "my_alias" + )yaml", + .expected_error = "4:21: Alias entry missing" + " 'qualified_name' string\n" + "| - alias: \"my_alias\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + -name: "optional" + - name: "other" + )yaml", + .expected_error = "5:21: end of map not found\n" + "| - name: \"other\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: "bar" + )yaml", + .expected_error = "2:27: Node 'extensions' is not a sequence\n" + "| extensions: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: + - something: "bar" + )yaml", + .expected_error = "4:19: Extension name is not a string\n" + "| - something: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: last + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: last\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: -15 + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: -15\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: 1 + - name: "math" + version: 2 + )yaml", + .expected_error = "5:19: Extension 'math' version 1 is already " + "included. Cannot also include version 2\n" + "| - name: \"math\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: "error" + )yaml", + .expected_error = "2:23: Standard library config ('stdlib') " + "is not a map\n" + "| stdlib: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable: "error" + )yaml", + .expected_error = "3:26: Node 'disable' is not a boolean\n" + "| disable: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable_macros: "error" + )yaml", + .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" + "| disable_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: "error" + )yaml", + .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" + "| exclude_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: + - foo: "error" + )yaml", + .expected_error = "4:19: Entry in 'exclude_macros' " + "is not a string\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: "error" + )yaml", + .expected_error = "3:36: Node 'include_functions' " + "is not a sequence\n" + "| include_functions: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - "error" + )yaml", + .expected_error = "4:19: Entry in 'include_functions' " + "is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - foo: "error" + )yaml", + .expected_error = "4:19: Function name in not specified in " + "'include_functions'\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "5:30: Overloads in 'include_functions' entry " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - foo_string + )yaml", + .expected_error = "6:21: Overload in 'include_functions' entry " + "is not a map\n" + "| - foo_string\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - id: + - foo_int64 + )yaml", + .expected_error = "7:21: Overload id in 'include_functions' entry " + "is not a string\n" + "| - foo_int64\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: + - type_name: "opaque" + )yaml", + .expected_error = "4:19: Variable name is not a string\n" + "| - type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: + - params: + )yaml", + .expected_error = "5:21: Node 'type_name' is not a string\n" + "| - params:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + params: + - type_name: "int" + - type_name: "A" + is_type_param: maybe + )yaml", + .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" + "| is_type_param: maybe\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + type: "opaque" + )yaml", + .expected_error = "4:19: Node 'type' and 'type_name'" + " are mutually exclusive\n" + "| type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: -1 + )yaml", + .expected_error = "5:26: Failed to parse uint constant\n" + "| value: -1\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: many + )yaml", + .expected_error = "2:26: Node 'functions' is not a sequence\n" + "| functions: many\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: + - overloads: + )yaml", + .expected_error = "4:19: Function name is not a string\n" + "| - overloads:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "4:30: Function 'overloads' item " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: + - "error" + )yaml", + .expected_error = "6:25: Function overload id is not a string\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + - "error" + )yaml", + .expected_error = "7:25: Function overload target is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + type_name: "Foo" + params: + - type_name: + - is_type_param: true + )yaml", + .expected_error = "10:31: Node 'type_name' is not a string\n" + "| " + "- is_type_param: true\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + args: "a bunch" + )yaml", + .expected_error = "6:29: Function overload args is not a sequence\n" + "| args: \"a bunch\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + return: [1] + )yaml", + .expected_error = "6:31: Function overload return type" + " is neither a string nor a map\n" + "| return: [1]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + signature: "bar()" + )yaml", + .expected_error = "6:34: Function overload name \"bar\" " + "does not match function name \"foo\"\n" + "| signature: \"bar()\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: [ "foo()" ] + )yaml", + .expected_error = + "5:34: Function overload signature is not a string\n" + "| - signature: [ \"foo()\" ]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + target: + type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and target " + "are mutually exclusive\n" + "| target:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + args: + - type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and args are " + "mutually exclusive\n" + "| args:\n" + "| ^", + })); + +std::string Unindent(std::string_view yaml) { + absl::string_view yaml_view = yaml; + std::vector lines = absl::StrSplit(yaml_view, '\n'); + int indent = -1; + std::vector unindented_lines; + for (auto& line : lines) { + std::size_t pos = line.find_first_not_of(" \t"); + if (pos == std::string::npos) { + // Skip blank lines. + continue; + } + if (indent == -1) { + indent = pos; + } + if (pos >= indent) { + unindented_lines.push_back(line.substr(indent)); + } else { + unindented_lines.push_back(line); + } + } + return absl::StrJoin(unindented_lines, "\n"); +} + +struct ExportTestCase { + absl::StatusOr config; + std::string expected_yaml; + std::string expected_alt_yaml; +}; + +class EnvYamlExportTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlExportTest, EnvYamlExport) { + const ExportTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, param.config); + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + std::string yaml_output = Unindent(ss.str()); + std::string expected_yaml = Unindent(param.expected_yaml); + EXPECT_EQ(yaml_output, expected_yaml); + + if (!param.expected_alt_yaml.empty()) { + std::stringstream alt_ss; + EnvConfigToYaml(config, alt_ss, {.use_type_signatures = false}); + std::string alt_yaml_output = Unindent(alt_ss.str()); + std::string expected_alt_yaml = Unindent(param.expected_alt_yaml); + EXPECT_EQ(alt_yaml_output, expected_alt_yaml); + } +} + +std::vector GetExportTestCases() { + return { + ExportTestCase{ + .config = + []() { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig( + {.name = "test.container", + .abbreviations = {"foo", "bar"}, + .aliases = { + {.alias = "foo", .qualified_name = "test.foo"}, + {.alias = "bar", .qualified_name = "test.bar"}, + }}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: + name: "test.container" + abbreviations: + - "bar" + - "foo" + aliases: + - alias: "bar" + qualified_name: "test.bar" + - alias: "foo" + qualified_name: "test.foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); + return config; + }(), + .expected_yaml = R"yaml( + extensions: + - name: "bindings" + - name: "math" + - name: "optional" + version: 2 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable_macros = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable_macros: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), + std::make_pair("matches", ""), + std::make_pair("_+_", "_+_(bytes,bytes)"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_functions: + - name: "_+_" + overloads: + - id: "_+_(bytes,bytes)" + - id: "_+_(list<~A>,list<~A>)" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "null"}, + .value = Constant(nullptr)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "null" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "bool"}, + .value = Constant(true)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bool" + value: true + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bool" + value: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "int"}, + .value = Constant(int64_t{42})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "int" + value: 42 + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "int" + value: 42 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "uint"}, + .value = Constant(uint64_t{777})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "uint" + value: 777 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "double"}, + .value = Constant(0.75)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "double" + value: 0.75 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "bytes"}, + .value = Constant( + BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bytes" + value: b"\xff\x00\x01" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + Constant c; + c.set_string_value("'single' \"double\""); + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "string"}, + .value = Constant(StringConstant("'single' \"double\""))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "string" + value: "'single' \"double\"" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "duration"}, + .value = Constant(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "duration" + value: 1h2m3s + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "duration" + value: 1h2m3s + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "timestamp"}, + .value = Constant(absl::FromUnixSeconds(1767323045))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = + "google.expr.proto3.test.TestAllTypes"}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "google.expr.proto3.test.TestAllTypes" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = { + .name = "A", + .params = {{.name = "int"}, + {.name = "B", .is_type_param = true}}}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "A" + )yaml", + .expected_alt_yaml = R"yaml( + variables: + - name: "foo" + type_name: "A" + params: + - type_name: "int" + - type_name: "B" + is_type_param: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_id", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .description = "my desc", + .overload_configs = { + {.overload_id = "foo_overload_a", + .parameters = {{.name = "timestamp"}}, + .return_type = {.name = "list", + .params = {{.name = "int"}}}}, + {.overload_id = "foo_overload_b", + .parameters = {{.name = "double"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "string"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + signature: "foo(double,A)" + return: "string" + - id: "foo_overload_a" + signature: "foo(timestamp)" + return: "list" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + args: + - type_name: "double" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "string" + - id: "foo_overload_a" + args: + - type_name: "timestamp" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "timestamp.foo(A<~B>)", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "timestamp.foo(A<~B>)" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, + }; +}; + +INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, + ::testing::ValuesIn(GetExportTestCases())); + +class EnvYamlStructuredRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlStructuredRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetStructuredRoundTripTestCases() { + return { + R"yaml( + stdlib: + disable: true + disable_macros: true + )yaml", + R"yaml( + name: "test.env" + container: "common.proto.prefix" + extensions: + - name: "math" + version: 0 + - name: "optional" + version: 2 + stdlib: + include_macros: + - "filter" + - "map" + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + container: + name: "test.container" + abbreviations: + - "abbr1.Abbr1" + - "abbr2.Abbr2" + aliases: + - alias: "alias1" + qualified_name: "qual.name1" + - alias: "alias2" + qualified_name: "qual.name2" + )yaml", + R"yaml( + extensions: + - name: "bindings" + - name: "math" + stdlib: + exclude_macros: + - "filter" + - "map" + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlStructuredRoundTripTest, EnvYamlStructuredRoundTripTest, + ::testing::ValuesIn(GetStructuredRoundTripTestCases())); + +class EnvYamlSignatureRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlSignatureRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetSignatureRoundTripTestCases() { + return { + R"yaml( + variables: + - name: "a" + type: "null" + - name: "b" + type: "bool" + value: true + - name: "c" + type: "int" + value: 42 + - name: "d" + type: "uint" + value: 777 + - name: "e" + type: "double" + value: 0.75 + - name: "f" + type: "bytes" + value: b"\xff\x00\x01" + - name: "g" + type: "string" + value: "plain 'single' \"double\"" + - name: "h" + type: "duration" + value: 1h2m3s + - name: "i" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "foo(timestamp,A<~B>)" + return: "list" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlSignatureRoundTripTest, + EnvYamlSignatureRoundTripTest, + ::testing::ValuesIn(GetSignatureRoundTripTestCases())); + +} // namespace +} // namespace cel diff --git a/env/internal/BUILD b/env/internal/BUILD new file mode 100644 index 000000000..ec4a0b15c --- /dev/null +++ b/env/internal/BUILD @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ext_registry", + srcs = ["ext_registry.cc"], + hdrs = ["ext_registry.h"], + deps = [ + "//compiler", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_ext_registry", + srcs = ["runtime_ext_registry.cc"], + hdrs = ["runtime_ext_registry.h"], + deps = [ + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ext_registry_test", + srcs = ["ext_registry_test.cc"], + deps = [ + ":ext_registry", + "//checker:type_checker_builder", + "//compiler", + "//internal:testing", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "runtime_ext_registry_test", + srcs = ["runtime_ext_registry_test.cc"], + deps = [ + ":runtime_ext_registry", + "//common:ast", + "//common:source", + "//common:value", + "//common:value_testing", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/internal/ext_registry.cc b/env/internal/ext_registry.cc new file mode 100644 index 000000000..b32239ac3 --- /dev/null +++ b/env/internal/ext_registry.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +void ExtensionRegistry::RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + library_registry_.push_back( + LibraryRegistration(name, alias, version, std::move(library_factory))); +} + +absl::StatusOr ExtensionRegistry::GetCompilerLibrary( + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name)); + } + version = max_version; + } + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.GetLibrary(); + } + } + + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/ext_registry.h b/env/internal/ext_registry.h new file mode 100644 index 000000000..ab5b67a24 --- /dev/null +++ b/env/internal/ext_registry.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +// A registry for CEL compiler extension libraries. +// +// Used to register and retrieve CompilerLibraries by name (or alias) and +// version. +class ExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory); + + absl::StatusOr GetCompilerLibrary(absl::string_view name, + int version) const; + + private: + class LibraryRegistration final { + public: + LibraryRegistration( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + factory_(std::move(library_factory)) {} + + CompilerLibrary GetLibrary() const { return factory_(); } + + private: + std::string name_; + std::string alias_; + int version_; + absl::AnyInvocable factory_; + + friend class ExtensionRegistry; + }; + + std::vector library_registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ diff --git a/env/internal/ext_registry_test.cc b/env/internal/ext_registry_test.cc new file mode 100644 index 000000000..9e345c781 --- /dev/null +++ b/env/internal/ext_registry_test.cc @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "parser/parser_interface.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Field; +using ::testing::HasSubstr; + +TEST(ExtensionRegistryTest, GetCompilerLibrary) { + ExtensionRegistry registry; + registry.RegisterCompilerLibrary("foo1", "f", 1, []() { + return CompilerLibrary("foo1_1", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo1", "f", 2, []() { + return CompilerLibrary("foo1_2", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo2", "", 1, []() { + return CompilerLibrary("foo2_1", nullptr, nullptr); + }); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo1#3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/internal/runtime_ext_registry.cc b/env/internal/runtime_ext_registry.cc new file mode 100644 index 000000000..dc78a38e3 --- /dev/null +++ b/env/internal/runtime_ext_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +void RuntimeExtensionRegistry::AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) { + registry_.push_back(Registration(name, alias, version, + std::move(function_registration_callback))); +} + +absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); + } + version = max_version; + } + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.RegisterExtensionFunctions(runtime_builder, + runtime_options); + } + } + + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/runtime_ext_registry.h b/env/internal/runtime_ext_registry.h new file mode 100644 index 000000000..67838519f --- /dev/null +++ b/env/internal/runtime_ext_registry.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +using FunctionRegistrationCallback = absl::AnyInvocable; + +// A registry for CEL runtime extension functions. +// +// Used to register runtime functions for extensions by name (or alias) and +// version. +class RuntimeExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback); + + absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options, + absl::string_view name, + int version) const; + + private: + class Registration final { + public: + Registration(absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + function_registration_callback_( + std::move(function_registration_callback)) {} + + absl::Status RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) const { + return function_registration_callback_(runtime_builder, runtime_options); + } + + private: + std::string name_; + std::string alias_; + int version_; + FunctionRegistrationCallback function_registration_callback_; + + friend class RuntimeExtensionRegistry; + }; + + std::vector registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ diff --git a/env/internal/runtime_ext_registry_test.cc b/env/internal/runtime_ext_registry_test.cc new file mode 100644 index 000000000..c6125d20f --- /dev/null +++ b/env/internal/runtime_ext_registry_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::test::StringValueIs; + +Value Hello1(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, old " + input.ToString() + "!", + context.arena()); +} + +Value Hello2(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, new " + input.ToString() + "!", + context.arena()); +} + +RuntimeExtensionRegistry GetRuntimeExtensionRegistry() { + RuntimeExtensionRegistry registry; + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 1, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("hello", &Hello1, + runtime_builder.function_registry()); + }); + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 2, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterMemberOverload("hello", &Hello2, + runtime_builder.function_registry()); + }); + return registry; +} + +class RuntimeExtensionRegistryTest : public testing::Test { + protected: + absl::StatusOr Run(std::string_view extension_name, int version, + std::string_view expr) { + const RuntimeExtensionRegistry registry = GetRuntimeExtensionRegistry(); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, + NewParserBuilder(ParserOptions())->Build()); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); + + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + cel::RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); + + CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( + runtime_builder, runtime_options, extension_name, version)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(RuntimeExtensionRegistryTest, SpecificExtensionVersion) { + EXPECT_THAT(Run("hello_extension", 1, "hello('world')"), + IsOkAndHolds(StringValueIs("Hello, old world!"))); +} + +TEST_F(RuntimeExtensionRegistryTest, LatestExtensionVersion) { + EXPECT_THAT(Run("hello_extension_alias", RuntimeExtensionRegistry::kLatest, + "'world'.hello()"), + IsOkAndHolds(StringValueIs("Hello, new world!"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc new file mode 100644 index 000000000..b866a5965 --- /dev/null +++ b/env/runtime_std_extensions.cc @@ -0,0 +1,133 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/optional.h" +#include "env/env_runtime.h" +#include "env/internal/runtime_ext_registry.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" +#include "runtime/optional_types.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { + +void RegisterStandardExtensions(EnvRuntime& env_runtime) { + env_internal::RuntimeExtensionRegistry& registry = + env_runtime.GetRuntimeExtensionRegistry(); + registry.AddFunctionRegistration( + "cel.lib.ext.bindings", "bindings", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.encoders", "encoders", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterEncodersFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.lists", "lists", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterListsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.math", "math", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.optional", "optional", version, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::EnableOptionalTypes(runtime_builder); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.protos", "protos", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.sets", "sets", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterSetsFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.strings", "strings", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + cel::extensions::StringsExtensionOptions strings_options; + strings_options.version = version; + return cel::extensions::RegisterStringsFunctions( + runtime_builder.function_registry(), runtime_options, + strings_options); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.regex", "regex", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterRegexExtensionFunctions( + runtime_builder); + }); +} + +} // namespace cel diff --git a/env/runtime_std_extensions.h b/env/runtime_std_extensions.h new file mode 100644 index 000000000..d7f714226 --- /dev/null +++ b/env/runtime_std_extensions.h @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ + +#include "env/env_runtime.h" + +namespace cel { + +// Registers the standard CEL extension functions with the given environment +// runtime. This makes them available, but does not enable them. See Env::Config +// for how to enable extensions. +// +// Included in the standard runtime environment: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// +// NOTE: Not included in the standard runtime environment yet - include manually +// if needed: +// - cel.lib.ext.regex (alias: "regex") +// +void RegisterStandardExtensions(EnvRuntime& env_runtime); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ diff --git a/env/runtime_std_extensions_test.cc b/env/runtime_std_extensions_test.cc new file mode 100644 index 000000000..4c7cb9829 --- /dev/null +++ b/env/runtime_std_extensions_test.cc @@ -0,0 +1,229 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/optional.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/strings.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string extension_name; + std::vector extension_versions = {0}; + int latest_extension_version = 0; + std::string expr; + bool requires_optional_extension = false; +}; + +using RuntimeStdExtensionTest = testing::TestWithParam; + +TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { + const TestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env); + + Config compiler_config; + // For the compilation step, assume latest version of the extension to ensure + // a successful compilation. Later, we will test the runtime with different + // extension versions. + ASSERT_THAT(compiler_config.AddExtensionConfig( + param.extension_name, Config::ExtensionConfig::kLatest), + IsOk()); + env.SetConfig(compiler_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + for (int version = 0; version <= param.latest_extension_version; ++version) { + Config runtime_config; + // Request a specific version of the extension to be configured in the + // runtime. + ASSERT_THAT( + runtime_config.AddExtensionConfig(param.extension_name, version), + IsOk()); + if (param.requires_optional_extension) { + ASSERT_THAT(runtime_config.AddExtensionConfig("optional"), IsOk()); + } + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool( + cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(runtime_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + absl::StatusOr> program_or = + runtime->CreateProgram(std::make_unique(*ast)); + + // If the function is not supported in this extension version, check that + // the program creation returned an error. + if (!absl::c_contains(param.extension_versions, version)) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr << " version: " << version; + continue; + } + + ASSERT_THAT(program_or, IsOk()) + << " expr: " << param.expr << " version: " << version; + std::unique_ptr program = *std::move(program_or); + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) + << " expr: " << param.expr << " version: " << version; + } +} + +std::vector GetRuntimeStdExtensionTestCases() { + return { + TestCase{ + // The "bindings" extension does not register any runtime functions - + // only macros. + .extension_name = "bindings", + .expr = "cel.bind(t, 42, t + 1) == 43", + }, + TestCase{ + .extension_name = "encoders", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].slice(0, 1) == [3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].sort() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.least([1, -2, 3]) == -2", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.floor(42.9) == 42.0", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.sqrt(4) == 2.0", + }, + TestCase{ + .extension_name = "optional", + .extension_versions = {0, 1, 2}, + .latest_extension_version = kOptionalExtensionLatestVersion, + .expr = "optional.of(1).hasValue()", + }, + TestCase{ + // No runtime functions. + .extension_name = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension_name = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {0, 1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "strings.quote('hello') == '\"hello\"'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "['hello', 'world'].join(', ') == 'hello, world'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'stressed'.reverse() == 'desserts'", + }, + TestCase{ + // No runtime functions. + .extension_name = "cel.lib.ext.comprev2", + .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", + }, + TestCase{ + .extension_name = "cel.lib.ext.regex", + .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", + .requires_optional_extension = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(RuntimeStdExtensionTest, RuntimeStdExtensionTest, + ValuesIn(GetRuntimeStdExtensionTestCases())); + +} // namespace +} // namespace cel diff --git a/env/type_info.cc b/env/type_info.cc new file mode 100644 index 000000000..f49fab9f4 --- /dev/null +++ b/env/type_info.cc @@ -0,0 +1,410 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {"bool_wrapper", TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {"int_wrapper", TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {"uint_wrapper", TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {"double_wrapper", TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {"string_wrapper", TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"bytes_wrapper", TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} +} // namespace + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return Type::Message(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, descriptor_pool, arena)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], descriptor_pool, arena)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], descriptor_pool, arena)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + return TypeSpec(ParamTypeSpec(type_info.name)); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty()) { + return TypeSpec(MessageTypeSpec(type_info.name)); + } else { + std::vector param_specs; + param_specs.reserve(type_info.params.size()); + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(TypeSpec param_spec, TypeInfoToTypeSpec(param)); + param_specs.push_back(std::move(param_spec)); + } + return TypeSpec(AbstractType(type_info.name, std::move(param_specs))); + } + } + + switch (*type_kind) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec()); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kList: { + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(TypeSpec elem_type, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } else { + return TypeSpec(ListTypeSpec()); + } + } + case TypeKind::kMap: { + if (type_info.params.empty()) { + return TypeSpec(MapTypeSpec()); + } + CEL_ASSIGN_OR_RETURN(TypeSpec key_type, + TypeInfoToTypeSpec(type_info.params[0])); + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN(TypeSpec value_type, + TypeInfoToTypeSpec(type_info.params[1])); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + return TypeSpec(MapTypeSpec( + std::make_unique(std::move(key_type)), nullptr)); + } + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec()); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeSpec(std::make_unique(DynTypeSpec())); + } + CEL_ASSIGN_OR_RETURN(TypeSpec type_param, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec(std::make_unique(std::move(type_param))); + } + default: + return TypeSpec(DynTypeSpec()); + } +} + +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec) { + Config::TypeInfo type_info; + + if (type_spec.has_dyn()) { + type_info.name = "dyn"; + } else if (type_spec.has_null()) { + type_info.name = "null"; + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + type_info.name = "bool"; + break; + case PrimitiveType::kInt64: + type_info.name = "int"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint"; + break; + case PrimitiveType::kDouble: + type_info.name = "double"; + break; + case PrimitiveType::kString: + type_info.name = "string"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes"; + break; + default: + return absl::InvalidArgumentError("Unspecified primitive type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + type_info.name = "bool_wrapper"; + break; + case PrimitiveType::kInt64: + type_info.name = "int_wrapper"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint_wrapper"; + break; + case PrimitiveType::kDouble: + type_info.name = "double_wrapper"; + break; + case PrimitiveType::kString: + type_info.name = "string_wrapper"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes_wrapper"; + break; + default: + return absl::InvalidArgumentError("Unspecified wrapper type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + type_info.name = "any"; + break; + case WellKnownTypeSpec::kTimestamp: + type_info.name = "timestamp"; + break; + case WellKnownTypeSpec::kDuration: + type_info.name = "duration"; + break; + default: + return absl::InvalidArgumentError("Unspecified well known type"); + } + } else if (type_spec.has_list_type()) { + type_info.name = "list"; + const ListTypeSpec& list_type = type_spec.list_type(); + if (list_type.has_elem_type() && list_type.elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(list_type.elem_type())); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_map_type()) { + type_info.name = "map"; + const MapTypeSpec& map_type = type_spec.map_type(); + bool has_key = + map_type.has_key_type() && map_type.key_type().is_specified(); + bool has_value = + map_type.has_value_type() && map_type.value_type().is_specified(); + if (has_key || has_value) { + if (has_key) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(map_type.key_type())); + type_info.params.push_back(std::move(param)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + if (has_value) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_value, + TypeSpecToTypeInfo(map_type.value_type())); + type_info.params.push_back(std::move(param_value)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + } + } else if (type_spec.has_message_type()) { + type_info.name = type_spec.message_type().type(); + } else if (type_spec.has_type_param()) { + type_info.name = type_spec.type_param().type(); + type_info.is_type_param = true; + } else if (type_spec.has_type()) { + type_info.name = "type"; + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(type_spec.type())); + type_info.params.push_back(std::move(param)); + } else if (type_spec.has_abstract_type()) { + type_info.name = type_spec.abstract_type().name(); + for (const TypeSpec& param_spec : + type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(param_spec)); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_error()) { + return absl::InvalidArgumentError( + "ErrorType cannot be converted to TypeInfo"); + } else if (type_spec.has_function()) { + return absl::InvalidArgumentError( + "FunctionType cannot be converted to TypeInfo"); + } else { + return absl::InvalidArgumentError("Unknown TypeSpec kind"); + } + + return type_info; +} + +} // namespace cel diff --git a/env/type_info.h b/env/type_info.h new file mode 100644 index 000000000..3f802ce1a --- /dev/null +++ b/env/type_info.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "env/config.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Converts a Config::TypeInfo to a cel::Type. Returns an error if the type_info +// cannot be converted to a known cel::Type, a list configured with more than +// one parameter. +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); + +// Converts a Config::TypeInfo to a cel::TypeSpec. +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info); + +// Converts a cel::TypeSpec to a Config::TypeInfo. +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc new file mode 100644 index 000000000..f9d46f9a9 --- /dev/null +++ b/env/type_info_test.cc @@ -0,0 +1,300 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast/metadata.h" +#include "common/type.h" +#include "common/type_proto.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { + +std::ostream& operator<<(std::ostream& os, const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + os << "?"; + } + os << type_info.name; + if (!type_info.params.empty()) { + os << "<"; + for (size_t i = 0; i < type_info.params.size(); ++i) { + if (i > 0) os << ", "; + os << type_info.params[i]; + } + os << ">"; + } + return os; +} + +namespace { + +using absl_testing::IsOk; +using absl_testing::StatusIs; +using testing::ValuesIn; + +struct TestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + cel::internal::GetTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN( + cel::Type actual_type, + cel::TypeInfoToType(param.type_info, descriptor_pool, &arena)); + + cel::expr::Type actual_type_pb; + ASSERT_THAT(cel::TypeToProto(actual_type, &actual_type_pb), IsOk()); + EXPECT_THAT(actual_type_pb, + cel::internal::test::EqualsProto(expected_type_pb)); +} + +std::vector GetTestCases() { + return { + TestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { type_param: 'B' } }", + }, + TestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "double_wrapper"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { type_param: 'A' } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); + +bool TypeInfoEqImpl(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + if (actual.name != expected.name) return false; + if (actual.is_type_param != expected.is_type_param) return false; + if (actual.params.size() != expected.params.size()) return false; + for (size_t i = 0; i < actual.params.size(); ++i) { + if (!TypeInfoEqImpl(actual.params[i], expected.params[i])) return false; + } + return true; +} + +MATCHER_P(TypeInfoEq, expected, "") { return TypeInfoEqImpl(arg, expected); } + +struct TypeSpecTestCase { + TypeSpec type_spec; + Config::TypeInfo expected_type_info; +}; + +using TypeSpecToTypeInfoTest = testing::TestWithParam; + +TEST_P(TypeSpecToTypeInfoTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config::TypeInfo actual_type_info, + TypeSpecToTypeInfo(param.type_spec)); + EXPECT_THAT(actual_type_info, TypeInfoEq(param.expected_type_info)); +} + +std::vector GetTypeSpecTestCases() { + return { + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveType::kInt64), + .expected_type_info = {.name = "int"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(ListTypeSpec()), + .expected_type_info = {.name = "list"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(MapTypeSpec()), + .expected_type_info = {.name = "map"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto2.TestAllTypes")), + .expected_type_info = + {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + }, + TypeSpecTestCase{ + .type_spec = + TypeSpec(AbstractType("A", {TypeSpec(ParamTypeSpec("B"))})), + .expected_type_info = {.name = "A", + .params = {Config::TypeInfo{ + .name = "B", .is_type_param = true}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kAny), + .expected_type_info = {.name = "any"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_type_info = {.name = "timestamp"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_type_info = {.name = "double_wrapper"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + std::make_unique(WellKnownTypeSpec::kDuration)), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = + "duration"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(std::make_unique(DynTypeSpec())), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(DynTypeSpec{}), + .expected_type_info = {.name = "dyn"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(NullTypeSpec{}), + .expected_type_info = {.name = "null"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "dyn"}, + Config::TypeInfo{.name = "int"}}}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSpecToTypeInfoTest, TypeSpecToTypeInfoTest, + ValuesIn(GetTypeSpecTestCases())); + +using TypeInfoToTypeSpecTest = testing::TestWithParam; + +TEST_P(TypeInfoToTypeSpecTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(TypeSpec actual_type_spec, + TypeInfoToTypeSpec(param.expected_type_info)); + EXPECT_EQ(actual_type_spec, param.type_spec); +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoToTypeSpecTest, TypeInfoToTypeSpecTest, + ValuesIn(GetTypeSpecTestCases())); + +TEST(TypeSpecToTypeInfoTest, ErrorConversions) { + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(ErrorTypeSpec::kValue)), + StatusIs(absl::StatusCode::kInvalidArgument, + "ErrorType cannot be converted to TypeInfo")); + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(FunctionTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, + "FunctionType cannot be converted to TypeInfo")); + EXPECT_THAT( + TypeSpecToTypeInfo(TypeSpec(UnsetTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, "Unknown TypeSpec kind")); +} + +} // namespace +} // namespace cel diff --git a/eval/README.md b/eval/README.md index ee6fd0798..32fa4bda4 100644 --- a/eval/README.md +++ b/eval/README.md @@ -3,4 +3,4 @@ A C++ implementation of a [Common Expression Language][1] evaluator. -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..f7300cb58 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -95,6 +95,7 @@ cc_library( "flat_expr_builder.h", ], deps = [ + ":check_ast_extensions", ":flat_expr_builder_extensions", ":resolver", "//base:ast", @@ -192,6 +193,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//parser:options", "//runtime:function", "//runtime:function_adapter", "//runtime:runtime_options", @@ -412,6 +414,33 @@ cc_library( ], ) +cc_library( + name = "check_ast_extensions", + srcs = ["check_ast_extensions.cc"], + hdrs = ["check_ast_extensions.h"], + deps = [ + "//common:ast", + "//common/ast:metadata", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "check_ast_extensions_test", + srcs = ["check_ast_extensions_test.cc"], + deps = [ + ":check_ast_extensions", + "//common:ast", + "//common:expr", + "//common/ast:metadata", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "resolver", srcs = ["resolver.cc"], diff --git a/eval/compiler/check_ast_extensions.cc b/eval/compiler/check_ast_extensions.cc new file mode 100644 index 000000000..37181b535 --- /dev/null +++ b/eval/compiler/check_ast_extensions.cc @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast) { + std::vector runtime_extensions; + absl::flat_hash_set seen_extension_ids; + + for (const cel::ExtensionSpec& extension : ast.source_info().extensions()) { + bool is_runtime = false; + for (const cel::ExtensionSpec::Component& component : + extension.affected_components()) { + if (component == cel::ExtensionSpec::Component::kRuntime) { + is_runtime = true; + break; + } + } + + if (!is_runtime) { + continue; + } + + if (!seen_extension_ids.insert(extension.id()).second) { + return absl::InvalidArgumentError( + absl::StrCat("duplicate extension ID: ", extension.id())); + } + runtime_extensions.push_back(extension); + } + + return runtime_extensions; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/check_ast_extensions.h b/eval/compiler/check_ast_extensions.h new file mode 100644 index 000000000..443c6ac09 --- /dev/null +++ b/eval/compiler/check_ast_extensions.h @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ + +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +// Extracts and validates extension tags from the AST `ast` that affect the +// runtime component. Returns the validated list of runtime extensions, or an +// error if there are multiple runtime extensions with the same ID. +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ diff --git a/eval/compiler/check_ast_extensions_test.cc b/eval/compiler/check_ast_extensions_test.cc new file mode 100644 index 000000000..9e5838905 --- /dev/null +++ b/eval/compiler/check_ast_extensions_test.cc @@ -0,0 +1,110 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/ast/metadata.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Ast; +using ::cel::Expr; +using ::cel::ExtensionSpec; +using ::cel::SourceInfo; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; +using ::testing::SizeIs; + +TEST(ExtractAndValidateRuntimeExtensionsTest, EmptyExtensions) { + Ast ast(Expr{}, SourceInfo{}); + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FiltersNonRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext2", nullptr, {ExtensionSpec::Component::kTypeChecker})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, ExtractsRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext2", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext3", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")), + Property(&ExtensionSpec::id, Eq("ext2")))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FailsOnDuplicateRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext1", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + StatusIs(absl::StatusCode::kInvalidArgument, + "duplicate extension ID: ext1")); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, IgnoresDuplicateNonRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a0fd427bd..aa9a8858c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +60,7 @@ #include "common/kind.h" #include "common/type.h" #include "common/value.h" +#include "eval/compiler/check_ast_extensions.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" @@ -109,6 +111,13 @@ constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: @@ -223,7 +232,7 @@ class BinaryCondVisitor : public CondVisitor { private: FlatExprVisitor* visitor_; const BinaryCond cond_; - Jump jump_step_; + std::vector jump_steps_; bool short_circuiting_; }; @@ -577,6 +586,12 @@ class FlatExprVisitor : public cel::AstVisitor { } } + void SetMaxRecursionDepth(int max_recursion_depth) { + max_recursion_depth_ = max_recursion_depth; + } + + bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } + void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); @@ -607,7 +622,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPreVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); } } } @@ -624,7 +639,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPostVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -642,7 +657,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_bind && (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { - SetProgressStatusError( + SetProgressStatusIfError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } @@ -651,7 +666,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (block.current_binding == &expr) { int index = program_builder_.ExtractSubexpression(&expr); if (index == -1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("failed to extract subexpression")); return; } @@ -671,7 +686,7 @@ class FlatExprVisitor : public cel::AstVisitor { ConvertConstant(const_expr, cel::NewDeleteAllocator()); if (!converted_value.ok()) { - SetProgressStatusError(converted_value.status()); + SetProgressStatusIfError(converted_value.status()); return; } @@ -707,13 +722,13 @@ class FlatExprVisitor : public cel::AstVisitor { if (absl::ConsumePrefix(&index_suffix, "@index")) { size_t index; if (!absl::SimpleAtoi(index_suffix, &index)) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("bad @index")))); return {-1, -1}; } if (index >= block.size) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "invalid @index greater than number of bindings: ", @@ -721,7 +736,7 @@ class FlatExprVisitor : public cel::AstVisitor { return {-1, -1}; } if (index >= block.current_index) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "@index references current or future binding: ", index, @@ -739,7 +754,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (record.iter_var_in_scope && record.comprehension->iter_var() == path) { if (record.is_optimizable_bind) { - SetProgressStatusError(issue_collector_.AddIssue( + SetProgressStatusIfError(issue_collector_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "Unexpected iter_var access in trivial comprehension")))); return {-1, -1}; @@ -766,7 +781,7 @@ class FlatExprVisitor : public cel::AstVisitor { // If we see a CSE generated comprehension variable that was not // resolvable through the normal comprehension scope resolution, reject it // now rather than surfacing errors at activation time. - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("out of scope reference to CSE " "generated comprehension variable")))); @@ -796,7 +811,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto* subexpression = program_builder_.GetExtractedSubexpression(slot.subexpression); if (subexpression == nullptr) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InternalError("bad subexpression reference")); return; } @@ -825,7 +840,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. - absl::optional const_value; + std::optional const_value; int64_t select_root_id = -1; std::string path_candidate; @@ -947,11 +962,10 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "unexpected number of dependencies for select operation.")); return; } @@ -1008,7 +1022,7 @@ class FlatExprVisitor : public cel::AstVisitor { // cel.@block if (block_.has_value()) { // There can only be one for now. - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("multiple cel.@block are not allowed")); return; } @@ -1016,32 +1030,28 @@ class FlatExprVisitor : public cel::AstVisitor { BlockInfo& block = *block_; block.in = true; if (call_expr.args().empty()) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing list of bound expressions")); return; } if (call_expr.args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing bound expression")); return; } if (!call_expr.args()[0].has_list_expr()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: first argument " "is not a list of bound expressions")); return; } const auto& list_expr = call_expr.args().front().list_expr(); block.size = list_expr.elements().size(); - if (block.size == 0) { - SetProgressStatusError(absl::InvalidArgumentError( - "malformed cel.@block: list of bound expressions is empty")); - return; - } + block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: list of bound " "expressions contains an optional")); return; @@ -1064,21 +1074,13 @@ class FlatExprVisitor : public cel::AstVisitor { } } - absl::optional RecursionEligible() { - if (program_builder_.current() == nullptr) { - return absl::nullopt; - } - absl::optional depth = - program_builder_.current()->RecursiveDependencyDepth(); - if (!depth.has_value()) { - // one or more of the dependencies isn't eligible. - return depth; - } - if (options_.max_recursion_depth < 0 || - *depth < options_.max_recursion_depth) { - return depth; + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. + std::optional RecursionEligible() { + if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { + return std::nullopt; } - return absl::nullopt; + return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> @@ -1089,12 +1091,9 @@ class FlatExprVisitor : public cel::AstVisitor { return program_builder_.current()->ExtractRecursiveDependencies(); } - void MaybeMakeTernaryRecursive(const cel::Expr* expr) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); return; } @@ -1107,26 +1106,16 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, condition_plan->recursive_program().depth); - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, @@ -1136,60 +1125,53 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth + 1); } - void MaybeMakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (options_.max_recursion_depth == 0) { - return; - } - if (expr->call_expr().args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { + int args_size = expr->call_expr().args().size(); + if (args_size < 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); return; } - const cel::Expr* left_expr = &expr->call_expr().args()[0]; - const cel::Expr* right_expr = &expr->call_expr().args()[1]; - - auto* left_plan = program_builder_.GetSubexpression(left_expr); - auto* right_plan = program_builder_.GetSubexpression(right_expr); - - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + auto* current_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[0]); + if (current_plan == nullptr || !current_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } + int current_depth = current_plan->recursive_program().depth; + std::unique_ptr current_step = + current_plan->ExtractRecursiveProgram().step; - if (is_or) { - SetRecursiveStep( - CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); - } else { - SetRecursiveStep( - CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); + for (int i = 1; i < args_size; ++i) { + auto* next_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[i]); + if (next_plan == nullptr || !next_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + current_depth = + std::max(current_depth, next_plan->recursive_program().depth); + std::unique_ptr next_step = + next_plan->ExtractRecursiveProgram().step; + if (is_or) { + current_step = + CreateDirectOrStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } else { + current_step = + CreateDirectAndStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } + current_depth++; } + SetRecursiveStep(std::move(current_step), current_depth); } - void MaybeMakeOptionalShortcircuitRecursive(const cel::Expr* expr, - bool is_or_value) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); return; } @@ -1199,21 +1181,13 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, @@ -1225,7 +1199,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } @@ -1233,16 +1207,12 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; - if (options_.max_recursion_depth > 0 && - result_depth >= options_.max_recursion_depth) { - return; - } - auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), @@ -1252,42 +1222,26 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); - - if (accu_plan == nullptr || !accu_plan->IsRecursive()) { - return; - } - auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); - - if (range_plan == nullptr || !range_plan->IsRecursive()) { - return; - } - auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); - - if (loop_plan == nullptr || !loop_plan->IsRecursive()) { - return; - } - auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); - - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); - - if (result_plan == nullptr || !result_plan->IsRecursive()) { + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1298,11 +1252,6 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); - if (options_.max_recursion_depth > 0 && - max_depth >= options_.max_recursion_depth) { - return; - } - auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, @@ -1520,7 +1469,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( + SetProgressStatusIfError(comprehension_stack_.back().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.back().expr)); } @@ -1566,7 +1515,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } @@ -1579,11 +1528,10 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - absl::optional depth = RecursionEligible(); - if (depth.has_value()) { + if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateList expr")); return; } @@ -1606,7 +1554,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { - SetProgressStatusError(status_or_resolved_fields.status()); + SetProgressStatusIfError(status_or_resolved_fields.status()); return; } std::string resolved_name = @@ -1614,11 +1562,10 @@ class FlatExprVisitor : public cel::AstVisitor { std::vector fields = std::move(status_or_resolved_fields.value().second); - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1646,7 +1593,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } @@ -1656,11 +1603,10 @@ class FlatExprVisitor : public cel::AstVisitor { } } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1696,8 +1642,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), @@ -1723,12 +1668,13 @@ class FlatExprVisitor : public cel::AstVisitor { "No overloads provided for FunctionStep creation"), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } - auto recursion_depth = RecursionEligible(); - if (recursion_depth.has_value()) { + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); @@ -1753,7 +1699,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (step.ok()) { return AddStep(*std::move(step)); } else { - SetProgressStatusError(step.status()); + SetProgressStatusIfError(step.status()); } return nullptr; } @@ -1772,14 +1718,19 @@ class FlatExprVisitor : public cel::AstVisitor { return; } if (program_builder_.current() == nullptr) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "CEL AST traversal out of order in flat_expr_builder.")); return; } program_builder_.current()->set_recursive_program(std::move(step), depth); + if (depth > max_recursion_depth_) { + SetProgressStatusIfError(absl::InvalidArgumentError( + absl::StrCat("Maximum recursion depth of ", + options_.max_recursion_depth, " exceeded"))); + } } - void SetProgressStatusError(const absl::Status& status) { + void SetProgressStatusIfError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -1821,7 +1772,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (valid_expression) { return true; } - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat(error_message, message_parts...))); return false; } @@ -1907,7 +1858,7 @@ class FlatExprVisitor : public cel::AstVisitor { int64_t expr_id) { absl::string_view ast_name = create_struct_expr.name(); - absl::optional> type; + std::optional> type; CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); if (!type.has_value()) { @@ -1980,17 +1931,17 @@ class FlatExprVisitor : public cel::AstVisitor { IssueCollector& issue_collector_; ProgramBuilder& program_builder_; - PlannerContext extension_context_; + PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; - absl::optional block_; + std::optional block_; + int max_recursion_depth_ = 0; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); - auto depth = RecursionEligible(); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a @@ -2000,10 +1951,10 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); return CallHandlerResult::kIntercepted; } @@ -2027,12 +1978,10 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin not operator")); return CallHandlerResult::kIntercepted; } @@ -2046,18 +1995,16 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { - auto depth = RecursionEligible(); - if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected number of args for builtin " "@not_strictly_false operator")); return CallHandlerResult::kIntercepted; @@ -2076,7 +2023,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( ABSL_DCHECK(call_expr.function() == kBlock); if (!block_.has_value() || block_->expr != &expr || call_expr.args().size() != 2 || call_expr.has_target()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected call to internal cel.@block")); return CallHandlerResult::kIntercepted; } @@ -2108,7 +2055,9 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( } // Otherwise, iterative plan. - AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + if (block.slot_count > 0) { + AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + } return CallHandlerResult::kIntercepted; } @@ -2155,12 +2104,11 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin equality operator")); return CallHandlerResult::kIntercepted; } @@ -2182,11 +2130,10 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin 'in' operator")); return CallHandlerResult::kIntercepted; } @@ -2207,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { case BinaryCond::kOr: visitor_->ValidateOrError( !expr->call_expr().has_target() && - expr->call_expr().args().size() == 2, + expr->call_expr().args().size() >= 2, "Invalid argument count for a binary function call."); break; case BinaryCond::kOptionalOr: @@ -2221,33 +2168,52 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { - if (short_circuiting_ && arg_num == 0 && - (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { - // If first branch evaluation result is enough to determine output, - // jump over the second branch and provide result of the first argument as - // final output. - // Retain a pointer to the jump step so we can update the target after - // planning the second argument. - std::unique_ptr jump_step; - switch (cond_) { - case BinaryCond::kAnd: - jump_step = CreateCondJumpStep(false, true, {}, expr->id()); - break; - case BinaryCond::kOr: - jump_step = CreateCondJumpStep(true, true, {}, expr->id()); - break; - default: - ABSL_UNREACHABLE(); + if (visitor_->PlanRecursiveProgram()) { + return; + } + const int last_arg_index = expr->call_expr().args().size() - 1; + if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) { + if (arg_num > 0) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + default: + break; + } + if (short_circuiting_ && !jump_steps_.empty()) { + visitor_->SetProgressStatusIfError( + jump_steps_.back().set_target(visitor_->GetCurrentIndex())); + } } - ProgramStepIndex index = visitor_->GetCurrentIndex(); - if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); - jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + if (short_circuiting_ && arg_num < last_arg_index) { + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kAnd: + jump_step = CreateCondJumpStep(false, true, {}, expr->id()); + break; + case BinaryCond::kOr: + jump_step = CreateCondJumpStep(true, true, {}, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); + } } } } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, @@ -2269,54 +2235,54 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { - switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; - case BinaryCond::kOptionalOr: - visitor_->AddStep( - CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); - break; - case BinaryCond::kOptionalOrValue: - visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); - break; - default: - ABSL_UNREACHABLE(); - } - if (short_circuiting_) { - // If short-circuiting is enabled, point the conditional jump past the - // boolean operator step. - visitor_->SetProgressStatusError( - jump_step_.set_target(visitor_->GetCurrentIndex())); + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; } - // Handle maybe replacing the subprogram with a recursive version. This needs - // to happen after the jump step is updated (though it may get overwritten). - switch (cond_) { - case BinaryCond::kAnd: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false); - break; - case BinaryCond::kOr: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true); - break; - case BinaryCond::kOptionalOr: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/false); - break; - case BinaryCond::kOptionalOrValue: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/true); - break; - default: - ABSL_UNREACHABLE(); + + if (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue) { + switch (cond_) { + case BinaryCond::kOptionalOr: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); + break; + case BinaryCond::kOptionalOrValue: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); + break; + default: + ABSL_UNREACHABLE(); + } + if (short_circuiting_) { + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } + } } } @@ -2327,6 +2293,9 @@ void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -2370,7 +2339,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->ValidateOrError( jump_to_second_.exists(), "Error configuring ternary operator: jump_to_second_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } @@ -2380,20 +2349,23 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), "Error configuring ternary operator: error_jump_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( error_jump_.set_target(visitor_->GetCurrentIndex())); } if (visitor_->ValidateOrError( jump_after_first_.exists(), "Error configuring ternary operator: jump_after_first_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } - visitor_->MaybeMakeTernaryRecursive(expr); } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2403,8 +2375,11 @@ void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } visitor_->AddStep(CreateTernaryStep(expr->id())); - visitor_->MaybeMakeTernaryRecursive(expr); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { @@ -2417,6 +2392,9 @@ void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); @@ -2443,7 +2421,8 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } Jump jump_helper(index, jump_to_next); - visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); + visitor_->SetProgressStatusIfError( + jump_helper.set_target(next_step_pos_)); // Set offsets jumping to the result step. if (cond_step_) { @@ -2491,6 +2470,9 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { case cel::ITER_RANGE: { break; @@ -2548,6 +2530,22 @@ std::vector FlattenExpressionTable( return subexpression_indexes; } +absl::Status CheckAstExtensions( + const std::vector& extensions) { + for (const cel::ExtensionSpec& extension : extensions) { + if (extension.id() == "cel_block" && extension.version().major() == 1) { + // cel_block v1 is always supported. + continue; + } + + // TODO(uncreated-issue/89): Add support for json field names. + return absl::InvalidArgumentError(absl::StrCat( + "unsupported CEL extension: ", extension.id(), "@", + extension.version().major(), ".", extension.version().minor())); + } + return absl::OkStatus(); +} + } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( @@ -2561,6 +2559,21 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); + + absl::StatusOr> runtime_extensions = + ExtractAndValidateRuntimeExtensions(*ast); + + if (!runtime_extensions.ok()) { + CEL_RETURN_IF_ERROR(issue_collector.AddIssue( + RuntimeIssue::CreateError(runtime_extensions.status()))); + } + + auto status = CheckAstExtensions(*runtime_extensions); + if (!status.ok()) { + CEL_RETURN_IF_ERROR( + issue_collector.AddIssue(RuntimeIssue::CreateError(status))); + } + Resolver resolver(container_, function_registry_, type_registry_, GetTypeProvider(), options_.enable_qualified_type_identifiers); @@ -2590,6 +2603,13 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( issue_collector, program_builder, extension_context, enable_optional_types_); + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + int depth_limit = options_.max_recursion_depth == -1 + ? std::numeric_limits::max() + : options_.max_recursion_depth; + visitor.SetMaxRecursionDepth(depth_limit); + } + cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index eab1e7ff8..aa4d0b4e5 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -23,12 +23,10 @@ #include #include "absl/base/nullability.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/type_provider.h" -#include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" @@ -53,18 +51,6 @@ class FlatExprBuilder { type_registry_(env_->type_registry), use_legacy_type_provider_(use_legacy_type_provider) {} - FlatExprBuilder( - absl_nonnull std::shared_ptr env, - const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) - : env_(std::move(env)), - options_(options), - container_(options.container), - function_registry_(function_registry), - type_registry_(type_registry), - use_legacy_type_provider_(use_legacy_type_provider) {} - void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index 463b48425..ee106ff4a 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -98,19 +98,19 @@ size_t Subexpression::ComputeSize() const { return size; } -absl::optional Subexpression::RecursiveDependencyDepth() const { +std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { - return absl::nullopt; + return std::nullopt; } for (const auto& element : *tree) { auto* subexpression = absl::get_if(&element); if (subexpression == nullptr) { - return absl::nullopt; + return std::nullopt; } if (!(*subexpression)->IsRecursive()) { - return absl::nullopt; + return std::nullopt; } depth = std::max(depth, (*subexpression)->recursive_program().depth); } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2b705398a..105060282 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "eval/compiler/flat_expr_builder.h" @@ -66,6 +64,7 @@ #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/options.h" #include "parser/parser.h" #include "runtime/function.h" #include "runtime/function_adapter.h" @@ -187,6 +186,20 @@ TEST(FlatExprBuilderTest, ExprUnset) { HasSubstr("Invalid empty expression"))); } +TEST(FlatExprBuilderTest, RuntimeExtensionsError) { + Expr expr; + SourceInfo source_info; + auto* ext = source_info.add_extensions(); + ext->set_id("ext1"); + ext->add_affected_components( + cel::expr::SourceInfo_Extension_Component_COMPONENT_RUNTIME); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unsupported CEL extension: ext1"))); +} + TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; @@ -457,7 +470,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -470,10 +483,10 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ operand{ ident_expr {name: 'var'} } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -486,11 +499,11 @@ TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ field: 'field' operand { id: 1 } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -503,7 +516,8 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -515,10 +529,10 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{accu_var: "a"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -530,12 +544,12 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: "a" iter_var: "b"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -547,7 +561,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -555,7 +569,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -567,7 +581,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -578,7 +592,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -590,7 +604,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -604,7 +618,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { const_expr {bool_value: false} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -616,7 +630,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -653,7 +667,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -671,7 +685,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; SourceInfo source_info; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( call_expr { function: "_&&_" args { @@ -685,7 +699,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -897,7 +911,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( expr { id: 1 call_expr { @@ -916,7 +930,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -934,7 +948,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CheckedExpr expr; // `foo.var1` && `bar.var2` - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -976,7 +990,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -996,7 +1010,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { CheckedExpr expr; // ext.and(var1, bar.var2) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 1 value { @@ -1045,7 +1059,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1070,7 +1084,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CheckedExpr expr; // && . - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -1113,7 +1127,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1148,7 +1162,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { CheckedExpr expr; // {`var1`: 'hello'} - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 3 value { @@ -1178,7 +1192,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1201,7 +1215,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { Expr expr; SourceInfo source_info; // {}[0].all(x, x) should evaluate OK but return an error value - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1266,7 +1280,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1283,7 +1297,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1337,7 +1351,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1709,7 +1723,7 @@ TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { TEST(FlatExprBuilderTest, MapFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1719,7 +1733,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { field: "string_int32_map" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -1753,7 +1767,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { TEST(FlatExprBuilderTest, RepeatedFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1763,7 +1777,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { field: "int32_list" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -2808,6 +2822,7 @@ TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { ParsedExpr parsed_expr; + // Allowed, but degenerate case. ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { @@ -2823,10 +2838,8 @@ TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), - StatusIs( - absl::StatusCode::kInvalidArgument, - HasSubstr( - "malformed cel.@block: list of bound expressions is empty"))); + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid @index greater than number of bindings:"))); } TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { @@ -2889,6 +2902,252 @@ TEST(FlatExprBuilderTest, BlockNested) { HasSubstr("multiple cel.@block are not allowed"))); } +struct VariadicLogicalEvalTestCase { + std::string label; + std::string expr; + std::string a_val; + std::string b_val; + std::string c_val; + std::string expected_type; // "bool", "error", "unknown" + bool expected_bool = false; +}; + +class FlatExprBuilderVariadicLogicalTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { + const auto& test_case = GetParam(); + parser::ParserOptions parser_options; + parser_options.enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse(test_case.expr, test_case.label, parser_options)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + std::vector unknown_patterns; + + // Set up variables: + auto insert_value = [&](absl::string_view name, const std::string& val) { + if (val == "true") { + activation.InsertValue(name, CelValue::CreateBool(true)); + } else if (val == "false") { + activation.InsertValue(name, CelValue::CreateBool(false)); + } else if (val == "error") { + activation.InsertValue(name, CreateErrorValue(&arena, "test error")); + } else if (val == "unknown1" || val == "unknown2") { + activation.InsertValue(name, CelValue::CreateBool(true)); + unknown_patterns.push_back(CreateCelAttributePattern(name, {})); + } + }; + + insert_value("a", test_case.a_val); + insert_value("b", test_case.b_val); + insert_value("c", test_case.c_val); + + if (!unknown_patterns.empty()) { + activation.set_unknown_attribute_patterns(std::move(unknown_patterns)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + if (test_case.expected_type == "bool") { + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_EQ(result.BoolOrDie(), test_case.expected_bool); + } else if (test_case.expected_type == "error") { + EXPECT_TRUE(result.IsError()) << result.DebugString(); + } else if (test_case.expected_type == "unknown") { + EXPECT_TRUE(result.IsUnknownSet()) << result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderVariadicLogicalTest, FlatExprBuilderVariadicLogicalTest, + testing::Values( + VariadicLogicalEvalTestCase{"AND_AllTrue", "a && b && c", "true", + "true", "true", "bool", true}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFalse", "a && b && c", + "true", "false", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFirstFalse", "a && b && c", + "false", "unset", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"OR_AllFalse", "a || b || c", "false", + "false", "false", "bool", false}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitTrue", "a || b || c", + "false", "true", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitFirstTrue", "a || b || c", + "true", "unset", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Error", "a && b && c", "true", "error", + "true", "error"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeError", + "a && b && c", "false", "error", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Error", "a || b || c", "false", "error", + "false", "error"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeError", "a || b || c", + "true", "error", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Unknown", "a && b && c", "true", + "unknown1", "true", "unknown"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeUnknown", + "a && b && c", "false", "unknown1", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Unknown", "a || b || c", "false", + "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeUnknown", + "a || b || c", "true", "unknown1", "unset", + "bool", true}, + VariadicLogicalEvalTestCase{"AND_UnknownAggregation", "a && b && c", + "unknown1", "unknown2", "true", "unknown"}, + VariadicLogicalEvalTestCase{"OR_UnknownAggregation", "a || b || c", + "unknown1", "unknown2", "false", "unknown"}, + VariadicLogicalEvalTestCase{"Exists_True", "[a, b, c].exists(x, x)", + "false", "false", "true", "bool", true}, + VariadicLogicalEvalTestCase{"Exists_Unknown", "[a, b, c].exists(x, x)", + "false", "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"All_False", "[a, b, c].all(x, x)", "true", + "true", "false", "bool", false}, + VariadicLogicalEvalTestCase{"All_Unknown", "[a, b, c].all(x, x)", + "true", "unknown1", "true", "unknown"})); + +struct RecursionDepthTestCase { + std::string label; + std::string expr; + int max_recursion_depth; + absl::StatusCode expected_status_code; + std::string expected_error_msg; +}; + +class FlatExprBuilderRecursionDepthTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderRecursionDepthTest, CheckRecursionLimit) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = test_case.max_recursion_depth; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + auto result = + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()); + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, StatusIs(test_case.expected_status_code, + HasSubstr(test_case.expected_error_msg))); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderRecursionDepthTest, FlatExprBuilderRecursionDepthTest, + testing::Values( + RecursionDepthTestCase{"AndChildLimitExceeded", "(1 + 1) && true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"AndParentLimitExceeded", "(1 + 1) && true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"AndLimitSuccess", "(1 + 1) && true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessGenerous", "(1 + 1) && true", 10, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessUnlimited", "(1 + 1) && true", + -1, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrChildLimitExceeded", "(1 + 1) || true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"OrParentLimitExceeded", "(1 + 1) || true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"OrLimitSuccess", "(1 + 1) || true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessGenerous", + "(1 + 1) || false || false || false || false || " + "(true && true && true && true && false)", + 10, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessUnlimited", "(1 + 1) || true", -1, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndDepthUpdateFromSubsequentArg", + "true && (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"}, + RecursionDepthTestCase{"OrDepthUpdateFromSubsequentArg", + "true || (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"})); + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockAndError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_&&_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockOrError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_||_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 09950bfe8..158e492be 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -81,9 +81,9 @@ bool OverloadExists(const Resolver& resolver, absl::string_view name, // Return the qualified name of the most qualified matching overload, or // nullopt if no matches are found. -absl::optional BestOverloadMatch(const Resolver& resolver, - absl::string_view base_name, - int argument_count) { +std::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { if (IsSpecialFunction(base_name)) { return std::string(base_name); } @@ -99,7 +99,7 @@ absl::optional BestOverloadMatch(const Resolver& resolver, return *name; } } - return absl::nullopt; + return std::nullopt; } // Rewriter visitor for resolving references. @@ -135,8 +135,17 @@ class ReferenceResolver : public cel::AstRewriterBase { expr.mutable_const_expr().set_int64_value( reference->value().int64_value()); return true; + } else if (expr.has_ident_expr()) { + // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes + // it is interpreted as null value and sometimes as an enum constant. + if (reference->value().has_null_value() && + expr.ident_expr().name() == + "google.protobuf.NullValue.NULL_VALUE") { + return false; + } + expr.set_const_expr(reference->value()); + return true; } else { - // No update if the constant reference isn't an int (an enum value). return false; } } @@ -253,27 +262,27 @@ class ReferenceResolver : public cel::AstRewriterBase { // Convert a select expr sub tree into a namespace name if possible. // If any operand of the top element is a not a select or an ident node, // return nullopt. - absl::optional ToNamespace(const Expr& expr) { - absl::optional maybe_parent_namespace; + std::optional ToNamespace(const Expr& expr) { + std::optional maybe_parent_namespace; if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. - return absl::nullopt; + return std::nullopt; } if (expr.has_ident_expr()) { return expr.ident_expr().name(); } else if (expr.has_select_expr()) { if (expr.select_expr().test_only()) { - return absl::nullopt; + return std::nullopt; } maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); if (!maybe_parent_namespace.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*maybe_parent_namespace, ".", expr.select_expr().field()); } else { - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 0d710a465..3fa7fca21 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -45,6 +45,7 @@ namespace google::api::expr::runtime { namespace { +using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Ast; @@ -343,6 +344,60 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { })pb")); } +// foo && bar +constexpr char kConstReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { + name: "foo" + } + } + args { + id: 5 + ident_expr { + name: "bar" + } + } + } +)"; + +TEST(ResolveReferences, ConstReferenceFolded) { + std::unique_ptr expr_ast = ParseTestProto(kConstReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar"); + expr_ast->mutable_reference_map()[5].mutable_value().set_bool_value(false); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + const_expr { bool_value: true } + } + args { + id: 5 + const_expr { bool_value: false } + } + })pb")); +} + TEST(ResolveReferences, ConstReferenceSkipped) { std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; @@ -388,6 +443,42 @@ TEST(ResolveReferences, ConstReferenceSkipped) { })pb")); } +constexpr char kNullValueReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_+_" + args { + id: 2 + ident_expr { + name: "google.protobuf.NullValue.NULL_VALUE" + } + } + args { + id: 5 + const_expr { int64_value: 1 } + } + } +)"; + +TEST(ResolveReferences, NullValueReferenceSkipped) { + std::unique_ptr expr_ast = ParseTestProto(kNullValueReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name( + "google.protobuf.NullValue.NULL_VALUE"); + expr_ast->mutable_reference_map()[2].mutable_value().set_null_value(nullptr); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(/*was_rewritten=*/false)); +} + constexpr char kExtensionAndExpr[] = R"( id: 1 call_expr { diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index b94cae383..38ef842b9 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -145,7 +145,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Try to check if the regex is valid, whether or not we can actually update // the plan. - absl::optional pattern = + std::optional pattern = GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); @@ -168,7 +168,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { } private: - absl::optional GetConstantString( + std::optional GetConstantString( PlannerContext& context, ProgramBuilder::Subexpression* absl_nullable subexpression, const Expr& call_expr, const Expr& re_expr) const { @@ -178,9 +178,9 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { if (subexpression == nullptr || subexpression->IsFlattened()) { // Already modified, can't recover the input pattern. - return absl::nullopt; + return std::nullopt; } - absl::optional constant; + std::optional constant; if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); auto deps = program.step->GetDependencies(); @@ -206,7 +206,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { return Cast(*constant).ToString(); } - return absl::nullopt; + return std::nullopt; } absl::Status RewritePlan( diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 4e3fa3841..cca72964a 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -102,8 +102,8 @@ absl::Span Resolver::GetPrefixesFor( return namespace_prefixes_; } -absl::optional Resolver::FindConstant(absl::string_view name, - int64_t expr_id) const { +std::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); @@ -128,7 +128,7 @@ absl::optional Resolver::FindConstant(absl::string_view name, return TypeValue(**type_value); } } - return absl::nullopt; + return std::nullopt; } std::vector Resolver::FindOverloads( @@ -205,7 +205,7 @@ std::vector Resolver::FindLazyOverloads( return funcs; } -absl::StatusOr>> +absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (auto& prefix : prefixes) { @@ -216,7 +216,7 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } } - return absl::nullopt; + return std::nullopt; } } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 117516caf..1e044627e 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -123,7 +123,7 @@ absl::optional AttributeUtility::MergeUnknowns( } if (!result_set.has_value()) { - return absl::nullopt; + return std::nullopt; } return UnknownValue(cel::Unknown(result_set->unknown_attributes(), diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index fda51e34f..4cf4ebf4d 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -55,7 +55,7 @@ absl::optional CelNumberFromValue(const Value& value) { case ValueKind::kDouble: return Number::FromDouble(value.GetDouble().NativeValue()); default: - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/eval/equality_steps.cc b/eval/eval/equality_steps.cc index e134069d5..d720302e4 100644 --- a/eval/eval/equality_steps.cc +++ b/eval/eval/equality_steps.cc @@ -132,15 +132,11 @@ class IterativeEqualityStep : public ExpressionStepBase { absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, const Value& item, const MapValue& container) { - absl::StatusOr result = {BoolValue(false)}; switch (item.kind()) { case ValueKind::kBool: case ValueKind::kString: case ValueKind::kInt: case ValueKind::kUint: - result = container.Has(item, frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - break; case ValueKind::kDouble: break; default: @@ -148,9 +144,12 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, cel::runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIn)); } + Value result; + CEL_RETURN_IF_ERROR(container.Has(item, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), + &result)); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + if (result.IsTrue()) { return result; } @@ -159,10 +158,10 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromUint64(item.GetUint().NativeValue()); if (number.LosslessConvertibleToInt()) { - result = container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + CEL_RETURN_IF_ERROR( + container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } @@ -173,21 +172,16 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromInt64(item.GetInt().NativeValue()); if (number.LosslessConvertibleToUint()) { - result = + CEL_RETURN_IF_ERROR( container.Has(UintValue(number.AsUint()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } } - if (!result.ok()) { - return BoolValue(false); - } - - return result; + return BoolValue(false); } absl::StatusOr EvaluateIn(ExecutionFrameBase& frame, const Value& item, diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 2a10e9674..12c5af8a7 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -286,20 +286,12 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr ResolveStatic( absl::Span input_args, absl::Span overloads) { - ResolveResult result = absl::nullopt; - for (const auto& overload : overloads) { if (ArgumentKindsMatch(overload.descriptor, input_args)) { - // More than one overload matches our arguments. - if (result.has_value()) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); - } - - result.emplace(overload); + return overload; } } - return result; + return std::nullopt; } absl::StatusOr ResolveLazy( @@ -307,7 +299,7 @@ absl::StatusOr ResolveLazy( bool receiver_style, absl::Span providers, const ExecutionFrameBase& frame) { - ResolveResult result = absl::nullopt; + ResolveResult result = std::nullopt; std::vector arg_types(input_args.size()); @@ -315,7 +307,7 @@ absl::StatusOr ResolveLazy( input_args.begin(), input_args.end(), arg_types.begin(), [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); - cel::FunctionDescriptor matcher{name, receiver_style, arg_types}; + cel::FunctionDescriptor matcher{name, receiver_style, std::move(arg_types)}; const cel::ActivationInterface& activation = frame.activation(); for (auto provider : providers) { diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc index 8d54a0188..53b955b25 100644 --- a/eval/eval/regex_match_step_test.cc +++ b/eval/eval/regex_match_step_test.cc @@ -94,7 +94,7 @@ TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, - Eq("regular expressions exceeds max allowed size"))); + Eq("regular expression exceeds max allowed size"))); } } // namespace diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index b95915145..b815f5d87 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -19,7 +19,6 @@ #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -73,7 +72,7 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, return cel::ErrorValue(std::move(result).status()); } - return absl::nullopt; + return std::nullopt; } void TestOnlySelect(const StructValue& msg, const std::string& field, @@ -158,13 +157,6 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { result_trail = trail.Step(&field_); } - if (arg->Is()) { - frame->value_stack().PopAndPush( - cel::ErrorValue(cel::runtime_internal::CreateError("Message is NULL")), - std::move(result_trail)); - return absl::OkStatus(); - } - absl::optional optional_arg; if (enable_optional_types_ && arg.IsOptional()) { @@ -354,10 +346,6 @@ class DirectSelectStep : public DirectExpressionStep { case ValueKind::kStruct: case ValueKind::kMap: break; - case ValueKind::kNull: - result = cel::ErrorValue( - cel::runtime_internal::CreateError("Message is NULL")); - return absl::OkStatus(); default: if (optional_arg) { break; diff --git a/eval/internal/cel_value_equal_test.cc b/eval/internal/cel_value_equal_test.cc index f52f38916..109a63795 100644 --- a/eval/internal/cel_value_equal_test.cc +++ b/eval/internal/cel_value_equal_test.cc @@ -67,7 +67,7 @@ using ::testing::ValuesIn; struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 3c210e607..87c667eb5 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -68,7 +68,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index a86923c67..c18b806b9 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -67,7 +67,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 1eeb07193..b73a2dc55 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -544,13 +544,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(93541L); ref.set_nanos(11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), 25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - 1559L); + int64_t{1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - 93541L); + int64_t{93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - 11L); + int64_t{11L}); std::string result = "93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -560,13 +561,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(-93541L); ref.set_nanos(-11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), -25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{-25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - -1559L); + int64_t{-1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - -93541L); + int64_t{-93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - -11L); + int64_t{-11L}); result = "-93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -595,23 +597,28 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { ref.set_seconds(1L); ref.set_nanos(11000000L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1970L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 0L); + int64_t{1970L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 0L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 1L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 1L); + int64_t{0L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), - 11L); + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); } TEST_F(BuiltinsTest, TestTimestampConversionToString) { @@ -640,46 +647,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); // Test timestamp functions w/ fixed timezone ref.set_seconds(1L); @@ -690,46 +711,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); TestTypeConversionError( builtin::kString, @@ -750,22 +785,25 @@ TEST_F(BuiltinsTest, TestBytesConversions_string) { TEST_F(BuiltinsTest, TestDoubleConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), 100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), + double{100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversions_string) { std::string ref = "-100.1"; - TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), -100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), + double{-100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_uint) { uint64_t ref = 100UL; - TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), + double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { @@ -774,34 +812,36 @@ TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { } TEST_F(BuiltinsTest, TestDynConversions) { - TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), 100.1); - TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), 100L); - TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), double{100.1}); + TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), int64_t{100L}); + TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestIntConversions_int) { - TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_Timestamp) { Timestamp ref; ref.set_seconds(100); - TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), 100L); + TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_uint) { uint64_t ref = 100; - TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_doubleIntMin) { @@ -823,10 +863,10 @@ TEST_F(BuiltinsTest, TestIntConversions_doubleIntMinMinus1024) { TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // Converting int64_t max - 512 to a double will not roundtrip to the original - // value, but it will rountrip to a valid 64-bit integer. + // value, but it will roundtrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), - std::numeric_limits::max() - 1023); + int64_t{std::numeric_limits::max() - 1023}); } TEST_F(BuiltinsTest, TestIntConversionError_doubleNegRange) { @@ -874,21 +914,24 @@ TEST_F(BuiltinsTest, TestIntConversionError_uintRange) { TEST_F(BuiltinsTest, TestUintConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_uint) { - TestTypeConverts(builtin::kUint, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateUint64(uint64_t{100UL}), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversionError_doubleNegRange) { @@ -1589,7 +1632,8 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); - EXPECT_TRUE(result_value.IsBool()); + ASSERT_TRUE(result_value.IsBool()) + << key.DebugString() << " : " << result_value.DebugString(); EXPECT_FALSE(result_value.BoolOrDie()); } diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 015289bed..70525a04d 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -76,8 +76,8 @@ CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 3f52ad60d..4cf029e89 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -80,10 +80,10 @@ class CelExpressionBuilder { virtual ~CelExpressionBuilder() = default; // Creates CelExpression object from AST tree. - // expr specifies root of AST tree - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // expr specifies root of AST tree. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info) const = 0; @@ -91,9 +91,9 @@ class CelExpressionBuilder { // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, @@ -101,8 +101,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr) const { // Default implementation just passes through the expr and source info. @@ -113,8 +114,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const { diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 779839583..4d81eb8a7 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -171,17 +171,23 @@ struct InterpreterOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 290726bfe..3fb80bcea 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -28,7 +28,6 @@ #include "base/type_provider.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -86,7 +85,7 @@ class CelTypeRegistry { // registry. // // This is a composited type provider that should check in order: - // - builtins (via TypeManager) + // - builtins // - custom enumerations // - registered extension type providers in the order registered. const cel::TypeProvider& GetTypeProvider() const { @@ -139,11 +138,6 @@ class CelTypeRegistry { private: // Internal modern registry. cel::TypeRegistry modern_type_registry_; - - // TODO(uncreated-issue/44): This is needed to inspect the registered legacy type - // providers for client tests. This can be removed when they are migrated to - // use the modern APIs. - std::shared_ptr legacy_type_provider_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 772ddfeba..a77a92734 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -86,7 +86,7 @@ MATCHER_P2(DefinesHomogenousOverload, name, argument_type, struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; @@ -204,7 +204,7 @@ std::string CelValueEqualTestName( } TEST_P(CelValueEqualImplTypesTest, Basic) { - absl::optional result = CelValueEqualImpl(lhs(), rhs()); + std::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { @@ -286,7 +286,7 @@ const std::vector& NumericValuesNotEqualExample() { using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); - absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); + std::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } @@ -299,7 +299,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST(CelValueEqualImplTest, LossyNumericEquality) { - absl::optional result = CelValueEqualImpl( + std::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc index ff0e691ab..15e5e88da 100644 --- a/eval/public/message_wrapper_test.cc +++ b/eval/public/message_wrapper_test.cc @@ -18,7 +18,6 @@ #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -60,7 +59,7 @@ TEST(MessageWrapperBuilder, Builder) { static_cast(&test_message)); auto mutable_message = - cel::internal::down_cast(builder.message_ptr()); + google::protobuf::DownCastMessage(builder.message_ptr()); mutable_message->set_int64_value(20); mutable_message->set_double_value(12.3); diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index d301ff0ca..d722559e3 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -442,3 +442,23 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "field_access_impl_benchmark_test", + srcs = ["field_access_impl_benchmark_test.cc"], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//extensions/protobuf/internal:map_reflection", + "//internal:benchmark", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index d0f80171f..7bfe81fe6 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -691,7 +691,7 @@ class ValueFromMessageMaker { return CreateWellknownTypeValue(message, factory, arena); // WELLKNOWNTYPE_FIELDMASK has no special CelValue type default: - return absl::nullopt; + return std::nullopt; } } @@ -716,7 +716,7 @@ absl::optional DynamicMap::operator[](CelValue key) const { auto it = values_->fields().find(std::string(str_key.value())); if (it == values_->fields().end()) { - return absl::nullopt; + return std::nullopt; } return ValueManager(factory_, arena_).ValueFromMessage(&it->second); diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index a1dc83ade..6fad6aee3 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -53,7 +53,7 @@ absl::optional CelProtoWrapper::MaybeWrapValue( if (msg != nullptr) { return InternalWrapMessage(msg); } else { - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 3b3cb9847..2bd9fff9d 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -139,8 +139,7 @@ class FieldAccessor { case FieldDescriptor::TYPE_BYTES: return CelValue::CreateBytesView(value); default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Error handling C++ string conversion"); + break; } break; } @@ -153,8 +152,7 @@ class FieldAccessor { return CelValue::CreateInt64(enum_value); } default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Unhandled C++ type conversion"); + break; } return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); diff --git a/eval/public/structs/field_access_impl_benchmark_test.cc b/eval/public/structs/field_access_impl_benchmark_test.cc new file mode 100644 index 000000000..888e424b1 --- /dev/null +++ b/eval/public/structs/field_access_impl_benchmark_test.cc @@ -0,0 +1,239 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/field_access_impl.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/benchmark.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; + +void BM_CreateValueFromSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Int64); + +void BM_CreateValueFromSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_String); + +void BM_CreateValueFromSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.mutable_standalone_message()->set_bb(123); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Message); + +void BM_CreateValueFromRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_Int64); + +void BM_CreateValueFromRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_String); + +void BM_CreateValueFromMapValue_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + (*msg.mutable_map_int64_int64())[42] = 100; + const google::protobuf::FieldDescriptor* map_desc = + TestAllTypes::descriptor()->FindFieldByName("map_int64_int64"); + const google::protobuf::FieldDescriptor* value_desc = + map_desc->message_type()->FindFieldByName("value"); + + google::protobuf::ConstMapIterator iter = + cel::extensions::protobuf_internal::ConstMapBegin(*msg.GetReflection(), + msg, *map_desc); + google::protobuf::MapValueConstRef value_ref = iter.GetValueRef(); + + for (auto _ : state) { + auto value = + CreateValueFromMapValue(&msg, value_desc, &value_ref, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromMapValue_Int64); + +void BM_SetValueToSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Int64); + +void BM_SetValueToSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_String); + +void BM_SetValueToSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + TestAllTypes::NestedMessage nested_msg; + nested_msg.set_bb(123); + CelValue val = CelProtoWrapper::CreateMessage(&nested_msg, &arena); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Message); + +void BM_AddValueToRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + msg.clear_repeated_int64(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_Int64); + +void BM_AddValueToRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_String); + +void BM_CreateValueFromRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string_piece("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_StringPiece); + +void BM_AddValueToRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string_piece(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_StringPiece); + +} // namespace +} // namespace google::api::expr::runtime::internal diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index a85f08911..f8db92298 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -27,6 +27,7 @@ #include "common/legacy_value.h" #include "common/memory.h" #include "common/type.h" +#include "common/type_introspector.h" #include "common/value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" @@ -62,7 +63,7 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr> SetFieldByNumber( @@ -75,7 +76,7 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr Build() && override { @@ -115,7 +116,7 @@ class LegacyValueBuilder final : public cel::ValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr> SetFieldByNumber( @@ -128,7 +129,7 @@ class LegacyValueBuilder final : public cel::ValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr Build() && override { @@ -175,6 +176,9 @@ LegacyTypeProvider::NewValueBuilder( absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::string_view name) const { + if (auto type = cel::FindWellKnownType(name); type.has_value()) { + return type; + } if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); if (descriptor != nullptr) { @@ -183,12 +187,16 @@ absl::StatusOr> LegacyTypeProvider::FindTypeImpl( return cel::common_internal::MakeBasicStructType( (*type_info)->GetTypename(MessageWrapper())); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> LegacyTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { + if (auto result = cel::FindWellKnownTypeFieldByName(type, name); + result.has_value()) { + return result; + } if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { if (auto field_desc = (*type_info)->FindFieldByName(name); field_desc.has_value()) { @@ -198,13 +206,13 @@ LegacyTypeProvider::FindStructTypeFieldByNameImpl( const auto* mutation_apis = (*type_info)->GetMutationApis(MessageWrapper()); if (mutation_apis == nullptr || !mutation_apis->DefinesField(name)) { - return absl::nullopt; + return std::nullopt; } return cel::common_internal::BasicStructTypeField(name, 0, cel::DynType{}); } } - return absl::nullopt; + return std::nullopt; } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/legacy_type_provider_test.cc b/eval/public/structs/legacy_type_provider_test.cc index 160ac49f3..8de2aba01 100644 --- a/eval/public/structs/legacy_type_provider_test.cc +++ b/eval/public/structs/legacy_type_provider_test.cc @@ -28,7 +28,7 @@ class LegacyTypeProviderTestEmpty : public LegacyTypeProvider { public: absl::optional ProvideLegacyType( absl::string_view name) const override { - return absl::nullopt; + return std::nullopt; } }; @@ -60,14 +60,14 @@ class LegacyTypeProviderTestImpl : public LegacyTypeProvider { if (name == "test") { return LegacyTypeAdapter(nullptr, nullptr); } - return absl::nullopt; + return std::nullopt; } absl::optional ProvideLegacyTypeInfo( absl::string_view name) const override { if (name == "test") { return test_type_info_; } - return absl::nullopt; + return std::nullopt; } private: @@ -76,8 +76,8 @@ class LegacyTypeProviderTestImpl : public LegacyTypeProvider { TEST(LegacyTypeProviderTest, EmptyTypeProviderHasProvideTypeInfo) { LegacyTypeProviderTestEmpty provider; - EXPECT_EQ(provider.ProvideLegacyType("test"), absl::nullopt); - EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), absl::nullopt); + EXPECT_EQ(provider.ProvideLegacyType("test"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), std::nullopt); } TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { @@ -85,8 +85,8 @@ TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { LegacyTypeProviderTestImpl provider(&test_type_info); EXPECT_TRUE(provider.ProvideLegacyType("test").has_value()); EXPECT_TRUE(provider.ProvideLegacyTypeInfo("test").has_value()); - EXPECT_EQ(provider.ProvideLegacyType("other"), absl::nullopt); - EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), absl::nullopt); + EXPECT_EQ(provider.ProvideLegacyType("other"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), std::nullopt); } } // namespace diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a351890c2..8c140c0c7 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -472,14 +472,14 @@ const LegacyTypeAccessApis* ProtoMessageTypeAdapter::GetAccessApis( absl::optional ProtoMessageTypeAdapter::FindFieldByName(absl::string_view field_name) const { if (descriptor_ == nullptr) { - return absl::nullopt; + return std::nullopt; } const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByName(field_name); if (field_descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } return LegacyTypeInfoApis::FieldDescription{field_descriptor->number(), @@ -582,6 +582,19 @@ absl::Status ProtoMessageTypeAdapter::SetField( ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), "failed to find value field descriptor")); + bool prune_when_null = false; + if (value_field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + auto well_known_type = + value_field_descriptor->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + prune_when_null = true; + } + } + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list).Get(arena, i); @@ -589,6 +602,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( auto value = (*cel_map).Get(arena, key); CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), "error serializing CelMap")); + if (prune_when_null && value->IsNull()) { + continue; + } Message* entry_msg = message->GetReflection()->AddMessage(message, field); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( key, key_field_descriptor, entry_msg, arena)); diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 088d20d48..e28d76102 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -36,6 +36,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" namespace google::api::expr::runtime { namespace { @@ -69,8 +70,8 @@ class ProtoMessageTypeAccessorTest : public testing::TestWithParam { bool use_generic_instance = GetParam(); if (use_generic_instance) { // implementation detail: in general, type info implementations may - // return a different accessor object based on the messsage instance, but - // this implemenation returns the same one no matter the message. + // return a different accessor object based on the message instance, but + // this implementation returns the same one no matter the message. return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); } else { @@ -709,7 +710,7 @@ TEST(ProtoMesssageTypeAdapter, FindFieldNotFound) { "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), absl::nullopt); + EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), std::nullopt); } TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { @@ -725,7 +726,8 @@ TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, api->NewInstance(manager)); - EXPECT_NE(dynamic_cast(builder.message_ptr()), nullptr); + EXPECT_NE(google::protobuf::DynamicCastMessage(builder.message_ptr()), + nullptr); } TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 68b39c643..b5746523e 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -27,7 +27,7 @@ absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { - return absl::nullopt; + return std::nullopt; } // ProtoMessageTypeAdapter provides apis for both access and mutation. return LegacyTypeAdapter(result, result); @@ -38,7 +38,7 @@ ProtobufDescriptorProvider::ProvideLegacyTypeInfo( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { - return absl::nullopt; + return std::nullopt; } return result; } diff --git a/eval/public/structs/trivial_legacy_type_info_test.cc b/eval/public/structs/trivial_legacy_type_info_test.cc index 9b4840373..9cc6e4916 100644 --- a/eval/public/structs/trivial_legacy_type_info_test.cc +++ b/eval/public/structs/trivial_legacy_type_info_test.cc @@ -56,9 +56,9 @@ TEST(TrivialTypeInfo, FindFieldByName) { TrivialTypeInfo info; MessageWrapper wrapper; - EXPECT_EQ(info.FindFieldByName("foo"), absl::nullopt); + EXPECT_EQ(info.FindFieldByName("foo"), std::nullopt); EXPECT_EQ(TrivialTypeInfo::GetInstance()->FindFieldByName("foo"), - absl::nullopt); + std::nullopt); } } // namespace diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index f79071fce..4f728c730 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -7,7 +7,6 @@ #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" #include "eval/public/set_util.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" @@ -76,8 +75,7 @@ class CelValueMatcherImpl CelValue::MessageWrapper arg; return v.GetValue(&arg) && arg.HasFullProto() && underlying_type_matcher_.Matches( - cel::internal::down_cast( - arg.message_ptr())); + google::protobuf::DownCastMessage(arg.message_ptr())); } void DescribeTo(std::ostream* os) const override { diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index edb6e83e0..bca8a8d65 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -67,8 +67,8 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value, break; } case CelValue::Type::kBytes: { - absl::Base64Escape(in_value.BytesOrDie().value(), - out_value->mutable_string_value()); + *out_value->mutable_string_value() = + absl::Base64Escape(in_value.BytesOrDie().value()); break; } case CelValue::Type::kDuration: { diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 8eeafd521..9163548d1 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -18,7 +18,10 @@ cc_test( srcs = [ "benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -52,7 +55,10 @@ cc_test( srcs = [ "modern_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:allocator", @@ -102,7 +108,10 @@ cc_test( srcs = [ "allocation_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -151,7 +160,10 @@ cc_test( srcs = [ "expression_builder_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:minimal_descriptor_pool", diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc index 5364d3fc0..425355e3a 100644 --- a/eval/tests/allocation_benchmark_test.cc +++ b/eval/tests/allocation_benchmark_test.cc @@ -169,6 +169,9 @@ static void BM_AllocateMessage(benchmark::State& state) { "google.api.expr.runtime.RequestContext{" "ip: '192.168.0.1'," "path: '/root'}"); + // Make sure RequestContext is loaded in the generated descriptor pool. + RequestContext context; + static_cast(context); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index fc0c39294..f188dc0b7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -317,7 +317,7 @@ BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: - absl::optional operator[](CelValue key) const override { + std::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index c26a7cd5c..410df8902 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -50,8 +48,24 @@ using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, }; +std::string LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); @@ -64,21 +78,33 @@ BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; - switch (param) { case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && @@ -105,7 +131,9 @@ void BM_SymbolicPolicy(benchmark::State& state) { BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, @@ -209,6 +237,7 @@ BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) @@ -231,10 +260,13 @@ void BM_NestedComprehension(benchmark::State& state) { BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 @@ -260,7 +292,9 @@ void BM_Comparisons(benchmark::State& state) { BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( @@ -290,6 +324,8 @@ BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || @@ -325,7 +361,9 @@ void BM_RegexPrecompilationDisabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); @@ -333,10 +371,13 @@ void BM_RegexPrecompilationEnabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; @@ -377,7 +418,17 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index 9c0a683e4..a88844fed 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -51,7 +51,12 @@ struct TestCase { bool reference_resolver_enabled = false; }; -enum Options { kDefault, kExhaustive, kFoldConstants }; +enum Options { + kDefault, + kExhaustive, + kFoldConstants, + kFoldConstantsManagedArena +}; using ParamType = std::tuple; @@ -68,6 +73,9 @@ std::string TestCaseName(const testing::TestParamInfo& param_info) { case Options::kFoldConstants: opt = "opt"; break; + case Options::kFoldConstantsManagedArena: + opt = "opt_managed_arena"; + break; } return absl::StrCat(std::get<0>(param).name, "_", opt); @@ -110,6 +118,14 @@ class EvaluatorMemorySafetyTest : public testing::TestWithParam { options.enable_comprehension_vulnerability_check = false; options.short_circuiting = true; break; + case Options::kFoldConstantsManagedArena: + options.enable_regex_precompilation = true; + options.constant_folding = true; + options.enable_comprehension_list_append = true; + options.enable_comprehension_vulnerability_check = false; + options.short_circuiting = true; + options.constant_arena = nullptr; + break; } options.enable_qualified_identifier_rewrites = @@ -295,7 +311,8 @@ INSTANTIATE_TEST_SUITE_P( test::IsCelBool(true), }}), testing::Values(Options::kDefault, Options::kExhaustive, - Options::kFoldConstants)), + Options::kFoldConstants, + Options::kFoldConstantsManagedArena)), &TestCaseName); } // namespace diff --git a/extensions/BUILD b/extensions/BUILD index 1e6e9204a..05104a4a5 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -75,6 +75,7 @@ cc_library( srcs = ["math_ext.cc"], hdrs = ["math_ext.h"], deps = [ + ":math_ext_decls", "//common:casting", "//common:value", "//eval/public:cel_function_registry", @@ -214,7 +215,10 @@ cc_library( srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ - "//common:ast", + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", "//compiler", "//internal:status_macros", "//parser:macro", @@ -608,6 +612,7 @@ cc_test( "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", + "//common:ast", "//common:decl", "//common:type", "//common:value", @@ -772,6 +777,8 @@ cc_library( "//runtime:runtime_builder", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", + "//validator", + "//validator:regex_validator", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:bind_front", @@ -812,6 +819,7 @@ cc_test( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index f097709ca..4823c077c 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -21,7 +21,10 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "common/ast.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -34,6 +37,8 @@ namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { @@ -47,6 +52,19 @@ inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { return absl::OkStatus(); } +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + } // namespace std::vector bindings_macros() { @@ -55,7 +73,7 @@ std::vector bindings_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span args) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } if (!args[0].has_ident_expr()) { return factory.ReportErrorAt( @@ -70,8 +88,16 @@ std::vector bindings_macros() { return {*cel_bind}; } -CompilerLibrary BindingsCompilerLibrary() { - return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; } } // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h index a338b24f6..40b83a37f 100644 --- a/extensions/bindings_ext.h +++ b/extensions/bindings_ext.h @@ -25,6 +25,7 @@ namespace cel::extensions { +constexpr int kBindingsVersionLatest = 1; // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); @@ -35,7 +36,10 @@ inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, } // Declarations for the bindings extension library. -CompilerLibrary BindingsCompilerLibrary(); +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); } // namespace cel::extensions diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index a8de3a103..a054626f9 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -14,12 +14,14 @@ #include "extensions/comprehensions_v2_macros.h" +#include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -38,16 +40,21 @@ namespace { using ::google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("all() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "all() second variable name must be a simple identifier"); } @@ -56,15 +63,15 @@ absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, args[0], "all() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("all() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -89,11 +96,11 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("exists() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "exists() second variable name must be a simple identifier"); } @@ -102,15 +109,15 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, args[0], "exists() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("exists() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -138,11 +145,11 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("existsOne() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "existsOne() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "existsOne() second variable name must be a simple identifier"); @@ -153,15 +160,15 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, "existsOne() second variable must be different " "from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("existsOne() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("existsOne() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -190,12 +197,12 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformList() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -205,15 +212,15 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -239,12 +246,12 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformList() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -254,15 +261,15 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -290,12 +297,12 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMap() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -305,15 +312,15 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -338,12 +345,12 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMap() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -353,15 +360,15 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -388,12 +395,12 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMapEntry() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); @@ -403,17 +410,17 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -438,12 +445,12 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMapEntry() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); @@ -453,17 +460,17 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 970cc6388..252fdc7bd 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -14,20 +14,19 @@ #include "extensions/formatting.h" +#include #include #include #include #include #include #include -#include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" -#include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -54,6 +53,7 @@ namespace { static constexpr int32_t kNanosPerMillisecond = 1000000; static constexpr int32_t kNanosPerMicrosecond = 1000; +static constexpr int kMaxPrecision = 1000; absl::StatusOr FormatString( const Value& value, @@ -63,7 +63,7 @@ absl::StatusOr FormatString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); absl::StatusOr>> ParsePrecision( - absl::string_view format) { + absl::string_view format, int max_precision) { if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; int64_t i = 1; @@ -79,6 +79,10 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } + if (precision > max_precision) { + return absl::InvalidArgumentError( + absl::StrCat("precision specifier exceeds maximum of ", max_precision)); + } return std::pair{i, precision}; } @@ -415,6 +419,12 @@ absl::StatusOr GetDouble(const Value& value, std::string& scratch) { str)); } } + if (value.kind() == ValueKind::kInt) { + return static_cast(value.GetInt().NativeValue()); + } + if (value.kind() == ValueKind::kUint) { + return static_cast(value.GetUint().NativeValue()); + } if (value.kind() != ValueKind::kDouble) { return absl::InvalidArgumentError( absl::StrCat("expected a double but got a ", value.GetTypeName())); @@ -439,12 +449,13 @@ absl::StatusOr FormatScientific( } absl::StatusOr> ParseAndFormatClause( - absl::string_view format, const Value& value, + absl::string_view format, const Value& value, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { - CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format)); + CEL_ASSIGN_OR_RETURN(auto precision_pair, + ParsePrecision(format, max_precision)); auto [read, precision] = precision_pair; switch (format[read]) { case 's': { @@ -489,7 +500,7 @@ absl::StatusOr> ParseAndFormatClause( } absl::StatusOr Format( - const StringValue& format_value, const ListValue& args, + const StringValue& format_value, const ListValue& args, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -507,43 +518,51 @@ absl::StatusOr Format( } ++i; if (i >= format.size()) { - return absl::InvalidArgumentError("unexpected end of format string"); + return ErrorValue( + absl::InvalidArgumentError("unexpected end of format string")); } if (format[i] == '%') { result.push_back('%'); continue; } if (arg_index >= args_size) { - return absl::InvalidArgumentError( - absl::StrFormat("index %d out of range", arg_index)); + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("index %d out of range", arg_index))); } CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, message_factory, arena)); - CEL_ASSIGN_OR_RETURN( - auto clause, - ParseAndFormatClause(format.substr(i), value, descriptor_pool, - message_factory, arena, clause_scratch)); - absl::StrAppend(&result, clause.second); - i += clause.first; + + auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, + descriptor_pool, message_factory, arena, + clause_scratch); + if (!clause.ok()) { + return ErrorValue(std::move(clause).status()); + } + absl::StrAppend(&result, clause->second); + i += clause->first; } - return StringValue(arena, std::move(result)); + return StringValue::From(std::move(result), arena); } } // namespace -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options) { + const int max_precision = + std::clamp(format_options.max_precision, 0, kMaxPrecision); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: CreateDescriptor("format", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( - [](const StringValue& format, const ListValue& args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) { - return Format(format, args, descriptor_pool, message_factory, - arena); + [max_precision]( + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); }))); return absl::OkStatus(); } diff --git a/extensions/formatting.h b/extensions/formatting.h index bc2002006..88954857b 100644 --- a/extensions/formatting.h +++ b/extensions/formatting.h @@ -21,9 +21,18 @@ namespace cel::extensions { +struct StringsExtensionFormatOptions { + // The maximum precision to permit for formatting floating-point numbers. + int max_precision = 1000; +}; + // Register extension functions for string formatting. -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +// +// This implements (string).format([args...]) in the strings extension. Most +// users should add these functions via `extensions/strings.h` instead. +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options = {}); } // namespace cel::extensions diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 433e4ae24..6a7fb300b 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -59,6 +59,75 @@ using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::ValuesIn; +using StringFormatLimitsTest = TestWithParam; + +// Check that formatted floating points are reversible. +TEST_P(StringFormatLimitsTest, FormatLimits) { + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + RegisterStringFormattingFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(GetParam(), "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + static_assert(std::numeric_limits::min_exponent == -1021); + for (double x : { + 0x1p-1021, + 0x3p-1021, + std::numeric_limits::epsilon() * 0x1p-3, + std::numeric_limits::epsilon() * 0x7p-3, + 1.1 / 7.0 * 1e-101, + 1.2 / 7.0 * 1e-101, + }) { + activation.InsertOrAssignValue("x", DoubleValue(x)); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); + } +} + +TEST(StringFormatLimitsTest, MaxPrecisionOption) { + google::protobuf::Arena arena; + const RuntimeOptions options; + StringsExtensionFormatOptions format_options; + format_options.max_precision = 99; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), + options, format_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", + "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus().message(), + HasSubstr("precision specifier exceeds maximum of 99")); +} + +INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, + ValuesIn({ + "double('%.326f'.format([x])) == x", + "double('%.17e'.format([x])) == x", + })); + struct FormattingTestCase { std::string name; std::string format; @@ -207,6 +276,12 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "'hello'", .error = "unable to find end of precision specifier", }, + { + .name = "InvalidPrecisionOutOfRange", + .format = "%.1001f", + .format_args = "1.2345", + .error = "precision specifier exceeds maximum of 100", + }, { .name = "DecimalFormatingClause", .format = "int %d, uint %d", @@ -478,6 +553,18 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "2.71828", .expected = "2.718280e+00", }, + { + .name = "FixedPointClauseWithInt", + .format = "%f", + .format_args = "3", + .expected = "3.000000", + }, + { + .name = "ScientificNotationWithUint", + .format = "%e", + .format_args = "uint(3)", + .expected = "3.000000e+00", + }, { .name = "NaNSupportForFixedPoint", .format = "%f", diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index 10bc717ed..bfe05d887 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -454,7 +454,7 @@ Macro ListSortByMacro() { MakeMapComprehension(factory, factory.Copy(sortby_input_ident), std::move(key_ident), std::move(key_expr)); if (!map_compr.has_value()) { - return absl::nullopt; + return std::nullopt; } // Build the call expression: diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 7b3655de3..a7773da19 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -266,7 +266,11 @@ Value BitShiftLeftInt(int64_t lhs, int64_t rhs) { if (rhs > 63) { return IntValue(0); } - return IntValue(lhs << static_cast(rhs)); + // Shift in the unsigned domain to avoid undefined behaviour when lhs is + // negative or the shift moves bits into the sign bit, matching the bit + // pattern semantics already used by bitShiftRight. + return IntValue(absl::bit_cast(absl::bit_cast(lhs) + << static_cast(rhs))); } Value BitShiftLeftUint(uint64_t lhs, int64_t rhs) { @@ -308,7 +312,8 @@ Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { } // namespace absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); @@ -360,6 +365,9 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, registry))); + if (version == 0) { + return absl::OkStatus(); + } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( @@ -370,15 +378,6 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.round", RoundDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtInt, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.trunc", TruncDouble, registry))); @@ -453,6 +452,20 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightUint, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtUint, registry))); + return absl::OkStatus(); } diff --git a/extensions/math_ext.h b/extensions/math_ext.h index 63d9e964b..fe000e476 100644 --- a/extensions/math_ext.h +++ b/extensions/math_ext.h @@ -18,6 +18,7 @@ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "extensions/math_ext_decls.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" @@ -25,8 +26,9 @@ namespace cel::extensions { // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. -absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterMathExtensionFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kMathExtensionLatestVersion); absl::Status RegisterMathExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, diff --git a/extensions/math_ext_macros.cc b/extensions/math_ext_macros.cc index a66720a60..08b163132 100644 --- a/extensions/math_ext_macros.cc +++ b/extensions/math_ext_macros.cc @@ -72,7 +72,7 @@ absl::optional CheckInvalidArgs(MacroExprFactory &factory, } } - return absl::nullopt; + return std::nullopt; } bool IsListLiteralWithValidArgs(const Expr &arg) { @@ -99,7 +99,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { @@ -143,7 +143,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 3088e6fa8..ce05ae6ed 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -23,7 +23,6 @@ #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -94,7 +93,7 @@ TestCase MinCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MinCase(CelValue list, CelValue result) { - return TestCase{kMathMin, list, absl::nullopt, result}; + return TestCase{kMathMin, list, std::nullopt, result}; } TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { @@ -102,7 +101,7 @@ TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MaxCase(CelValue list, CelValue result) { - return TestCase{kMathMax, list, absl::nullopt, result}; + return TestCase{kMathMax, list, std::nullopt, result}; } struct MacroTestCase { @@ -110,19 +109,6 @@ struct MacroTestCase { absl::string_view err = ""; }; -std::string FormatIssues(const cel::ValidationResult& result) { - std::string issues; - for (const auto& issue : result.GetIssues()) { - if (!issues.empty()) { - absl::StrAppend(&issues, "\n", - issue.ToDisplayString(*result.GetSource())); - } else { - issues = issue.ToDisplayString(*result.GetSource()); - } - } - return issues; -} - class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) @@ -352,10 +338,11 @@ TEST_P(MathExtMacroParamsTest, ParserTests) { TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { const MacroTestCase& test_case = GetParam(); - - ASSERT_OK_AND_ASSIGN( - auto compiler_builder, - cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CompilerOptions compile_opts; + compile_opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + cel::NewCompilerBuilder( + internal::GetTestingDescriptorPool(), compile_opts)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); @@ -381,16 +368,16 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); - auto result = compiler->Compile(test_case.expr, ""); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); if (!test_case.err.empty()) { - EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr(test_case.err))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); return; } - ASSERT_THAT(result, IsOk()); - ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( @@ -411,9 +398,8 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); - - ASSERT_OK_AND_ASSIGN(auto program, - runtime->CreateProgram(*result->ReleaseAst())); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; @@ -577,6 +563,8 @@ INSTANTIATE_TEST_SUITE_P( {"math.bitNot(2) == -3"}, {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, {"math.bitShiftLeft(1, 1) == 2"}, + {"math.bitShiftLeft(-1, 1) == -2"}, + {"math.bitShiftLeft(-4, 2) == -16"}, {"math.bitShiftLeft(1u, 1) == 2u"}, {"math.bitShiftRight(4, 1) == 2"}, {"math.bitShiftRight(4u, 1) == 2u"}})); diff --git a/extensions/proto_ext.cc b/extensions/proto_ext.cc index f38039002..48618f7ae 100644 --- a/extensions/proto_ext.cc +++ b/extensions/proto_ext.cc @@ -45,11 +45,11 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { absl::Overload( [](const SelectExpr& select_expr) -> absl::optional { if (select_expr.test_only()) { - return absl::nullopt; + return std::nullopt; } auto op_name = ValidateExtensionIdentifier(select_expr.operand()); if (!op_name.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*op_name, ".", select_expr.field()); }, @@ -57,7 +57,7 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { return ident_expr.name(); }, [](const auto&) -> absl::optional { - return absl::nullopt; + return std::nullopt; }), expr.kind()); } @@ -68,7 +68,7 @@ absl::optional GetExtensionFieldName(const Expr& expr) { select_expr) { return ValidateExtensionIdentifier(expr); } - return absl::nullopt; + return std::nullopt; } bool IsExtensionCall(const Expr& target) { @@ -95,7 +95,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { @@ -109,7 +109,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 6c3f654f9..3f4081b09 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -87,48 +87,12 @@ cc_library( ], ) -cc_library( - name = "type", - srcs = [ - "type_introspector.cc", - ], - hdrs = [ - "type_introspector.h", - ], - deps = [ - "//common:type", - "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "type_test", - srcs = [ - "type_introspector_test.cc", - ], - deps = [ - ":type", - "//common:type", - "//common:type_kind", - "//internal:testing", - "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", - "@com_google_protobuf//:protobuf", - ], -) - cc_library( name = "value", hdrs = [ - "type_reflector.h", "value.h", ], deps = [ - ":type", "//common:memory", "//common:type", "//common:value", diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc index fd79508ac..680b4b353 100644 --- a/extensions/protobuf/bind_proto_to_activation_test.cc +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -76,10 +76,10 @@ TEST_F(BindProtoToActivationTest, BindProtoToActivationSkip) { EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationDefault) { diff --git a/extensions/protobuf/internal/map_reflection.cc b/extensions/protobuf/internal/map_reflection.cc index 22a6dc23c..605e4437d 100644 --- a/extensions/protobuf/internal/map_reflection.cc +++ b/extensions/protobuf/internal/map_reflection.cc @@ -42,22 +42,16 @@ class CelMapReflectionFriend final { return reflection.MapSize(message, &field); } - static google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapBegin( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapBegin( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapBegin(&message, &field); } - static google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapEnd( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapEnd( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapEnd(&message, &field); } static bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, @@ -104,18 +98,18 @@ int MapSize(const google::protobuf::Reflection& reflection, field); } -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapBegin(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapBegin(reflection, + message, field); } -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapEnd(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapEnd(reflection, message, + field); } bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, diff --git a/extensions/protobuf/internal/map_reflection.h b/extensions/protobuf/internal/map_reflection.h index 6e696bbe3..681d7693d 100644 --- a/extensions/protobuf/internal/map_reflection.h +++ b/extensions/protobuf/internal/map_reflection.h @@ -42,13 +42,13 @@ int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, diff --git a/extensions/protobuf/internal/qualify.cc b/extensions/protobuf/internal/qualify.cc index dba4f44ae..37ad30011 100644 --- a/extensions/protobuf/internal/qualify.cc +++ b/extensions/protobuf/internal/qualify.cc @@ -145,7 +145,7 @@ absl::StatusOr> LookupMapValu bool found = cel::extensions::protobuf_internal::LookupMapValue( *reflection, *message, *field_desc, proto_key, &value_ref); if (!found) { - return absl::nullopt; + return std::nullopt; } return value_ref; } diff --git a/extensions/protobuf/type_introspector.cc b/extensions/protobuf/type_introspector.cc deleted file mode 100644 index 8b445c359..000000000 --- a/extensions/protobuf/type_introspector.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( - absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(name); - if (desc == nullptr) { - return absl::nullopt; - } - return MessageType(desc); -} - -absl::StatusOr> -ProtoTypeIntrospector::FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const { - const google::protobuf::EnumDescriptor* enum_desc = - descriptor_pool()->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. - if (enum_desc == nullptr) { - return absl::nullopt; - } - - // Note: we don't support strong enum typing at this time so only the fully - // qualified enum values are meaningful, so we don't provide any signal if the - // enum type is found but can't match the value name. - const google::protobuf::EnumValueDescriptor* value_desc = - enum_desc->FindValueByName(value); - if (value_desc == nullptr) { - return absl::nullopt; - } - - return TypeIntrospector::EnumConstant{ - EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), - value_desc->number()}; -} - -absl::StatusOr> -ProtoTypeIntrospector::FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(type); - if (desc == nullptr) { - return absl::nullopt; - } - const auto* field_desc = desc->FindFieldByName(name); - if (field_desc == nullptr) { - field_desc = descriptor_pool()->FindExtensionByPrintableName(desc, name); - if (field_desc == nullptr) { - return absl::nullopt; - } - } - return MessageTypeField(field_desc); -} - -} // namespace cel::extensions diff --git a/extensions/protobuf/type_introspector.h b/extensions/protobuf/type_introspector.h deleted file mode 100644 index 5eb9c3ddc..000000000 --- a/extensions/protobuf/type_introspector.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ - -#include "absl/base/nullability.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeIntrospector : public virtual TypeIntrospector { - public: - ProtoTypeIntrospector() - : ProtoTypeIntrospector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeIntrospector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : descriptor_pool_(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return descriptor_pool_; - } - - protected: - absl::StatusOr> FindTypeImpl( - absl::string_view name) const final; - - absl::StatusOr> - FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const final; - - absl::StatusOr> FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const final; - - private: - const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc deleted file mode 100644 index 0a7b21524..000000000 --- a/extensions/protobuf/type_introspector_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_kind.h" -#include "internal/testing.h" -#include "cel/expr/conformance/proto2/test_all_types.pb.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { -namespace { - -using ::absl_testing::IsOkAndHolds; -using ::cel::expr::conformance::proto2::TestAllTypes; -using ::testing::Eq; -using ::testing::Optional; - -TEST(ProtoTypeIntrospector, FindType) { - ProtoTypeIntrospector introspector; - EXPECT_THAT( - introspector.FindType(TestAllTypes::descriptor()->full_name()), - IsOkAndHolds(Optional(Eq(MessageType(TestAllTypes::GetDescriptor()))))); - EXPECT_THAT(introspector.FindType("type.that.does.not.Exist"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindStructTypeFieldByName) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto field, introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "single_int32")); - ASSERT_TRUE(field.has_value()); - EXPECT_THAT(field->name(), Eq("single_int32")); - EXPECT_THAT(field->number(), Eq(1)); - EXPECT_THAT( - introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "field_that_does_not_exist"), - IsOkAndHolds(Eq(absl::nullopt))); - EXPECT_THAT(introspector.FindStructTypeFieldByName("type.that.does.not.Exist", - "does_not_matter"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindEnumConstant) { - ProtoTypeIntrospector introspector; - const auto* enum_desc = TestAllTypes::NestedEnum_descriptor(); - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "BAZ")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); - EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); - EXPECT_EQ(enum_constant->value_name, "BAZ"); - EXPECT_EQ(enum_constant->number, 2); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantNull) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant("google.protobuf.NullValue", "NULL_VALUE")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kNull); - EXPECT_EQ(enum_constant->type_full_name, "google.protobuf.NullValue"); - EXPECT_EQ(enum_constant->value_name, "NULL_VALUE"); - EXPECT_EQ(enum_constant->number, 0); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownEnum) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN(auto enum_constant, - introspector.FindEnumConstant("NotARealEnum", "BAZ")); - EXPECT_FALSE(enum_constant.has_value()); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownValue) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "QUX")); - ASSERT_FALSE(enum_constant.has_value()); -} - -} // namespace -} // namespace cel::extensions diff --git a/extensions/protobuf/type_reflector.h b/extensions/protobuf/type_reflector.h deleted file mode 100644 index 4665235fe..000000000 --- a/extensions/protobuf/type_reflector.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ - -#include "absl/base/nullability.h" -#include "common/type_reflector.h" -#include "extensions/protobuf/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeReflector : public TypeReflector, public ProtoTypeIntrospector { - public: - ProtoTypeReflector() - : ProtoTypeReflector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeReflector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : ProtoTypeIntrospector(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return ProtoTypeIntrospector::descriptor_pool(); - } -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index c3d7cae53..9c06d90c2 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -42,6 +42,8 @@ #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" +#include "validator/regex_validator.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -341,4 +343,10 @@ CompilerLibrary RegexExtCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } +Validation RegexExtValidator() { + return RegexPatternValidator( + /*id=*/"", + {{"regex.extract", 1}, {"regex.extractAll", 1}, {"regex.replace", 1}}); +} + } // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h index dc401f5bd..7b32aee00 100644 --- a/extensions/regex_ext.h +++ b/extensions/regex_ext.h @@ -81,6 +81,7 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/runtime_builder.h" +#include "validator/validator.h" namespace cel::extensions { @@ -119,5 +120,12 @@ CheckerLibrary RegexExtCheckerLibrary(); // regex.extractAll(target: str, pattern: str) -> list CompilerLibrary RegexExtCompilerLibrary(); +// Returns a `Validation` that checks all calls to the CEL regex extension +// functions. +// +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +Validation RegexExtValidator(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index e69f7cce1..26d9936aa 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -46,6 +46,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/extension_set.h" @@ -497,5 +498,44 @@ std::vector createRegexCheckerParams() { INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, ValuesIn(createRegexCheckerParams())); + +absl::StatusOr> CreateRegexExtCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(RegexExtCompilerLibrary())); + return std::move(*builder).Build(); +} + +class RegexExtValidatorTest : public TestWithParam {}; + +TEST_P(RegexExtValidatorTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateRegexExtCompiler()); + + Validator validator; + validator.AddValidation(RegexExtValidator()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(GetParam().expr_string)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), GetParam().error_substr.empty()) + << "Expression: " << GetParam().expr_string; + if (!GetParam().error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(GetParam().error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtValidatorTest, RegexExtValidatorTest, + testing::ValuesIn(std::vector{ + {"regex.extract('hello world', 'hello (.*)')"}, + {"regex.extract('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.extractAll('hello world', 'hello (.*)')"}, + {"regex.extractAll('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.replace('hello world', 'hello', 'hi')"}, + {"regex.replace('hello world', 'he([', 'hi') ", + "invalid regular expression"}, + })); } // namespace } // namespace cel::extensions diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 0f09773ae..0cc64311a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -92,7 +92,7 @@ struct SelectInstruction { // Represents a single qualifier in a traversal path. // TODO(uncreated-issue/51): support variable indexes. using QualifierInstruction = - absl::variant; + std::variant; struct SelectPath { Expr* operand; @@ -153,16 +153,16 @@ Expr MakeSelectPathExpr( // Returns a single select operation based on the inferred type of the operand // and the field name. If the operand type doesn't define the field, returns // nullopt. -absl::optional GetSelectInstruction( +std::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) - .value_or(absl::nullopt); + .value_or(std::nullopt); if (field_or.has_value()) { return SelectInstruction{field_or->number(), std::string(field_or->name())}; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr SelectQualifierFromList(const ListExpr& list) { @@ -407,13 +407,13 @@ class RewriterImpl : public AstRewriterBase { // support message traversal. const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); - absl::optional rt_type = + std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) - : absl::nullopt; + : std::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); - absl::optional field_or = + std::optional field_or = GetSelectInstruction(runtime_type, planner_context_, field_name); if (field_or.has_value()) { candidates_[&expr] = std::move(field_or).value(); @@ -538,9 +538,9 @@ class RewriterImpl : public AstRewriterBase { return candidates_.find(operand) != candidates_.end(); } - absl::optional GetRuntimeType(absl::string_view type_name) { + std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( - absl::nullopt); + std::nullopt); } void SetProgressStatus(const absl::Status& status) { @@ -582,14 +582,14 @@ class OptimizedSelectImpl { AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; - absl::optional attribute() const { return attribute_; } + std::optional attribute() const { return attribute_; } const std::vector& qualifiers() const { return qualifiers_; } private: - absl::optional attribute_; + std::optional attribute_; std::vector select_path_; std::vector qualifiers_; bool presence_test_; @@ -597,10 +597,10 @@ class OptimizedSelectImpl { }; // Check for unknowns or missing attributes. -absl::StatusOr> CheckForMarkedAttributes( +absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { - return absl::nullopt; + return std::nullopt; } if (frame.unknown_processing_enabled() && @@ -624,7 +624,7 @@ absl::StatusOr> CheckForMarkedAttributes( attribute_trail.attribute()); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr OptimizedSelectImpl::ApplySelect( @@ -715,7 +715,7 @@ absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // select arguments. // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); - CEL_ASSIGN_OR_RETURN(absl::optional value, + CEL_ASSIGN_OR_RETURN(std::optional value, CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index c07f4c6ad..c14c4d461 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -254,9 +254,10 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { return nullptr; } - absl::optional FindFieldByName( - absl::string_view field_name) const override { - return absl::nullopt; + std::optional< + google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> + FindFieldByName(absl::string_view field_name) const override { + return std::nullopt; } MOCK_METHOD(absl::StatusOr, GetField, diff --git a/extensions/strings.cc b/extensions/strings.cc index 652c72572..54fda20d6 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -305,18 +305,10 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "join", /*receiver_style=*/true), - UnaryFunctionAdapter, ListValue>::WrapFunction( - Join1))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, ListValue, StringValue>:: - CreateDescriptor("join", /*receiver_style=*/true), - BinaryFunctionAdapter, ListValue, - StringValue>::WrapFunction(Join2))); +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options) { + const int version = extension_options.version; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -350,7 +342,6 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, @@ -388,9 +379,33 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "trim", &Trim, registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( + registry, options, {extension_options.max_precision})); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "join", /*receiver_style=*/true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + Join1))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, StringValue>:: + CreateDescriptor("join", /*receiver_style=*/true), + BinaryFunctionAdapter, ListValue, + StringValue>::WrapFunction(Join2))); + if (version == 2) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "reverse", &Reverse, registry))); @@ -399,13 +414,16 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options) { + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options) { return RegisterStringsFunctions( registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + google::api::expr::runtime::ConvertToRuntimeOptions(options), + extension_options); } -CheckerLibrary StringsCheckerLibrary(int version) { +CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { + const int version = options.version; return {"strings", [version](TypeCheckerBuilder& builder) { return RegisterStringsDecls(builder, version); }}; diff --git a/extensions/strings.h b/extensions/strings.h index 5dab33c5d..3ec92d603 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -27,20 +27,45 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; +struct StringsExtensionOptions { + int version = kStringsExtensionLatestVersion; + + // Maximum precision allowed for floating point format specifiers in + // format() function. This is used for both fixed and scientific notations. + // Value must be in the range [0, 1000], otherwise clamped. + // + // Does not affect default precisions for %e and %f format specifiers. + int max_precision = 1000; +}; + // Register extension functions for strings. -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options = {}); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options); + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options = {}); CheckerLibrary StringsCheckerLibrary( - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); + +inline CheckerLibrary StringsCheckerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCheckerLibrary(options); +} inline CompilerLibrary StringsCompilerLibrary( - int version = kStringsExtensionLatestVersion) { - return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(version)); + const StringsExtensionOptions& options = {}) { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); +} + +inline CompilerLibrary StringsCompilerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCompilerLibrary(options); } } // namespace cel::extensions diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index a5d56eaed..c3059808f 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -27,6 +27,7 @@ #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" @@ -50,6 +51,7 @@ namespace cel::extensions { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; @@ -85,6 +87,48 @@ TEST(StringsCheckerLibrary, SmokeTest) { )~bool^equals)"); } +TEST(StringsExtTest, MaxPrecisionOption) { + StringsExtensionOptions extension_options; + extension_options.max_precision = 99; + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("'abc %.100f'.format([2.0])", "")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), + opts, extension_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("precision specifier exceeds maximum of 99"))); +} + using StringsExtFunctionsTest = testing::TestWithParam; TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { diff --git a/internal/BUILD b/internal/BUILD index 59f68df9b..6d0efab72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -86,6 +86,18 @@ cc_library( ], ) +cc_library( + name = "runfiles", + srcs = ["runfiles.cc"], + hdrs = ["runfiles.h"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@rules_cc//cc/runfiles", + ], +) + cc_library( name = "status_builder", hdrs = ["status_builder.h"], @@ -296,6 +308,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) @@ -312,6 +325,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) @@ -523,6 +537,7 @@ cel_proto_transitive_descriptor_set( deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", + "//testutil:test_json_names_proto", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", diff --git a/internal/json.cc b/internal/json.cc index 200d18bfb..cdd4c1a5d 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -803,10 +803,10 @@ class MessageToJsonState { const auto* value_descriptor = field->message_type()->map_value(); CEL_ASSIGN_OR_RETURN(const auto value_to_value, GetMapFieldValueToValue(value_descriptor)); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, message, *field); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, message, *field); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + message, *field); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, message, *field); for (; begin != end; ++begin) { auto key = (*key_to_string)(begin.GetKey()); CEL_RETURN_IF_ERROR((this->*value_to_value)( @@ -1381,7 +1381,7 @@ class JsonMapIterator final { using Generated = typename google::protobuf::Map::const_iterator; - using Dynamic = google::protobuf::MapIterator; + using Dynamic = google::protobuf::ConstMapIterator; using Value = std::pair; @@ -1417,7 +1417,7 @@ class JsonMapIterator final { } private: - absl::variant variant_; + std::variant variant_; }; class JsonAccessor { diff --git a/internal/json.h b/internal/json.h index d32c42741..e35909d0e 100644 --- a/internal/json.h +++ b/internal/json.h @@ -26,7 +26,7 @@ namespace cel::internal { // Converts the given message to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageToJson( const google::protobuf::Message& message, @@ -45,7 +45,7 @@ absl::Status MessageToJson( google::protobuf::Message* absl_nonnull result); // Converts the given message field to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageFieldToJson( const google::protobuf::Message& message, diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 628432d66..33ef78089 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -50,9 +50,9 @@ namespace cel::internal { namespace { +using ::cel::extensions::protobuf_internal::ConstMapBegin; +using ::cel::extensions::protobuf_internal::ConstMapEnd; using ::cel::extensions::protobuf_internal::LookupMapValue; -using ::cel::extensions::protobuf_internal::MapBegin; -using ::cel::extensions::protobuf_internal::MapEnd; using ::cel::extensions::protobuf_internal::MapSize; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; @@ -86,10 +86,10 @@ class EquatableMessage final }; using EquatableValue = - absl::variant; + std::variant; struct NullValueEqualer { bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } @@ -904,8 +904,8 @@ class MessageEqualsState final { MapSize(*rhs_reflection, rhs, *rhs_field)) { return false; } - auto lhs_begin = MapBegin(*lhs_reflection, lhs, *lhs_field); - const auto lhs_end = MapEnd(*lhs_reflection, lhs, *lhs_field); + auto lhs_begin = ConstMapBegin(*lhs_reflection, lhs, *lhs_field); + const auto lhs_end = ConstMapEnd(*lhs_reflection, lhs, *lhs_field); Unique lhs_unpacked; EquatableValue lhs_value; Unique rhs_unpacked; diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index bc5914bef..092edd71b 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -110,22 +110,22 @@ TEST_P(UnaryMessageEqualsTest, Equals) { } EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); // Test any. auto lhs_any = PackMessage(*lhs); auto rhs_any = PackMessage(*rhs); EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs; + << lhs_any->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs_any; + << lhs->ShortDebugString() << " " << rhs_any->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs_any; + << lhs_any->ShortDebugString() << " " << rhs_any->ShortDebugString(); } } } @@ -399,7 +399,7 @@ absl::optional, PackTestAllTypesProto3Field(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field) { if (field->is_map()) { - return absl::nullopt; + return std::nullopt; } if (field->is_repeated() && field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -425,7 +425,7 @@ PackTestAllTypesProto3Field(const google::protobuf::Message& message, cel::to_address(packed), any_field)); return std::pair{packed, any_field}; } - return absl::nullopt; + return std::nullopt; } TEST_P(UnaryMessageFieldEqualsTest, Equals) { @@ -455,28 +455,30 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); if (!lhs_field->is_repeated() && lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } if (!rhs_field->is_repeated() && rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -485,14 +487,16 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_message, rhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( *rhs_message, rhs_field), *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } // Test `google.protobuf.Any`. absl::optional, @@ -505,21 +509,24 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); if (!lhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( *lhs_any->first, lhs_any->second), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); } } if (rhs_any) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); if (!rhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(*lhs_message, lhs_field, @@ -527,7 +534,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_any->first, rhs_any->second), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } if (lhs_any && rhs_any) { @@ -535,7 +543,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_any->second; + << lhs_any->first->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } } diff --git a/internal/new.cc b/internal/new.cc index 5bd9e8158..31ec82a08 100644 --- a/internal/new.cc +++ b/internal/new.cc @@ -67,6 +67,13 @@ void* AlignedNew(size_t size, std::align_val_t alignment) { ThrowStdBadAlloc(); } return ptr; +#elif defined(__APPLE__) + void* ptr; + if (ABSL_PREDICT_FALSE( + posix_memalign(&ptr, static_cast(alignment), size) != 0)) { + ThrowStdBadAlloc(); + } + return ptr; #else void* ptr = std::aligned_alloc(static_cast(alignment), size); if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { @@ -107,7 +114,7 @@ void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { - Delete(ptr, size); + ::operator delete(ptr); } else { #if defined(_MSC_VER) _aligned_free(ptr); diff --git a/internal/overflow_test.cc b/internal/overflow_test.cc index 38c5fa750..213e7a79d 100644 --- a/internal/overflow_test.cc +++ b/internal/overflow_test.cc @@ -57,25 +57,30 @@ INSTANTIATE_TEST_SUITE_P( CheckedIntMathTest, CheckedIntResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1L, 1L); }, 2L}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1L); }, 1L}, - {"ZeroAddMinusOne", [] { return CheckedAdd(0, -1L); }, -1L}, - {"OneAddZero", [] { return CheckedAdd(1L, 0); }, 1L}, - {"MinusOneAddZero", [] { return CheckedAdd(-1L, 0); }, -1L}, + {"OneAddOne", [] { return CheckedAdd(int64_t{1L}, 1L); }, 2L}, + {"ZeroAddOne", [] { return CheckedAdd(int64_t{0}, 1L); }, 1L}, + {"ZeroAddMinusOne", [] { return CheckedAdd(int64_t{0}, -1L); }, -1L}, + {"OneAddZero", [] { return CheckedAdd(int64_t{1L}, 0); }, 1L}, + {"MinusOneAddZero", [] { return CheckedAdd(int64_t{-1L}, 0); }, -1L}, {"OneAddIntMax", - [] { return CheckedAdd(1L, std::numeric_limits::max()); }, + [] { + return CheckedAdd(int64_t{1L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneAddIntMin", - [] { return CheckedAdd(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedAdd(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests. - {"TwoSubThree", [] { return CheckedSub(2L, 3L); }, -1L}, - {"TwoSubZero", [] { return CheckedSub(2L, 0); }, 2L}, - {"ZeroSubTwo", [] { return CheckedSub(0, 2L); }, -2L}, - {"MinusTwoSubThree", [] { return CheckedSub(-2L, 3L); }, -5L}, - {"MinusTwoSubZero", [] { return CheckedSub(-2L, 0); }, -2L}, - {"ZeroSubMinusTwo", [] { return CheckedSub(0, -2L); }, 2L}, + {"TwoSubThree", [] { return CheckedSub(int64_t{2L}, 3L); }, -1L}, + {"TwoSubZero", [] { return CheckedSub(int64_t{2L}, 0); }, 2L}, + {"ZeroSubTwo", [] { return CheckedSub(int64_t{0}, 2L); }, -2L}, + {"MinusTwoSubThree", [] { return CheckedSub(int64_t{-2L}, 3L); }, -5L}, + {"MinusTwoSubZero", [] { return CheckedSub(int64_t{-2L}, 0); }, -2L}, + {"ZeroSubMinusTwo", [] { return CheckedSub(int64_t{0}, -2L); }, 2L}, {"IntMinSubIntMax", [] { return CheckedSub(std::numeric_limits::max(), @@ -84,66 +89,100 @@ INSTANTIATE_TEST_SUITE_P( absl::OutOfRangeError("integer overflow")}, // Multiplication tests. - {"TwoMulThree", [] { return CheckedMul(2L, 3L); }, 6L}, - {"MinusTwoMulThree", [] { return CheckedMul(-2L, 3L); }, -6L}, - {"MinusTwoMulMinusThree", [] { return CheckedMul(-2L, -3L); }, 6L}, - {"TwoMulMinusThree", [] { return CheckedMul(2L, -3L); }, -6L}, + {"TwoMulThree", [] { return CheckedMul(int64_t{2L}, 3L); }, 6L}, + {"MinusTwoMulThree", [] { return CheckedMul(int64_t{-2L}, 3L); }, -6L}, + {"MinusTwoMulMinusThree", [] { return CheckedMul(int64_t{-2L}, -3L); }, + 6L}, + {"TwoMulMinusThree", [] { return CheckedMul(int64_t{2L}, -3L); }, -6L}, {"TwoMulIntMax", - [] { return CheckedMul(2L, std::numeric_limits::max()); }, + [] { + return CheckedMul(int64_t{2L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneMulIntMin", - [] { return CheckedMul(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulMinusOne", - [] { return CheckedMul(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulZero", - [] { return CheckedMul(std::numeric_limits::lowest(), 0); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{0}); + }, 0}, {"ZeroMulIntMin", - [] { return CheckedMul(0, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{0}, + std::numeric_limits::lowest()); + }, 0}, {"IntMaxMulZero", - [] { return CheckedMul(std::numeric_limits::max(), 0); }, 0}, + [] { + return CheckedMul(std::numeric_limits::max(), int64_t{0}); + }, + 0}, {"ZeroMulIntMax", - [] { return CheckedMul(0, std::numeric_limits::max()); }, 0}, + [] { + return CheckedMul(int64_t{0}, std::numeric_limits::max()); + }, + 0}, // Division cases. - {"ZeroDivOne", [] { return CheckedDiv(0, 1L); }, 0}, - {"TenDivTwo", [] { return CheckedDiv(10L, 2L); }, 5}, - {"TenDivMinusOne", [] { return CheckedDiv(10L, -1L); }, -10}, - {"MinusTenDivMinusOne", [] { return CheckedDiv(-10L, -1L); }, 10}, - {"MinusTenDivTwo", [] { return CheckedDiv(-10L, 2L); }, -5}, - {"OneDivZero", [] { return CheckedDiv(1L, 0L); }, + {"ZeroDivOne", [] { return CheckedDiv(int64_t{0}, 1L); }, 0}, + {"TenDivTwo", [] { return CheckedDiv(int64_t{10L}, 2L); }, 5}, + {"TenDivMinusOne", [] { return CheckedDiv(int64_t{10L}, -1L); }, -10}, + {"MinusTenDivMinusOne", [] { return CheckedDiv(int64_t{-10L}, -1L); }, + 10}, + {"MinusTenDivTwo", [] { return CheckedDiv(int64_t{-10L}, 2L); }, -5}, + {"OneDivZero", [] { return CheckedDiv(int64_t{1L}, 0L); }, absl::InvalidArgumentError("divide by zero")}, {"IntMinDivMinusOne", - [] { return CheckedDiv(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedDiv(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Modulus cases. - {"ZeroModTwo", [] { return CheckedMod(0, 2L); }, 0}, - {"TwoModTwo", [] { return CheckedMod(2L, 2L); }, 0}, - {"ThreeModTwo", [] { return CheckedMod(3L, 2L); }, 1L}, - {"TwoModZero", [] { return CheckedMod(2L, 0); }, + {"ZeroModTwo", [] { return CheckedMod(int64_t{0}, 2L); }, 0}, + {"TwoModTwo", [] { return CheckedMod(int64_t{2L}, 2L); }, 0}, + {"ThreeModTwo", [] { return CheckedMod(int64_t{3L}, 2L); }, 1L}, + {"TwoModZero", [] { return CheckedMod(int64_t{2L}, 0); }, absl::InvalidArgumentError("modulus by zero")}, {"IntMinModTwo", - [] { return CheckedMod(std::numeric_limits::lowest(), 2L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{2L}); + }, 0}, {"IntMaxModMinusOne", - [] { return CheckedMod(std::numeric_limits::max(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::max(), int64_t{-1L}); + }, 0}, {"IntMinModMinusOne", - [] { return CheckedMod(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Negation cases. - {"NegateOne", [] { return CheckedNegation(1L); }, -1L}, + {"NegateOne", [] { return CheckedNegation(int64_t{1L}); }, -1L}, {"NegateMinInt64", [] { return CheckedNegation(std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Numeric conversion cases for uint -> int, double -> int - {"Uint64Conversion", [] { return CheckedUint64ToInt64(1UL); }, 1L}, + {"Uint64Conversion", [] { return CheckedUint64ToInt64(uint64_t{1UL}); }, + 1L}, {"Uint32MaxConversion", [] { return CheckedUint64ToInt64( @@ -156,7 +195,8 @@ INSTANTIATE_TEST_SUITE_P( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToInt64(100.1); }, 100L}, + {"DoubleConversion", [] { return CheckedDoubleToInt64(double{100.1}); }, + 100L}, {"DoubleInt64MaxConversionError", [] { return CheckedDoubleToInt64( @@ -201,9 +241,10 @@ INSTANTIATE_TEST_SUITE_P( }, absl::OutOfRangeError("out of int64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToInt64(-1.0e99); }, + [] { return CheckedDoubleToInt64(double{-1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, - {"PosRangeConversionError", [] { return CheckedDoubleToInt64(1.0e99); }, + {"PosRangeConversionError", + [] { return CheckedDoubleToInt64(double{1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, }), [](const testing::TestParamInfo& info) { @@ -218,51 +259,58 @@ INSTANTIATE_TEST_SUITE_P( CheckedUintMathTest, CheckedUintResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1UL, 1UL); }, 2UL}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1UL); }, 1UL}, - {"OneAddZero", [] { return CheckedAdd(1UL, 0); }, 1UL}, + {"OneAddOne", [] { return CheckedAdd(uint64_t{1UL}, 1UL); }, 2UL}, + {"ZeroAddOne", [] { return CheckedAdd(uint64_t{0}, 1UL); }, 1UL}, + {"OneAddZero", [] { return CheckedAdd(uint64_t{1UL}, 0); }, 1UL}, {"OneAddIntMax", - [] { return CheckedAdd(1UL, std::numeric_limits::max()); }, + [] { + return CheckedAdd(uint64_t{1UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Subtraction tests. - {"OneSubOne", [] { return CheckedSub(1UL, 1UL); }, 0}, - {"ZeroSubOne", [] { return CheckedSub(0, 1UL); }, + {"OneSubOne", [] { return CheckedSub(uint64_t{1UL}, 1UL); }, 0}, + {"ZeroSubOne", [] { return CheckedSub(uint64_t{0}, 1UL); }, absl::OutOfRangeError("unsigned integer overflow")}, - {"OneSubZero", [] { return CheckedSub(1UL, 0); }, 1UL}, + {"OneSubZero", [] { return CheckedSub(uint64_t{1UL}, 0); }, 1UL}, // Multiplication tests. - {"OneMulOne", [] { return CheckedMul(1UL, 1UL); }, 1UL}, - {"ZeroMulOne", [] { return CheckedMul(0, 1UL); }, 0}, - {"OneMulZero", [] { return CheckedMul(1UL, 0); }, 0}, + {"OneMulOne", [] { return CheckedMul(uint64_t{1UL}, 1UL); }, 1UL}, + {"ZeroMulOne", [] { return CheckedMul(uint64_t{0}, 1UL); }, 0}, + {"OneMulZero", [] { return CheckedMul(uint64_t{1UL}, 0); }, 0}, {"TwoMulUintMax", - [] { return CheckedMul(2UL, std::numeric_limits::max()); }, + [] { + return CheckedMul(uint64_t{2UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Division tests. - {"TwoDivTwo", [] { return CheckedDiv(2UL, 2UL); }, 1UL}, - {"TwoDivFour", [] { return CheckedDiv(2UL, 4UL); }, 0}, - {"OneDivZero", [] { return CheckedDiv(1UL, 0); }, + {"TwoDivTwo", [] { return CheckedDiv(uint64_t{2UL}, 2UL); }, 1UL}, + {"TwoDivFour", [] { return CheckedDiv(uint64_t{2UL}, 4UL); }, 0}, + {"OneDivZero", [] { return CheckedDiv(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("divide by zero")}, // Modulus tests. - {"TwoModTwo", [] { return CheckedMod(2UL, 2UL); }, 0}, - {"TwoModFour", [] { return CheckedMod(2UL, 4UL); }, 2UL}, - {"OneModZero", [] { return CheckedMod(1UL, 0); }, + {"TwoModTwo", [] { return CheckedMod(uint64_t{2UL}, 2UL); }, 0}, + {"TwoModFour", [] { return CheckedMod(uint64_t{2UL}, 4UL); }, 2UL}, + {"OneModZero", [] { return CheckedMod(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("modulus by zero")}, // Conversion test cases for int -> uint, double -> uint. - {"Int64Conversion", [] { return CheckedInt64ToUint64(1L); }, 1UL}, + {"Int64Conversion", [] { return CheckedInt64ToUint64(int64_t{1L}); }, + 1UL}, {"Int64MaxConversion", [] { return CheckedInt64ToUint64(std::numeric_limits::max()); }, static_cast(std::numeric_limits::max())}, {"NegativeInt64ConversionError", - [] { return CheckedInt64ToUint64(-1L); }, + [] { return CheckedInt64ToUint64(int64_t{-1L}); }, absl::OutOfRangeError("out of uint64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToUint64(100.1); }, - 100UL}, + {"DoubleConversion", + [] { return CheckedDoubleToUint64(double{100.1}); }, 100UL}, {"DoubleUint64MaxConversionError", [] { return CheckedDoubleToUint64( @@ -287,13 +335,14 @@ INSTANTIATE_TEST_SUITE_P( std::numeric_limits::infinity()); }, absl::OutOfRangeError("out of uint64 range")}, - {"NegConversionError", [] { return CheckedDoubleToUint64(-1.1); }, + {"NegConversionError", + [] { return CheckedDoubleToUint64(double{-1.1}); }, absl::OutOfRangeError("out of uint64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToUint64(-1.0e99); }, + [] { return CheckedDoubleToUint64(double{-1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, {"PosRangeConversionError", - [] { return CheckedDoubleToUint64(1.0e99); }, + [] { return CheckedDoubleToUint64(double{1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, }), [](const testing::TestParamInfo& info) { @@ -571,7 +620,8 @@ TEST_P(CheckedConvertInt64Int32Test, Conversions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedConvertInt64Int32Test, CheckedConvertInt64Int32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedInt64ToInt32(1L); }, 1}, + {"SimpleConversion", [] { return CheckedInt64ToInt32(int64_t{1L}); }, + 1}, {"Int32MaxConversion", [] { return CheckedInt64ToInt32( @@ -610,7 +660,8 @@ TEST_P(CheckedConvertUint64Uint32Test, Conversions) { INSTANTIATE_TEST_SUITE_P( CheckedConvertUint64Uint32Test, CheckedConvertUint64Uint32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedUint64ToUint32(1UL); }, 1U}, + {"SimpleConversion", + [] { return CheckedUint64ToUint32(uint64_t{1UL}); }, 1U}, {"Uint32MaxConversion", [] { return CheckedUint64ToUint32( diff --git a/internal/proto_matchers.h b/internal/proto_matchers.h index 76d844036..02250634b 100644 --- a/internal/proto_matchers.h +++ b/internal/proto_matchers.h @@ -21,7 +21,6 @@ #include "absl/log/absl_check.h" #include "absl/memory/memory.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -43,13 +42,13 @@ class TextProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } @@ -58,7 +57,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p.New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(p)); + *message, google::protobuf::DownCastMessage(p)); } bool MatchAndExplain(const google::protobuf::Message* p, @@ -66,7 +65,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p->New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(*p)); + *message, google::protobuf::DownCastMessage(*p)); } inline void DescribeTo(::std::ostream* os) const { *os << expected_; } @@ -93,13 +92,13 @@ class ProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } diff --git a/internal/re2_options.h b/internal/re2_options.h index 9c20ceb63..25a30f6bd 100644 --- a/internal/re2_options.h +++ b/internal/re2_options.h @@ -45,13 +45,13 @@ inline absl::Status CheckRE2(const RE2& re, int max_program_size) { if (max_program_size > 0 && program_size > 0 && program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } int reverse_program_size = re.ReverseProgramSize(); if (max_program_size > 0 && reverse_program_size > 0 && reverse_program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } return absl::OkStatus(); } diff --git a/internal/runfiles.cc b/internal/runfiles.cc new file mode 100644 index 000000000..bffbfa9d1 --- /dev/null +++ b/internal/runfiles.cc @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/runfiles.h" + +#include +#include +#include + +#include "rules_cc/cc/runfiles/runfiles.h" +#include "absl/log/absl_check.h" + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +std::string ResolveRunfilesPath(absl::string_view path) { + using ::rules_cc::cc::runfiles::Runfiles; + static Runfiles* runfiles = []() { + std::string error; + auto runfiles = + Runfiles::CreateForTest(BAZEL_CURRENT_REPOSITORY, &error); + ABSL_QCHECK(runfiles != nullptr) + << absl::StrCat("failed to init runfiles", error); + return runfiles; + }(); + return runfiles->Rlocation(std::string(path)); +} + +absl::Status GetFileContents(absl::string_view path, std::string* out) { + std::ifstream file{std::string(path)}; + if (!file.is_open()) { + return absl::NotFoundError(absl::StrCat("Failed to open file: ", path)); + } + out->append((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + return absl::OkStatus(); +} + +} // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h new file mode 100644 index 000000000..11fdcf337 --- /dev/null +++ b/internal/runfiles.h @@ -0,0 +1,36 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Utilities for working with bazel runfiles. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Resolves a path relative to the runfiles directory. +// Intended for resolving test cases from cel-spec and cel-policy. +std::string ResolveRunfilesPath(absl::string_view path); + +// Read contents of a file at a resolved path to a string. +absl::Status GetFileContents(absl::string_view path, std::string* out); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ diff --git a/internal/strings_test.cc b/internal/strings_test.cc index d6c90473e..fcdb6d4ec 100644 --- a/internal/strings_test.cc +++ b/internal/strings_test.cc @@ -24,6 +24,7 @@ #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "internal/testing.h" diff --git a/internal/testing.cc b/internal/testing.cc index 77e4c65b4..84aa58cce 100644 --- a/internal/testing.cc +++ b/internal/testing.cc @@ -14,6 +14,8 @@ #include "internal/testing.h" +#include "absl/strings/str_cat.h" // IWYU pragma: keep + namespace cel::internal { void AddFatalFailure(const char* file, int line, absl::string_view expression, diff --git a/internal/to_address.h b/internal/to_address.h index 5dffef3c1..36e7eeb60 100644 --- a/internal/to_address.h +++ b/internal/to_address.h @@ -49,7 +49,7 @@ struct PointerTraitsToAddress { template struct PointerTraitsToAddress< - T, absl::void_t::to_address( + T, std::void_t::to_address( std::declval()))> > { static constexpr auto Dispatch( const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index f66a9360b..02e50c3e3 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -71,6 +71,17 @@ using ::google::protobuf::util::TimeUtil; using CppStringType = ::google::protobuf::FieldDescriptor::CppStringType; +FieldDescriptor::Label GetFieldLabel( + const FieldDescriptor* absl_nonnull field) { + if (field->is_required()) { + return FieldDescriptor::LABEL_REQUIRED; + } else if (field->is_repeated()) { + return FieldDescriptor::LABEL_REPEATED; + } else { + return FieldDescriptor::LABEL_OPTIONAL; + } +} + absl::string_view FlatStringValue( const StringValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { @@ -264,11 +275,11 @@ absl::string_view LabelToString(FieldDescriptor::Label label) { absl::Status CheckFieldCardinality(const FieldDescriptor* absl_nonnull field, FieldDescriptor::Label label) { - if (ABSL_PREDICT_FALSE(field->label() != label)) { - return absl::InvalidArgumentError( - absl::StrCat("unexpected field cardinality for protocol buffer message " - "well known type: ", - field->full_name(), " ", LabelToString(field->label()))); + if (ABSL_PREDICT_FALSE(GetFieldLabel(field) != label)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field cardinality for protocol buffer message " + "well known type: ", + field->full_name(), " ", LabelToString(GetFieldLabel(field)))); } return absl::OkStatus(); } @@ -1632,20 +1643,20 @@ int StructReflection::FieldsSize(const google::protobuf::Message& message) const message, *fields_field_); } -google::protobuf::MapIterator StructReflection::BeginFields( +google::protobuf::ConstMapIterator StructReflection::BeginFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapBegin(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapBegin( + *message.GetReflection(), message, *fields_field_); } -google::protobuf::MapIterator StructReflection::EndFields( +google::protobuf::ConstMapIterator StructReflection::EndFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapEnd(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapEnd( + *message.GetReflection(), message, *fields_field_); } bool StructReflection::ContainsField(const google::protobuf::Message& message, @@ -2163,7 +2174,7 @@ absl::StatusOr AdaptFromMessage( if (adapted) { return adapted; } - return absl::monostate{}; + return std::monostate{}; } } diff --git a/internal/well_known_types.h b/internal/well_known_types.h index dce88a420..f63e5e76b 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -698,8 +698,9 @@ absl::StatusOr GetAnyReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); -AnyReflection GetAnyReflectionOrDie(const google::protobuf::Descriptor* absl_nonnull - descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); +AnyReflection GetAnyReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); class DurationReflection final { public: @@ -1193,10 +1194,10 @@ class StructReflection final { int FieldsSize(const google::protobuf::Message& message) const; - google::protobuf::MapIterator BeginFields( + google::protobuf::ConstMapIterator BeginFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; - google::protobuf::MapIterator EndFields( + google::protobuf::ConstMapIterator EndFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; bool ContainsField(const google::protobuf::Message& message, diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc index 0d2c9fe33..afc8ce396 100644 --- a/internal/well_known_types_test.cc +++ b/internal/well_known_types_test.cc @@ -806,7 +806,7 @@ TEST_F(AdaptFromMessageTest, Struct) { TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(AdaptFromMessage(*message), - IsOkAndHolds(VariantWith(absl::monostate()))); + IsOkAndHolds(VariantWith(std::monostate()))); } TEST_F(AdaptFromMessageTest, Any_BoolValue) { diff --git a/parser/BUILD b/parser/BUILD index 63813bb59..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -244,9 +244,11 @@ cc_library( ":options", "//common:ast", "//common:source", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/parser/macro.cc b/parser/macro.cc index eaa1ebd1a..815b07401 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -25,6 +25,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -40,6 +41,11 @@ namespace { using google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + inline MacroExpander ToMacroExpander(GlobalMacroExpander expander) { ABSL_DCHECK(expander); return [expander = std::move(expander)]( @@ -87,14 +93,14 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("all() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("all() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -119,14 +125,14 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("exists() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -153,14 +159,14 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("exists_one() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists_one() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -192,14 +198,14 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("map() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -225,14 +231,14 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("map() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -260,14 +266,14 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("filter() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("filter() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto name = args[0].ident_expr().name(); @@ -298,14 +304,14 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("optMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); @@ -337,14 +343,14 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("optFlatMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optFlatMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index ffba5e2f2..c66aa4fe0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -319,8 +319,7 @@ class MacroExprFactory : protected ExprFactory { friend class ParserMacroExprFactory; friend class TestMacroExprFactory; - explicit MacroExprFactory(absl::string_view accu_var) - : ExprFactory(accu_var) {} + explicit MacroExprFactory() = default; }; } // namespace cel diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 04705eec6..b95cbe16f 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -15,6 +15,7 @@ #include "parser/macro_expr_factory.h" #include +#include #include #include "absl/strings/string_view.h" @@ -27,7 +28,7 @@ namespace cel { class TestMacroExprFactory final : public MacroExprFactory { public: - TestMacroExprFactory() : MacroExprFactory(kAccumulatorVariableName) {} + TestMacroExprFactory() = default; ExprId id() const { return id_; } @@ -39,6 +40,7 @@ class TestMacroExprFactory final : public MacroExprFactory { return NewUnspecified(NextId()); } + using MacroExprFactory::NewBind; using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; @@ -69,6 +71,8 @@ class TestMacroExprFactory final : public MacroExprFactory { namespace { +using ::testing::IsEmpty; + TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); @@ -147,5 +151,52 @@ TEST(MacroExprFactory, CopyComprehension) { factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc index 3fc77f18c..d36761e87 100644 --- a/parser/macro_registry.cc +++ b/parser/macro_registry.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/match.h" @@ -54,7 +55,7 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, bool receiver_style) const { // :: if (name.empty() || absl::StrContains(name, ':')) { - return absl::nullopt; + return std::nullopt; } // Try argument count specific key first. auto key = absl::StrCat(name, ":", arg_count, ":", @@ -67,7 +68,16 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, if (auto it = macros_.find(key); it != macros_.end()) { return it->second; } - return absl::nullopt; + return std::nullopt; +} + +std::vector MacroRegistry::ListMacros() const { + std::vector macros; + macros.reserve(macros_.size()); + for (auto it = macros_.begin(); it != macros_.end(); ++it) { + macros.push_back(it->second); + } + return macros; } bool MacroRegistry::RegisterMacroImpl(const Macro& macro) { diff --git a/parser/macro_registry.h b/parser/macro_registry.h index 51899bade..01a0634ef 100644 --- a/parser/macro_registry.h +++ b/parser/macro_registry.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -44,6 +45,9 @@ class MacroRegistry final { absl::optional FindMacro(absl::string_view name, size_t arg_count, bool receiver_style) const; + // Returns a copy of all registered macros. + std::vector ListMacros() const; + private: bool RegisterMacroImpl(const Macro& macro); diff --git a/parser/macro_registry_test.cc b/parser/macro_registry_test.cc index 9e6da87a4..db8a99ab2 100644 --- a/parser/macro_registry_test.cc +++ b/parser/macro_registry_test.cc @@ -30,14 +30,14 @@ using ::testing::Ne; TEST(MacroRegistry, RegisterAndFind) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacro(HasMacro()), IsOk()); - EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(std::nullopt)); } TEST(MacroRegistry, RegisterRollsback) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacros({HasMacro(), AllMacro(), AllMacro()}), StatusIs(absl::StatusCode::kAlreadyExists)); - EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(std::nullopt)); } } // namespace diff --git a/parser/options.h b/parser/options.h index ad03102e8..719bed454 100644 --- a/parser/options.h +++ b/parser/options.h @@ -51,14 +51,21 @@ struct ParserOptions final { // Disable standard macros (has, all, exists, exists_one, filter, map). bool disable_standard_macros = false; - // Enable hidden accumulator variable '@result' for builtin comprehensions. + // Deprecated: The builtin and extension macros now always use the new + // accumulator variable name. + // This option has no effect. bool enable_hidden_accumulator_var = true; // Enables support for identifier quoting syntax: // "message.`skewer-case-field`" // - // Limited to field specifiers in select and message creation. - bool enable_quoted_identifiers = false; + // Limited to field specifiers in select and message creation, + // enabled by default + bool enable_quoted_identifiers = true; + + // Enables parsing logical AND & OR operators as a single flat variadic call + // instead of a balanced/nested binary AST structure. + bool enable_variadic_logical_operators = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index d430e3169..24b4ca079 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -112,13 +112,12 @@ struct ParserError { }; std::string DisplayParserError(const cel::Source& source, - const ParserError& error) { - auto location = - source.GetLocation(error.range.begin).value_or(SourceLocation{}); + SourceLocation location, + absl::string_view message) { return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column - location.column + 1, error.message), + location.column + 1, message), source.DisplayErrorLocation(location)); } @@ -163,9 +162,8 @@ SourceRange SourceRangeFromParserRuleContext( class ParserMacroExprFactory final : public MacroExprFactory { public: - explicit ParserMacroExprFactory(const cel::Source& source, - absl::string_view accu_var) - : MacroExprFactory(accu_var), source_(source) {} + explicit ParserMacroExprFactory(const cel::Source& source) + : source_(source) {} void BeginMacro(SourceRange macro_position) { macro_position_ = macro_position; @@ -210,7 +208,7 @@ class ParserMacroExprFactory final : public MacroExprFactory { bool HasErrors() const { return error_count_ != 0; } - std::string ErrorMessage() { + std::vector CollectIssues() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location @@ -227,20 +225,23 @@ class ParserMacroExprFactory final : public MacroExprFactory { }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; - std::vector messages; - messages.reserve( + std::vector issues; + issues.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. - std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), - [this](const ParserError& error) { - return cel::DisplayParserError(source_, error); - }); + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); if (errors_truncated) { - messages.emplace_back( - absl::StrCat(error_count_ - 100, " more errors were truncated.")); + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); } - return absl::StrJoin(messages, "\n"); + return issues; } void AddMacroCall(int64_t macro_id, absl::string_view function, @@ -551,7 +552,7 @@ class ExpressionBalancer final { // balance creates a balanced tree from the sub-terms and returns the final // Expr value. - Expr Balance(); + Expr Balance(bool enable_variadic = false); private: // balancedTree recursively balances the terms provided to a commutative @@ -576,10 +577,13 @@ void ExpressionBalancer::AddTerm(int64_t op, Expr term) { ops_.push_back(op); } -Expr ExpressionBalancer::Balance() { +Expr ExpressionBalancer::Balance(bool enable_variadic) { if (terms_.size() == 1) { return std::move(terms_[0]); } + if (enable_variadic) { + return factory_.NewCall(ops_[0], function_, std::move(terms_)); + } return BalancedTree(0, ops_.size() - 1); } @@ -603,23 +607,33 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, - absl::string_view accu_var, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, - bool enable_quoted_identifiers = false) + bool enable_quoted_identifiers = false, + bool enable_variadic_logical_operators = false) : source_(source), - factory_(source_, accu_var), + factory_(source_), macro_registry_(macro_registry), recursion_depth_(0), max_recursion_depth_(max_recursion_depth), add_macro_calls_(add_macro_calls), enable_optional_syntax_(enable_optional_syntax), - enable_quoted_identifiers_(enable_quoted_identifiers) {} + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} ~ParserVisitor() override = default; @@ -675,7 +689,7 @@ class ParserVisitor final : public CelBaseVisitor, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; - std::string ErrorMessage(); + std::vector CollectIssues(); private: template @@ -710,6 +724,7 @@ class ParserVisitor final : public CelBaseVisitor, const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; + const bool enable_variadic_logical_operators_; }; template op)); + int64_t obj_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector fields; if (ctx->entries) { fields = visitFields(ctx->entries); @@ -1191,7 +1206,7 @@ std::any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { } std::any ParserVisitor::visitCreateList(CelParser::CreateListContext* ctx) { - int64_t list_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t list_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); auto elems = visitList(ctx->elems); return ExprToAny(factory_.NewList(list_id, std::move(elems))); } @@ -1229,7 +1244,7 @@ std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { } std::any ParserVisitor::visitCreateMap(CelParser::CreateMapContext* ctx) { - int64_t struct_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t struct_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector entries; if (ctx->entries) { entries = visitEntries(ctx->entries); @@ -1436,7 +1451,9 @@ void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } -std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, @@ -1451,11 +1468,11 @@ Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, } } factory_.BeginMacro(factory_.GetSourceRange(expr_id)); - auto expr = macro->Expand(factory_, absl::nullopt, absl::MakeSpan(args)); + auto expr = macro->Expand(factory_, std::nullopt, absl::MakeSpan(args)); factory_.EndMacro(); if (expr) { if (add_macro_calls_) { - factory_.AddMacroCall(expr->id(), function, absl::nullopt, + factory_.AddMacroCall(expr->id(), function, std::nullopt, std::move(macro_args)); } // We did not end up using `expr_id`. Delete metadata. @@ -1640,9 +1657,10 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; -absl::StatusOr ParseImpl(const cel::Source& source, - const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1654,14 +1672,10 @@ absl::StatusOr ParseImpl(const cel::Source& source, CommonTokenStream tokens(&lexer); CelParser parser(&tokens); ExprRecursionListener listener(options.max_recursion_depth); - absl::string_view accu_var = cel::kAccumulatorVariableName; - if (options.enable_hidden_accumulator_var) { - accu_var = cel::kHiddenAccumulatorVariableName; - } - ParserVisitor visitor(source, options.max_recursion_depth, accu_var, - registry, options.add_macro_calls, - options.enable_optional_syntax, - options.enable_quoted_identifiers); + ParserVisitor visitor( + source, options.max_recursion_depth, registry, options.add_macro_calls, + options.enable_optional_syntax, options.enable_quoted_identifiers, + options.enable_variadic_logical_operators); lexer.removeErrorListeners(); parser.removeErrorListeners(); @@ -1680,13 +1694,23 @@ absl::StatusOr ParseImpl(const cel::Source& source, expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return { @@ -1707,19 +1731,28 @@ absl::StatusOr ParseImpl(const cel::Source& source, class ParserImpl : public cel::Parser { public: explicit ParserImpl(const ParserOptions& options, - cel::MacroRegistry macro_registry) - : options_(options), macro_registry_(std::move(macro_registry)) {} - absl::StatusOr> Parse( - const cel::Source& source) const override { + cel::MacroRegistry macro_registry, + absl::flat_hash_set library_ids) + : options_(options), + macro_registry_(std::move(macro_registry)), + library_ids_(std::move(library_ids)) {} + + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, - ParseImpl(source, macro_registry_, options_)); + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } + std::unique_ptr ToBuilder() const override; + private: const ParserOptions options_; const cel::MacroRegistry macro_registry_; + absl::flat_hash_set library_ids_; }; class ParserBuilderImpl : public cel::ParserBuilder { @@ -1796,21 +1829,28 @@ class ParserBuilderImpl : public cel::ParserBuilder { macros_.clear(); } + absl::flat_hash_set library_ids(library_ids_); + // Hack to support adding the standard library macros either by option or // with a library configurer. if (!options_.disable_standard_macros && !library_ids_.contains("stdlib")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); + library_ids.insert("stdlib"); } if (options_.enable_optional_syntax && !library_ids_.contains("optional")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); + library_ids.insert("optional"); } CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(individual_macros)); - return std::make_unique(options_, std::move(macro_registry)); + return std::make_unique(options_, std::move(macro_registry), + std::move(library_ids)); } private: + friend class ParserImpl; + ParserOptions options_; std::vector macros_; absl::flat_hash_set library_ids_; @@ -1818,6 +1858,13 @@ class ParserBuilderImpl : public cel::ParserBuilder { absl::flat_hash_map library_subsets_; }; +std::unique_ptr ParserImpl::ToBuilder() const { + auto ins = std::make_unique(options_); + ins->library_ids_ = library_ids_; + ins->macros_ = macro_registry_.ListMacros(); + return ins; +} + } // namespace absl::StatusOr Parse(absl::string_view expression, diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 0992385f7..ad6e8ca84 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -16,10 +16,14 @@ #include #include +#include +#include +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" @@ -73,6 +77,26 @@ class ParserBuilder { virtual absl::StatusOr> Build() = 0; }; +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. @@ -81,10 +105,35 @@ class Parser { virtual ~Parser() = default; // Parses the given source into a CEL AST. - virtual absl::StatusOr> Parse( - const cel::Source& source) const = 0; + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; + + // Returns a builder initialized with the configuration of this parser. + virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; }; +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index aee121051..35f11b413 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -631,6 +631,27 @@ std::vector test_cases = { "ERROR: :1:7: all() variable name must be a simple identifier\n" " | 1.all(2, 3)\n" " | ......^"}, + {"[].all(.x, x)", "", + "ERROR: :1:9: all() variable name must be a simple identifier\n" + " | [].all(.x, x)\n" + " | ........^"}, + {"[].exists(.x, x)", "", + "ERROR: :1:12: exists() variable name must be a simple identifier\n" + " | [].exists(.x, x)\n" + " | ...........^"}, + {"[].exists_one(.x, x)", "", + "ERROR: :1:16: exists_one() variable name must be a simple " + "identifier\n" + " | [].exists_one(.x, x)\n" + " | ...............^"}, + {"[].map(.x, x, x)", "", + "ERROR: :1:9: map() variable name must be a simple identifier\n" + " | [].map(.x, x, x)\n" + " | ........^"}, + {"[].filter(.x, x)", "", + "ERROR: :1:12: filter() variable name must be a simple identifier\n" + " | [].filter(.x, x)\n" + " | ...........^"}, {"x[\"a\"].single_int32 == 23", "_==_(\n" " _[_](\n" @@ -1473,7 +1494,6 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); ParserOptions options; - options.enable_hidden_accumulator_var = true; if (!test_info.M.empty()) { options.add_macro_calls = true; } @@ -1495,14 +1515,16 @@ TEST_P(ExpressionTest, Parse) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); ; } @@ -1514,11 +1536,34 @@ TEST_P(ExpressionTest, Parse) { if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) - << result->parsed_expr(); + << result->parsed_expr().ShortDebugString(); ; } } +TEST(ExpressionTest, CompositeExpressionOffsets) { + ParserOptions options; + std::vector macros = Macro::AllMacros(); + + std::string list_expr = "[1, 2]"; + auto list_result = EnrichedParse(list_expr, macros, "", options); + ASSERT_THAT(list_result, IsOk()); + auto list_offsets = list_result->enriched_source_info().offsets(); + EXPECT_EQ(list_offsets.at(1), std::make_pair(0, 5)); + + std::string map_expr = "{'a': 1}"; + auto map_result = EnrichedParse(map_expr, macros, "", options); + ASSERT_THAT(map_result, IsOk()); + auto map_offsets = map_result->enriched_source_info().offsets(); + EXPECT_EQ(map_offsets.at(1), std::make_pair(0, 7)); + + std::string msg_expr = "Msg{f: 1}"; + auto msg_result = EnrichedParse(msg_expr, macros, "", options); + ASSERT_THAT(msg_result, IsOk()); + auto msg_offsets = msg_result->enriched_source_info().offsets(); + EXPECT_EQ(msg_offsets.at(1), std::make_pair(0, 8)); +} + TEST(ExpressionTest, TsanOom) { Parse( "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" @@ -1626,269 +1671,6 @@ TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { EXPECT_THAT(result, IsOk()); } -const std::vector& UpdatedAccuVarTestCases() { - static const std::vector* kInstance = new std::vector{ - {"[].exists(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " false^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " !_(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#\n" - " )^#10:Expr.Call#,\n" - " // LoopStep\n" - " _||_(\n" - " __result__^#11:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#12:Expr.Call#,\n" - " // Result\n" - " __result__^#13:Expr.Ident#)^#14:Expr.Comprehension#"}, - {"[].exists_one(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " 0^#7:int64#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " 1^#10:int64#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " _==_(\n" - " __result__^#14:Expr.Ident#,\n" - " 1^#15:int64#\n" - " )^#16:Expr.Call#)^#17:Expr.Comprehension#"}, - {"[].all(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " true^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#,\n" - " // LoopStep\n" - " _&&_(\n" - " __result__^#10:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#4:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#5:Expr.Call#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x > 0, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#10:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#11:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#12:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#7:Expr.Ident#,\n" - " 1^#9:int64#\n" - " )^#8:Expr.Call#\n" - " ]^#13:Expr.CreateList#\n" - " )^#14:Expr.Call#,\n" - " __result__^#15:Expr.Ident#\n" - " )^#16:Expr.Call#,\n" - " // Result\n" - " __result__^#17:Expr.Ident#)^#18:Expr.Comprehension#"}, - {"[].filter(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " x^#3:Expr.Ident#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " __result__^#14:Expr.Ident#)^#15:Expr.Comprehension#"}, - // Maintain restriction on '__result__' variable name until the default is - // changed everywhere. - { - "[].map(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true)\n" - " | ...................^", - }, - { - "[].map(__result__, true, false)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true, false)\n" - " | ...................^", - }, - { - "[].filter(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: filter() variable name cannot be __result__\n" - " | [].filter(__result__, true)\n" - " | ......................^", - }, - { - "[].exists(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: exists() variable name cannot be __result__\n" - " | [].exists(__result__, true)\n" - " | ......................^", - }, - { - "[].all(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: all() variable name cannot be __result__\n" - " | [].all(__result__, true)\n" - " | ...................^", - }, - { - "[].exists_one(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:27: exists_one() variable name cannot be " - "__result__\n" - " | [].exists_one(__result__, true)\n" - " | ..........................^", - }}; - return *kInstance; -} - -class UpdatedAccuVarDisabledTest : public testing::TestWithParam {}; - -TEST_P(UpdatedAccuVarDisabledTest, Parse) { - const TestInfo& test_info = GetParam(); - ParserOptions options; - options.enable_hidden_accumulator_var = false; - if (!test_info.M.empty()) { - options.add_macro_calls = true; - } - - auto result = - EnrichedParse(test_info.I, Macro::AllMacros(), "", options); - if (test_info.E.empty()) { - EXPECT_THAT(result, IsOk()); - } else { - EXPECT_THAT(result, Not(IsOk())); - EXPECT_EQ(test_info.E, result.status().message()); - } - - if (!test_info.P.empty()) { - KindAndIdAdorner kind_and_id_adorner; - ExprPrinter w(kind_and_id_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); - } - - if (!test_info.L.empty()) { - LocationAdorner location_adorner(result->parsed_expr().source_info()); - ExprPrinter w(location_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); - } - - if (!test_info.R.empty()) { - EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( - result->enriched_source_info())); - } - - if (!test_info.M.empty()) { - EXPECT_EQ(test_info.M, ConvertMacroCallsToString( - result.value().parsed_expr().source_info())) - << result->parsed_expr(); - } -} - TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); @@ -1969,6 +1751,153 @@ TEST(NewParserBuilderTest, ForwardsOptions) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(NewParserBuilderTest, ToBuilderCopiesConfig) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddLibrary({"custom_lib", + [](cel::ParserBuilder& b) { + return b.AddMacro(cel::HasMacro()); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + EXPECT_TRUE(derived_builder->GetOptions().enable_optional_syntax); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b && has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, ToBuilderHandlesStdlibAndOptionalByLibrary) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + builder->GetOptions().enable_optional_syntax = false; + + // Abusing the library ids for testing. Real uses should use subsetting. + ASSERT_THAT( + builder->AddLibrary( + {"stdlib", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + ASSERT_THAT( + builder->AddLibrary( + {"optional", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + // Should be ignored now. + derived_builder->GetOptions().disable_standard_macros = false; + derived_builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#1:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = false; + builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [?a]")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +struct VariadicLogicalOperatorsTestCase { + std::string input; + std::string expected_adorned_string; +}; + +class VariadicLogicalOperatorsTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalOperatorsTest, Parse) { + const auto& test_case = GetParam(); + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.input)); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.Print(ast->root_expr()); + EXPECT_EQ(adorned_string, test_case.expected_adorned_string); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalOperators, VariadicLogicalOperatorsTest, + testing::Values( + VariadicLogicalOperatorsTestCase{ + .input = "a && b && c && d", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a || b || c || d", + .expected_adorned_string = "_||_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a && b && (c || d || e)", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " _||_(\n" + " c^#4:Expr.Ident#,\n" + " d^#5:Expr.Ident#,\n" + " e^#7:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#3:Expr.Call#"})); + +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); @@ -1979,9 +1908,5 @@ std::string TestName(const testing::TestParamInfo& test_info) { INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases), TestName); -INSTANTIATE_TEST_SUITE_P(UpdatedAccuVarTest, UpdatedAccuVarDisabledTest, - testing::ValuesIn(UpdatedAccuVarTestCases()), - TestName); - } // namespace } // namespace google::api::expr::parser diff --git a/policy/BUILD b/policy/BUILD new file mode 100644 index 000000000..19195be2b --- /dev/null +++ b/policy/BUILD @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cel_policy", + srcs = [ + "cel_policy.cc", + ], + hdrs = [ + "cel_policy.h", + ], + deps = [ + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cel_policy_test", + srcs = ["cel_policy_test.cc"], + deps = [ + ":cel_policy", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cel_policy_parser", + srcs = [ + "cel_policy_parse_context.cc", + "cel_policy_parse_result.cc", + ], + hdrs = [ + "cel_policy_parse_context.h", + "cel_policy_parse_result.h", + "cel_policy_parser.h", + ], + deps = [ + ":cel_policy", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "yaml_policy_parser", + srcs = [ + "yaml_policy_parser.cc", + ], + hdrs = ["yaml_policy_parser.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:source", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_library( + name = "cel_policy_validation_result", + srcs = [ + "cel_policy_validation_result.cc", + ], + hdrs = [ + "cel_policy_validation_result.h", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//policy/internal:issue_reporter", + "//policy/internal:optimizer_expr_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "yaml_policy_parser_test", + srcs = [ + "test_custom_yaml_policy_parser.cc", + "yaml_policy_parser_test.cc", + ], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":yaml_policy_parser", + "//common:source", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + ":compiler", + ":yaml_policy_parser", + "//common:ast", + "//common:decl", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@yaml-cpp", + ], +) diff --git a/policy/cel_policy.cc b/policy/cel_policy.cc new file mode 100644 index 000000000..c2d97edeb --- /dev/null +++ b/policy/cel_policy.cc @@ -0,0 +1,273 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +std::string IdDebugString(CelPolicyElementId id) { + if (id == -1) { + return ""; + } + return absl::StrCat("#", id, "> "); +} + +std::string IndentBlock(absl::string_view text) { + if (text.empty()) { + return ""; + } + std::vector lines; + for (absl::string_view line : absl::StrSplit(text, '\n')) { + if (line.empty()) { + lines.push_back(""); + } else { + lines.push_back(absl::StrCat(" ", line)); + } + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +void CelPolicySource::NoteSourcePosition(CelPolicyElementId id, + SourcePosition position) { + source_positions_[id] = position; +} + +std::optional CelPolicySource::GetSourcePosition( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return it->second; +} + +std::optional CelPolicySource::GetSourceLocation( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return policy_source_->GetLocation(it->second); +} + +std::string CelPolicySource::DebugString() const { + std::string result; + + // Sort the source elements in descending order of position + std::vector> sorted_positions; + for (const auto& pair : source_positions_) { + sorted_positions.push_back(pair); + } + std::sort(sorted_positions.begin(), sorted_positions.end(), + [](const auto& a, const auto& b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + }); + + result = policy_source_->content().ToString(); + for (const auto& [id, position] : sorted_positions) { + result.insert(position, IdDebugString(id)); + } + return result; +} + +std::string ValueString::DebugString() const { + return absl::StrCat(IdDebugString(id_), "\"", value_, "\""); +} + +std::string Import::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "name: ", name_.DebugString()); + return result; +} + +std::string OutputBlock::DebugString() const { + std::string result; + absl::StrAppend(&result, "output: ", output_.DebugString()); + if (explanation_.has_value()) { + absl::StrAppend(&result, "\nexplanation: ", explanation_->DebugString()); + } + return result; +} + +Match::Match(const Match& other) + : id_(other.id_), condition_(other.condition_) { + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = + std::make_unique(*std::get>(other.result_)); + } +} + +Match& Match::operator=(const Match& other) { + if (this != &other) { + id_ = other.id_; + condition_ = other.condition_; + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = std::make_unique( + *std::get>(other.result_)); + } + } + return *this; +} + +std::string Match::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "match: {\n"); + if (condition_.has_value()) { + absl::StrAppend(&result, " condition: ", condition_->DebugString(), "\n"); + } + if (has_rule()) { + absl::StrAppend(&result, " result:\n", + IndentBlock(IndentBlock(rule().DebugString())), "\n"); + } else { + absl::StrAppend(&result, " result: {\n", + IndentBlock(IndentBlock(output_block().DebugString())), + "\n }\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Variable::DebugString() const { + std::string result; + absl::StrAppend(&result, "variable: {\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + absl::StrAppend(&result, " expression: ", expression_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Rule::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "rule: {\n"); + if (rule_id_.has_value()) { + absl::StrAppend(&result, " rule_id: ", rule_id_->DebugString(), "\n"); + } + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + for (const Variable& variable : variables_) { + absl::StrAppend(&result, IndentBlock(variable.DebugString()), "\n"); + } + for (const Match& match : matches_) { + absl::StrAppend(&result, IndentBlock(match.DebugString()), "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string MetadataValueDebugString(std::any value) { + if (value.type() == typeid(std::monostate)) { + return "null"; + } + if (value.type() == typeid(ValueString)) { + return std::any_cast(value).DebugString(); + } + if (value.type() == typeid(bool)) { + return std::any_cast(value) ? "true" : "false"; + } + if (value.type() == typeid(int)) { + return absl::StrCat(std::any_cast(value)); + } + if (value.type() == typeid(std::string)) { + return std::any_cast(value); + } + return absl::StrCat("typeid: ", value.type().name()); +} + +std::string CelPolicy::DebugString() const { + std::string result; + absl::StrAppend(&result, "CelPolicy{\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, IndentBlock(IndentBlock(source_->DebugString())), + "\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + if (!metadata_.empty()) { + std::vector sorted_keys; + for (const auto& [key, _] : metadata_) { + sorted_keys.push_back(key); + } + std::sort(sorted_keys.begin(), sorted_keys.end()); + + absl::StrAppend(&result, " metadata: {\n"); + for (const auto& key : sorted_keys) { + const auto& value = metadata_.at(key); + absl::StrAppend(&result, " ", key, ": ", + MetadataValueDebugString(value), "\n"); + } + absl::StrAppend(&result, " }\n"); + } + if (!imports_.empty()) { + absl::StrAppend(&result, " imports:\n"); + for (const Import& import : imports_) { + absl::StrAppend(&result, " ", import.DebugString(), "\n"); + } + } + absl::StrAppend(&result, IndentBlock(rule_.DebugString()), "\n"); + absl::StrAppend(&result, "}"); + return result; +} + +} // namespace cel diff --git a/policy/cel_policy.h b/policy/cel_policy.h new file mode 100644 index 000000000..af8f7c977 --- /dev/null +++ b/policy/cel_policy.h @@ -0,0 +1,320 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +using CelPolicyElementId = int32_t; + +class CelPolicySource { + public: + explicit CelPolicySource(cel::SourcePtr policy_source) + : policy_source_(std::move(policy_source)) {} + + const Source* absl_nonnull content() const { return policy_source_.get(); } + + void NoteSourcePosition(CelPolicyElementId id, SourcePosition position); + + std::optional GetSourcePosition(CelPolicyElementId id) const; + + std::optional GetSourceLocation(CelPolicyElementId id) const; + + std::string DebugString() const; + + private: + cel::SourcePtr policy_source_; + absl::flat_hash_map source_positions_; +}; + +class ValueString { + public: + ValueString() : id_(-1) {} + + explicit ValueString(CelPolicyElementId id, absl::string_view value) + : id_(id), value_(value) {} + + CelPolicyElementId id() const { return id_; } + absl::string_view value() const { return value_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + std::string value_; +}; + +class Import { + public: + Import(CelPolicyElementId id, ValueString name) + : id_(id), name_(std::move(name)) {} + CelPolicyElementId id() const { return id_; } + const ValueString& name() const { return name_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + ValueString name_; +}; + +// Defines a variable that can be used in CEL expressions within the policy. +// Variables are evaluated once and stored in the activation context. +class Variable { + public: + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + const ValueString& expression() const { return expression_; } + void set_expression(ValueString expression) { + expression_ = std::move(expression); + } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + + std::string DebugString() const; + + private: + ValueString name_; + ValueString expression_; + std::optional description_; + std::optional display_name_; +}; + +class Rule; + +class OutputBlock { + public: + OutputBlock() = default; + OutputBlock(ValueString output, std::optional explanation) + : output_(std::move(output)), explanation_(std::move(explanation)) {} + + const ValueString& output() const { return output_; } + void set_output(ValueString output) { output_ = std::move(output); } + + const std::optional& explanation() const { return explanation_; } + void set_explanation(ValueString explanation) { + explanation_ = std::move(explanation); + } + + std::string DebugString() const; + + private: + ValueString output_; + std::optional explanation_; +}; + +// Defines a match condition and result. +// If the result is a Rule, it is considered a sub-rule and will be evaluated +// only if the match condition evaluates to true. +class Match { + public: + Match() = default; + Match(const Match& other); + Match& operator=(const Match& other); + + CelPolicyElementId id() const; + void set_id(CelPolicyElementId id); + + bool has_condition() const; + std::optional condition() const; + void set_condition(ValueString condition); + + bool has_output_block() const; + const OutputBlock& output_block() const; + OutputBlock& mutable_output_block(); + + bool has_rule() const; + const Rule& rule() const; + Rule& mutable_rule(); + + void set_result(OutputBlock result); + void set_result(std::unique_ptr result); + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional condition_; + std::variant> result_; +}; + +// Rule is the body of the policy and contains a list of variables and matches. +// Variables are evaluated once and stored in the activation context. +// Matches are evaluated in order and the first match is returned. If the +// match contains a sub-rule, the sub-rule is evaluated only if the match +// condition evaluates to true. +class Rule { + public: + Rule() = default; + Rule(const Rule& other) = default; + + CelPolicyElementId id() const { return id_; } + void set_id(CelPolicyElementId id) { id_ = id; } + + const std::optional& rule_id() const { return rule_id_; } + void set_rule_id(ValueString rule_id) { rule_id_ = std::move(rule_id); } + + const std::optional& description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + const std::vector& variables() const { return variables_; } + std::vector& mutable_variables() { return variables_; } + + const std::vector& matches() const { return matches_; } + std::vector& mutable_matches() { return matches_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional rule_id_; + std::optional description_; + std::vector variables_; + std::vector matches_; +}; + +// CelPolicy is the top-level policy object. +// It contains a source, name, description, display name, imports, and a rule. +// The source is the CEL policy source code. +// The name, description, and display name are metadata about the policy. +// The rule is the main body of the policy. +class CelPolicy { + public: + explicit CelPolicy(std::shared_ptr source) + : source_(std::move(source)) {} + + CelPolicy(const CelPolicy& other) = default; + CelPolicy& operator=(const CelPolicy& other) = default; + + const CelPolicySource* absl_nullable source() const { return source_.get(); } + const std::shared_ptr& source_ptr() const { return source_; } + + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + const absl::flat_hash_map& metadata() const { + return metadata_; + } + absl::flat_hash_map& mutable_metadata() { + return metadata_; + } + const std::vector& imports() const { return imports_; } + std::vector& mutable_imports() { return imports_; } + + const Rule& rule() const { return rule_; } + Rule& mutable_rule() { return rule_; } + + std::string DebugString() const; + + private: + std::shared_ptr source_; + ValueString name_; + std::optional description_; + std::optional display_name_; + absl::flat_hash_map metadata_; + std::vector imports_; + Rule rule_; +}; + +// Implementation details. + +inline CelPolicyElementId Match::id() const { return id_; } +inline void Match::set_id(CelPolicyElementId id) { id_ = id; } + +inline bool Match::has_condition() const { return condition_.has_value(); } + +inline std::optional Match::condition() const { + return condition_; +} + +inline void Match::set_condition(ValueString condition) { + condition_ = std::move(condition); +} + +inline bool Match::has_output_block() const { + return std::holds_alternative(result_); +} + +inline const OutputBlock& Match::output_block() const { + ABSL_DCHECK(std::holds_alternative(result_)); + return std::get(result_); +} + +inline OutputBlock& Match::mutable_output_block() { + if (!std::holds_alternative(result_)) { + result_ = OutputBlock(); + } + return std::get(result_); +} + +inline bool Match::has_rule() const { + return std::holds_alternative>(result_); +} + +inline const Rule& Match::rule() const { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline Rule& Match::mutable_rule() { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline void Match::set_result(OutputBlock result) { + result_ = std::move(result); +} + +inline void Match::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ diff --git a/policy/cel_policy_parse_context.cc b/policy/cel_policy_parse_context.cc new file mode 100644 index 000000000..66861d085 --- /dev/null +++ b/policy/cel_policy_parse_context.cc @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_context.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +CelPolicy& CelPolicyParseContext::policy() const { + ABSL_CHECK(policy_ != nullptr) + << "CelPolicyParseContext::policy() called after GetResult()"; + return *policy_; +} + +CelPolicyParseResult CelPolicyParseContext::GetResult() { + if (policy_ != nullptr && issues_.empty()) { + return CelPolicyParseResult(std::move(policy_source_), std::move(policy_), + std::move(issues_)); + } + policy_.reset(); + return CelPolicyParseResult(std::move(policy_source_), nullptr, + std::move(issues_)); +} + +void CelPolicyParseContext::ReportError(CelPolicyElementId element_id, + std::string_view message) { + issues_.push_back(CelPolicyIssue(element_id, std::string(message))); +} + +} // namespace cel diff --git a/policy/cel_policy_parse_context.h b/policy/cel_policy_parse_context.h new file mode 100644 index 000000000..6482fa1ae --- /dev/null +++ b/policy/cel_policy_parse_context.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// A mutable context for parsing a CelPolicy. An instance of this class is +// created for each policy parse and is passed to the parser, which is meant to +// be stateless. +// +// Parsers call methods on this class to report issues and populate the policy +// being parsed. Call GetResult() to obtain the resulting CelPolicyParseResult, +// which takes ownership of the parsed policy. Do not use the context after +// calling GetResult(). +class CelPolicyParseContext { + public: + explicit CelPolicyParseContext(std::shared_ptr policy_source) + : policy_source_(std::move(policy_source)), + policy_(std::make_unique(policy_source_)) {} + + CelPolicySource& policy_source() const { return *policy_source_; } + + // Returns the policy being parsed. It should not be used after + // calling GetResult(). + CelPolicy& policy() const; + + // The context should not be used after calling GetResult(). + CelPolicyParseResult GetResult(); + + // Reports an error for the given element with the given error message. + void ReportError(CelPolicyElementId id, std::string_view message); + + CelPolicyElementId next_element_id() { return next_element_id_++; } + + private: + std::shared_ptr policy_source_; + CelPolicyElementId next_element_id_ = 0; + std::vector issues_; + std::unique_ptr policy_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ diff --git a/policy/cel_policy_parse_result.cc b/policy/cel_policy_parse_result.cc new file mode 100644 index 000000000..32d6431bb --- /dev/null +++ b/policy/cel_policy_parse_result.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_result.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { +namespace { + +absl::string_view SeverityString(CelPolicyIssue::Severity severity) { + switch (severity) { + case CelPolicyIssue::Severity::kInformation: + return "INFORMATION"; + case CelPolicyIssue::Severity::kWarning: + return "WARNING"; + case CelPolicyIssue::Severity::kError: + return "ERROR"; + case CelPolicyIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string CelPolicyIssue::ToDisplayString( + const CelPolicySource* absl_nullable source) const { + SourceLocation location; + std::string description; + std::string snippet; + if (source != nullptr) { + if (relative_position_) { + std::optional base = + source->GetSourcePosition(element_id_); + if (element_id_ == -1) { + base.emplace(0); + } + if (base) { + location = source->content() + ->GetLocation(*base + *relative_position_) + .value_or(SourceLocation{}); + } + } else { + location = + source->GetSourceLocation(element_id_).value_or(SourceLocation{}); + } + description = std::string(source->content()->description()); + snippet = source->content()->DisplayErrorLocation(location); + } + + const int display_column = location.column >= 0 ? location.column + 1 : -1; + + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + description, location.line, display_column, message_, + snippet); +} + +std::string CelPolicyParseResult::FormattedIssues() const { + std::string formatted_issues; + for (const CelPolicyIssue& issue : issues_) { + if (!formatted_issues.empty()) { + absl::StrAppend(&formatted_issues, "\n"); + } + absl::StrAppend(&formatted_issues, issue.ToDisplayString(*policy_source_)); + } + return formatted_issues; +} + +} // namespace cel diff --git a/policy/cel_policy_parse_result.h b/policy/cel_policy_parse_result.h new file mode 100644 index 000000000..2bf80b1ce --- /dev/null +++ b/policy/cel_policy_parse_result.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { + +class CelPolicyIssue { + public: + enum class Severity { kInformation, kDeprecated, kWarning, kError }; + + CelPolicyIssue(CelPolicyElementId element_id, absl::string_view message) + : element_id_(element_id), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, Severity severity, + absl::string_view message) + : element_id_(element_id), severity_(severity), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, Severity severity, + absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + severity_(severity), + message_(message) {} + + std::string ToDisplayString( + const CelPolicySource* absl_nullable source) const; + std::string ToDisplayString(const CelPolicySource& source) const { + return ToDisplayString(&source); + } + + Severity severity() const { return severity_; } + absl::string_view message() const { return message_; } + + private: + CelPolicyElementId element_id_; + std::optional relative_position_; + Severity severity_ = Severity::kError; + std::string message_; +}; + +class CelPolicyParseResult { + public: + explicit CelPolicyParseResult(std::shared_ptr policy_source, + std::unique_ptr policy, + std::vector issues) + : policy_source_(std::move(policy_source)), + policy_(std::move(policy)), + issues_(std::move(issues)) {} + + bool IsValid() const { return policy_ != nullptr; } + + const CelPolicy* absl_nullable GetPolicy() const { return policy_.get(); } + + absl::StatusOr> ReleasePolicy() { + if (policy_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyParseResult is empty. Check for Issues."); + } + return std::move(policy_); + } + + absl::Span GetIssues() const { return issues_; } + + std::string FormattedIssues() const; + + private: + std::shared_ptr policy_source_; + absl_nullable std::unique_ptr policy_; + std::vector issues_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ diff --git a/policy/cel_policy_parser.h b/policy/cel_policy_parser.h new file mode 100644 index 000000000..0a11c9e68 --- /dev/null +++ b/policy/cel_policy_parser.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ + +#include "absl/status/status.h" +#include "policy/cel_policy_parse_context.h" + +namespace cel { + +// A policy parser for a given policy format. The type `T` parameter is the +// representation of the input file format, such as `` for YAML. +// +// Parsers are intended to be stateless: all state, including the resulting +// policy and any issues encountered, should be kept in the context passed to +// the `ParsePolicy` method. +template +class CelPolicyParser { + public: + virtual ~CelPolicyParser() = default; + + // Parses the input and populates a CelPolicy in the context. + virtual absl::Status ParsePolicy(CelPolicyParseContext& ctx) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ diff --git a/policy/cel_policy_test.cc b/policy/cel_policy_test.cc new file mode 100644 index 000000000..640247e7f --- /dev/null +++ b/policy/cel_policy_test.cc @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Field; +using testing::Optional; +using testing::SizeIs; + +TEST(CelPolicyBuilderTest, Build) { + CelPolicyElementId next_id = 1; + ASSERT_OK_AND_ASSIGN(SourcePtr source, NewSource("CEL\n policy\n source")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CelPolicy policy(policy_source); + policy.set_name(ValueString(next_id++, "test_policy")); + policy.set_description(ValueString(next_id++, "test_description")); + policy.set_display_name(ValueString(next_id++, "test_display_name")); + ValueString import1_name = ValueString(next_id++, "test_import1"); + policy.mutable_imports().push_back(Import(next_id++, import1_name)); + ValueString import2_name = ValueString(next_id++, "test_import2"); + policy.mutable_imports().push_back(Import(next_id++, import2_name)); + + Rule& rule = policy.mutable_rule(); + rule.set_id(next_id++); + rule.set_rule_id(ValueString(next_id++, "test_rule_id")); + rule.set_description(ValueString(next_id++, "test_rule_description")); + + Variable variable; + variable.set_name(ValueString(next_id++, "test_variable")); + variable.set_expression(ValueString(next_id++, "test_expression")); + variable.set_description(ValueString(next_id++, "test_variable_description")); + variable.set_display_name( + ValueString(next_id++, "test_variable_display_name")); + + Match match1; + match1.set_id(next_id++); + match1.set_condition(ValueString(next_id++, "test_condition")); + CelPolicyElementId output_id = next_id++; + CelPolicyElementId explanation_id = next_id++; + match1.set_result( + OutputBlock(ValueString(output_id, "test_result"), + ValueString(explanation_id, "test_explanation"))); + + Match match2; + match2.set_id(next_id++); + match2.set_condition(ValueString(next_id++, "test_condition2")); + + auto sub_rule = std::make_unique(); + sub_rule->set_id(next_id++); + sub_rule->set_rule_id(ValueString(next_id++, "sub_rule_id")); + sub_rule->set_description(ValueString(next_id++, "sub_rule_description")); + Match sub_rule_match; + sub_rule_match.set_id(next_id++); + sub_rule_match.set_condition(ValueString(next_id++, "sub_rule_condition")); + sub_rule_match.set_result( + OutputBlock(ValueString(next_id++, "sub_rule_result"), std::nullopt)); + sub_rule->mutable_matches().push_back(sub_rule_match); + + match2.set_result(std::move(sub_rule)); + + rule.mutable_variables().push_back(variable); + rule.mutable_matches().push_back(match1); + rule.mutable_matches().push_back(match2); + + EXPECT_EQ(policy.name().value(), "test_policy"); + ASSERT_TRUE(policy.description().has_value()); + EXPECT_EQ(policy.description()->value(), "test_description"); + ASSERT_TRUE(policy.display_name().has_value()); + EXPECT_EQ(policy.display_name()->value(), "test_display_name"); + + ASSERT_THAT(policy.imports(), SizeIs(2)); + + EXPECT_EQ(policy.imports()[0].name().value(), "test_import1"); + EXPECT_EQ(policy.imports()[1].name().value(), "test_import2"); + ASSERT_TRUE(policy.rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().rule_id()->value(), "test_rule_id"); + ASSERT_TRUE(policy.rule().description().has_value()); + EXPECT_EQ(policy.rule().description()->value(), "test_rule_description"); + + ASSERT_THAT(policy.rule().variables(), SizeIs(1)); + + EXPECT_EQ(policy.rule().variables()[0].name().value(), "test_variable"); + EXPECT_EQ(policy.rule().variables()[0].expression().value(), + "test_expression"); + ASSERT_TRUE(policy.rule().variables()[0].description().has_value()); + EXPECT_EQ(policy.rule().variables()[0].description()->value(), + "test_variable_description"); + ASSERT_TRUE(policy.rule().variables()[0].display_name().has_value()); + EXPECT_EQ(policy.rule().variables()[0].display_name()->value(), + "test_variable_display_name"); + + ASSERT_THAT(policy.rule().matches(), SizeIs(2)); + + EXPECT_EQ(policy.rule().matches()[0].condition().value().value(), + "test_condition"); + ASSERT_TRUE(policy.rule().matches()[0].has_output_block()); + EXPECT_EQ(policy.rule().matches()[0].output_block().output().value(), + "test_result"); + ASSERT_TRUE( + policy.rule().matches()[0].output_block().explanation().has_value()); + EXPECT_EQ(policy.rule().matches()[0].output_block().explanation()->value(), + "test_explanation"); + + EXPECT_EQ(policy.rule().matches()[1].condition().value().value(), + "test_condition2"); + ASSERT_TRUE(policy.rule().matches()[1].has_rule()); + ASSERT_TRUE(policy.rule().matches()[1].rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().rule_id()->value(), + "sub_rule_id"); + ASSERT_TRUE(policy.rule().matches()[1].rule().description().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().description()->value(), + "sub_rule_description"); + ASSERT_THAT(policy.rule().matches()[1].rule().matches(), SizeIs(1)); + EXPECT_EQ(policy.rule() + .matches()[1] + .rule() + .matches()[0] + .condition() + .value() + .value(), + "sub_rule_condition"); + + std::string actual = policy.DebugString(); + EXPECT_EQ(actual, absl::StrReplaceAll(R"(CelPolicy{ + =========================================================== + CEL + policy + source + =========================================================== + name: #1> "test_policy" + description: #2> "test_description" + display_name: #3> "test_display_name" + imports: + #5> name: #4> "test_import1" + #7> name: #6> "test_import2" + #8> rule: { + rule_id: #9> "test_rule_id" + description: #10> "test_rule_description" + variable: { + name: #11> "test_variable" + expression: #12> "test_expression" + description: #13> "test_variable_description" + display_name: #14> "test_variable_display_name" + } + #15> match: { + condition: #16> "test_condition" + result: { + output: #17> "test_result" + explanation: #18> "test_explanation" + } + } + #19> match: { + condition: #20> "test_condition2" + result: + #21> rule: { + rule_id: #22> "sub_rule_id" + description: #23> "sub_rule_description" + #24> match: { + condition: #25> "sub_rule_condition" + result: { + output: #26> "sub_rule_result" + } + } + } + } + } + })", + {{"\n ", "\n"}})); +} + +TEST(CelPolicySourceTest, Build) { + std::string source = + "name: test_policy\n imports:\n - name: test_import\n"; + + ASSERT_OK_AND_ASSIGN(SourcePtr source_ptr, NewSource(source)); + CelPolicySource policy_source(std::move(source_ptr)); + policy_source.NoteSourcePosition(1, source.find("test_policy")); + policy_source.NoteSourcePosition(2, source.find("test_import")); + + EXPECT_THAT(policy_source.GetSourcePosition(1), Optional(6)); + EXPECT_THAT(policy_source.GetSourceLocation(1), + Optional(AllOf(Field(&SourceLocation::line, 1), + Field(&SourceLocation::column, 6)))); + EXPECT_THAT(policy_source.GetSourcePosition(2), Optional(44)); + EXPECT_THAT(policy_source.GetSourceLocation(2), + Optional(AllOf(Field(&SourceLocation::line, 3), + Field(&SourceLocation::column, 13)))); + EXPECT_EQ(policy_source.GetSourcePosition(3), std::nullopt); + EXPECT_EQ(policy_source.GetSourceLocation(3), std::nullopt); + EXPECT_EQ( + policy_source.DebugString(), + "name: #1> test_policy\n imports:\n - name: #2> test_import\n"); +} + +} // namespace +} // namespace cel diff --git a/policy/cel_policy_validation_result.cc b/policy/cel_policy_validation_result.cc new file mode 100644 index 000000000..e257f064c --- /dev/null +++ b/policy/cel_policy_validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +std::string CelPolicyValidationResult::FormatIssues() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const CelPolicyIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/policy/cel_policy_validation_result.h b/policy/cel_policy_validation_result.h new file mode 100644 index 000000000..bddb9a3ca --- /dev/null +++ b/policy/cel_policy_validation_result.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// CelPolicyValidationResult holds the result of policy compilation. +// +// Policy compilation/validation errors are captured in issues. +class CelPolicyValidationResult { + public: + CelPolicyValidationResult( + std::unique_ptr ast, std::vector issues, + std::shared_ptr source = nullptr) + : ast_(std::move(ast)), + issues_(std::move(issues)), + source_(std::move(source)) {} + + explicit CelPolicyValidationResult( + std::vector issues, + std::shared_ptr source = nullptr) + : ast_(nullptr), issues_(std::move(issues)), source_(std::move(source)) {} + + // Returns true if validation succeeded and an AST is present. + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + // Moves out and returns the AST. + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyValidationResult is empty. Check for CelPolicyIssues."); + } + return std::move(ast_); + } + + // Returns the list of issues encountered during compilation. + absl::Span GetIssues() const { return issues_; } + + // Returns the contained policy source, if any. + const CelPolicySource* absl_nullable GetSource() const { + return source_.get(); + } + + // Returns a formatted error string of the compiled issues. + std::string FormatIssues() const; + + private: + absl_nullable std::unique_ptr ast_; + std::vector issues_; + std::shared_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ diff --git a/policy/compiler.cc b/policy/compiler.cc new file mode 100644 index 000000000..7a892447c --- /dev/null +++ b/policy/compiler.cc @@ -0,0 +1,1058 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/internal/issue_reporter.h" +#include "policy/internal/optimizer_expr_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +constexpr absl::string_view kCelBlock = "cel.@block"; + +enum class RuleSemantics { + // TODO(b/506179116): will also need "aggregate" or similar concept. + kFirstMatch, + + kNotForUseWithExhaustiveSwitchStatements, +}; + +template +void AbslStringify(Sink& s, RuleSemantics semantics) { + switch (semantics) { + case RuleSemantics::kFirstMatch: + s.Append("first_match"); + return; + default: + s.Append(""); + return; + } +} + +struct EmbeddedAst { + CelPolicyElementId id; + std::unique_ptr ast; +}; + +struct CompiledVariable { + std::string ident; + EmbeddedAst ast; +}; + +struct CompiledOutputBlock { + EmbeddedAst output_ast; + cel::Type result_type; + std::optional explanation_ast; +}; + +struct CompiledRule; + +struct CompiledMatch { + using Production = + std::variant absl_nonnull, + CompiledOutputBlock>; + + CelPolicyElementId id; + std::optional condition; + Production production; +}; + +struct CompiledRule { + CelPolicyElementId id; + std::vector variables; + std::vector matches; + // Not set if cannot be determined. + std::optional result_type; +}; + +std::optional GetOutputType( + const CompiledMatch::Production& production) { + return std::visit( + [](const auto& production) -> std::optional { + if constexpr (std::is_same_v, + CompiledOutputBlock>) { + return production.result_type; + } else if constexpr (std::is_same_v, + std::unique_ptr>) { + return production->result_type; + } + return std::nullopt; + }, + production); +} + +// Internal representation of the compiled policy elements. +// +// This is used for checking the component expression before composing into the +// final AST based on the provided rule semantics. +class IntermediateCompiledPolicy { + public: + CompiledRule& mutable_root_rule() { return root_rule_; } + + const CompiledRule& root_rule() const { return root_rule_; } + + void set_name(absl::string_view name) { name_ = name; } + absl::string_view name() const { return name_; } + void set_display_name(absl::string_view display_name) { + display_name_ = display_name; + } + absl::string_view display_name() const { return display_name_; } + void set_description(absl::string_view description) { + description_ = description; + } + absl::string_view description() const { return description_; } + + void set_semantics(RuleSemantics semantics) { semantics_ = semantics; } + RuleSemantics semantics() const { return semantics_; } + + private: + std::string name_; + std::string display_name_; + std::string description_; + RuleSemantics semantics_ = RuleSemantics::kFirstMatch; + + CompiledRule root_rule_; +}; + +CelPolicyIssue::Severity MapSeverity(cel::TypeCheckIssue::Severity severity) { + switch (severity) { + case cel::TypeCheckIssue::Severity::kError: + return CelPolicyIssue::Severity::kError; + case cel::TypeCheckIssue::Severity::kWarning: + return CelPolicyIssue::Severity::kWarning; + case cel::TypeCheckIssue::Severity::kDeprecated: + return CelPolicyIssue::Severity::kDeprecated; + default: + return CelPolicyIssue::Severity::kError; + } +} + +bool IsWrapperOf(cel::TypeKind wrapper_kind, cel::TypeKind primitive_kind) { + switch (wrapper_kind) { + case cel::TypeKind::kBoolWrapper: + return primitive_kind == cel::TypeKind::kBool; + case cel::TypeKind::kIntWrapper: + return primitive_kind == cel::TypeKind::kInt; + case cel::TypeKind::kUintWrapper: + return primitive_kind == cel::TypeKind::kUint; + case cel::TypeKind::kDoubleWrapper: + return primitive_kind == cel::TypeKind::kDouble; + case cel::TypeKind::kStringWrapper: + return primitive_kind == cel::TypeKind::kString; + case cel::TypeKind::kBytesWrapper: + return primitive_kind == cel::TypeKind::kBytes; + default: + return false; + } +} + +cel::Type FilterSpecialTypes(cel::Type type) { + if (type.IsTypeParam()) { + // Free type param should not appear in the output type, but if it does, + // force it to dyn. + return DynType(); + } + if (type.IsEnum()) { + return IntType{}; + } + if (type.IsError()) { + return DynType(); + } + if (type.IsType()) { + // drop parameters so all type types are compatible. + return TypeType{}; + } + return type; +} + +// Returns true if `from` is assignable to `to`. +// +// Slightly adjusted from the standard routine to cover some edge cases around +// null and wrappers. +// +// TODO(b/522391716): try to standardize assignability checks. +bool OutputTypeIsAssignable(cel::Type from, cel::Type to) { + from = FilterSpecialTypes(from); + to = FilterSpecialTypes(to); + + // Any and dyn are assignable to/from everything. + if (from.kind() == cel::TypeKind::kAny || + from.kind() == cel::TypeKind::kDyn || to.kind() == cel::TypeKind::kAny || + to.kind() == cel::TypeKind::kDyn) { + return true; + } + + // Wrappers auto-unwrap. + if (IsWrapperOf(from.kind(), to.kind()) || + IsWrapperOf(to.kind(), from.kind())) { + return true; + } + + // Null is assignable to anything that is message-like. + if (from.kind() == cel::TypeKind::kNull) { + switch (to.kind()) { + case cel::TypeKind::kNull: + case cel::TypeKind::kStruct: + case cel::TypeKind::kOpaque: + case cel::TypeKind::kTimestamp: + case cel::TypeKind::kDuration: + case cel::TypeKind::kBytesWrapper: + case cel::TypeKind::kBoolWrapper: + case cel::TypeKind::kIntWrapper: + case cel::TypeKind::kUintWrapper: + case cel::TypeKind::kDoubleWrapper: + case cel::TypeKind::kStringWrapper: + return true; + default: + return false; + } + } + + if (from.kind() != to.kind()) { + return false; + } + + if (from.name() != to.name()) { + return false; + } + + if (from.GetParameters().size() != to.GetParameters().size()) { + return false; + } + + for (int i = 0; i < from.GetParameters().size(); ++i) { + if (!OutputTypeIsAssignable(from.GetParameters()[i], + to.GetParameters()[i])) { + return false; + } + } + + return true; +} + +bool OutputTypeIsCompatible(cel::Type from, cel::Type to) { + // We don't handle widening like in a self-contained CEL expression, but + // permit some cases where one type is more specific than the other. + return OutputTypeIsAssignable(from, to) || OutputTypeIsAssignable(to, from); +} + +bool HasErrors(const policy_internal::IssueReporter& issues) { + for (const auto& issue : issues.issues()) { + if (issue.severity() == CelPolicyIssue::Severity::kError) { + return true; + } + } + return false; +} + +// Note on lifetime safety: +// +// The output policy will contain references to types that are owned by the +// arena member of this class. This is safe as long as the policy compiler lives +// as long as the output policies. +class PolicyCompiler { + public: + explicit PolicyCompiler(policy_internal::IssueReporter* issues, + std::unique_ptr base_compiler) + : issues_(*issues), base_compiler_(std::move(base_compiler)) {} + + absl::string_view GetSourceDescription() const { + if (src_ == nullptr) { + return ""; + } + return src_->content()->description(); + } + + void AdaptTypeCheckIssues(CelPolicyElementId id, const ValidationResult& r) { + const Source* source = r.GetSource(); + + for (const auto& iss : r.GetIssues()) { + std::optional offset; + if (source != nullptr) { + offset = source->GetPosition(iss.location()); + } + if (offset.has_value()) { + issues_.ReportOffsetIssue(id, offset.value(), + MapSeverity(iss.severity()), iss.message()); + continue; + } + issues_.ReportIssue(id, MapSeverity(iss.severity()), iss.message()); + } + } + + absl::StatusOr CompileOutputBlock( + const cel::OutputBlock& output_block, const Compiler* env) { + CompiledOutputBlock output; + CEL_ASSIGN_OR_RETURN(auto output_validation, + env->Compile(output_block.output().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.output().id(), output_validation); + + cel::Type result_type = DynType(); + if (output_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, output_validation.ReleaseAst()); + auto root_expr_id = ast->root_expr().id(); + output.output_ast = + EmbeddedAst{output_block.output().id(), std::move(ast)}; + if (auto it = output_validation.GetResolvedTypeMap().find(root_expr_id); + it != output_validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + } + if (output_block.explanation().has_value()) { + CEL_ASSIGN_OR_RETURN(auto explanation_validation, + env->Compile(output_block.explanation()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.explanation()->id(), + explanation_validation); + if (explanation_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, explanation_validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kString) { + issues_.ReportError(output_block.explanation()->id(), + "explanation must evaluate to string"); + } else { + output.explanation_ast = + EmbeddedAst{output_block.explanation()->id(), std::move(ast)}; + } + } + } + output.result_type = result_type; + return output; + } + + absl::Status CompileMatch(const Match& match, const Compiler* env, + CompiledRule* out) { + CompiledMatch c_match; + c_match.id = match.id(); + if (match.condition().has_value()) { + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(match.condition()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(match.condition()->id(), validation); + if (validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kBool) { + issues_.ReportError(match.condition()->id(), + "condition must evaluate to bool"); + } + c_match.condition = + EmbeddedAst{match.condition()->id(), std::move(ast)}; + } + } + + if (match.has_output_block()) { + CEL_ASSIGN_OR_RETURN(c_match.production, + CompileOutputBlock(match.output_block(), env)); + } else if (match.has_rule()) { + auto rule = std::make_unique(); + CEL_RETURN_IF_ERROR(CompileRule(match.rule(), env, rule.get())); + c_match.production = std::move(rule); + } else { + issues_.ReportError(match.id(), "match must specify an output or rule"); + } + out->matches.push_back(std::move(c_match)); + return absl::OkStatus(); + } + + absl::Status CompileRule(const Rule& rule, const cel::Compiler* env, + CompiledRule* out) { + out->id = rule.id(); + std::unique_ptr buf; + + absl::flat_hash_set seen_variables; + for (const auto& variable : rule.variables()) { + std::string name(variable.name().value()); + if (!seen_variables.insert(name).second) { + issues_.ReportError( + variable.expression().id(), + absl::StrCat("overlapping identifier for name 'variables.", name, + "'")); + continue; + } + std::string ident = absl::StrCat("variables.", name); + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(variable.expression().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(variable.expression().id(), validation); + if (!validation.IsValid()) { + continue; + } + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + cel::Type result_type = DynType(); + + if (auto it = validation.GetResolvedTypeMap().find(ast->root_expr().id()); + it != validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + out->variables.push_back(CompiledVariable{ + ident, + EmbeddedAst{variable.expression().id(), std::move(ast)}, + }); + auto next = env->ToBuilder(); + auto status = next->GetCheckerBuilder().AddOrReplaceVariable( + MakeVariableDecl(ident, result_type)); + if (!status.ok()) { + issues_.ReportError(variable.expression().id(), status.message()); + continue; + } + CEL_ASSIGN_OR_RETURN(buf, next->Build()); + env = buf.get(); + } + + std::optional overall_type; + for (const auto& match : rule.matches()) { + CEL_RETURN_IF_ERROR(CompileMatch(match, env, out)); + if (!overall_type.has_value()) { + overall_type = GetOutputType(out->matches.back().production); + continue; + } + + if (std::optional match_type = + GetOutputType(out->matches.back().production); + match_type.has_value()) { + if (!OutputTypeIsCompatible(*match_type, *overall_type)) { + issues_.ReportError( + match.id(), + absl::StrCat("incompatible output types: block has output type ", + FormatTypeName(*match_type), + ", but previous outputs have type ", + FormatTypeName(*overall_type))); + } + } + } + + out->result_type = overall_type; + return absl::OkStatus(); + } + + absl::Status CompilePolicy(const CelPolicy& policy, + IntermediateCompiledPolicy* out) { + src_ = policy.source(); + out->set_semantics(RuleSemantics::kFirstMatch); + out->set_name(policy.name().value()); + out->set_display_name( + policy.display_name().value_or(ValueString{}).value()); + out->set_description(policy.description().value_or(ValueString{}).value()); + + return CompileRule(policy.rule(), base_compiler_.get(), + &out->mutable_root_rule()); + } + + private: + google::protobuf::Arena arena_; + const CelPolicySource* absl_nullable src_; + policy_internal::IssueReporter& issues_; + std::unique_ptr base_compiler_; +}; + +bool IsExhaustive(const CompiledRule& rule); + +class FirstMatchComposer { + public: + FirstMatchComposer(const IntermediateCompiledPolicy& icp, + const Compiler& compiler, + policy_internal::IssueReporter& issues) + : issues_(issues), icp_(icp), compiler_(compiler) {} + + absl::Status Compose(); + + bool success() const { return ast_ != nullptr; } + + std::unique_ptr ReleaseAst() { return std::move(ast_); } + + private: + using VariableScope = absl::flat_hash_map; + + std::optional ResolvePolicyVariable(absl::string_view reference); + + absl::flat_hash_map ResolveBlockIndexes(const Ast& ast); + + bool CheckMatchStructure(const CompiledRule& rule); + + // Returns true if already optional wrapped. + absl::StatusOr ComposeRule(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + // returns true if already optional wrapped. + absl::StatusOr ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr); + + void MapVariables(Ast& ast); + + void ComposeRuleVariables(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + policy_internal::IssueReporter& issues_; + OptimizerExprFactory factory_; + const IntermediateCompiledPolicy& icp_; + const Compiler& compiler_; + std::vector scopes_; + bool optionalize_ = false; + std::unique_ptr ast_; +}; + +absl::Status FirstMatchComposer::Compose() { + ABSL_DCHECK(icp_.semantics() == RuleSemantics::kFirstMatch); + + factory_.mutable_ast().mutable_root_expr() = factory_.NewCall( + "cel.@block", factory_.NewList(), factory_.NewUnspecified()); + auto& block_init_list = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[0]; + auto& insertion_expr = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]; + optionalize_ = !IsExhaustive(icp_.root_rule()); + if (!CheckMatchStructure(icp_.root_rule())) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + bool optional_wrapped, + ComposeRule(icp_.root_rule(), block_init_list, insertion_expr)); + + if (optional_wrapped != optionalize_) { + return absl::InternalError( + "composition failed to handle non-exhaustive rules"); + } + + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, + compiler_.GetTypeChecker().Check(factory_.ast())); + if (!result.IsValid()) { + for (const auto& iss : result.GetIssues()) { + issues_.ReportError(icp_.root_rule().id, iss.message()); + } + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ast_, result.ReleaseAst()); + + return absl::OkStatus(); +} + +bool IsTriviallyTrueCondition(const CompiledMatch& match) { + if (!match.condition.has_value() || match.condition->ast == nullptr) { + return true; + } + const cel::Expr& expr = match.condition->ast->root_expr(); + if (expr.has_const_expr()) { + const cel::Constant& const_expr = expr.const_expr(); + if (const_expr.has_bool_value() && const_expr.bool_value()) { + return true; + } + } + return false; +} + +bool IsExhaustive(const CompiledRule& rule); + +bool IsExhaustive(const CompiledMatch& match) { + if (std::holds_alternative(match.production)) { + return true; + } + + const auto* nested_rule_ptr = + std::get_if>(&match.production); + ABSL_DCHECK(nested_rule_ptr != nullptr); + const CompiledRule& nested_rule = **nested_rule_ptr; + return IsExhaustive(nested_rule); +} + +bool IsExhaustive(const CompiledRule& rule) { + if (rule.matches.empty()) { + // Validation should fail, but generalization would be false. + return false; + } + bool has_default = false; + for (const auto& match : rule.matches) { + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + // If this isn't the last match in the rule, it should get flagged + // during validation since it means there are trivially unreachable + // matches. + has_default = true; + } + if (!IsTriviallyTrueCondition(match) && !IsExhaustive(match)) { + // There is a nested rule that might return an optional.none(). + return false; + } + } + // Otherwise, everything in this branch is exhaustive so we can defer + // wrapping. + return has_default; +} + +bool FirstMatchComposer::CheckMatchStructure(const CompiledRule& rule) { + if (rule.matches.empty()) { + issues_.ReportError(rule.id, "rule does not specify match conditions"); + return false; + } + + bool valid = true; + bool seen_trivially_true = false; + + for (const auto& match : rule.matches) { + if (seen_trivially_true) { + if (std::holds_alternative(match.production)) { + issues_.ReportError(match.id, "match creates unreachable outputs"); + } else if (std::holds_alternative>( + match.production)) { + issues_.ReportError(match.id, "rule creates unreachable outputs"); + } + valid = false; + } + + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + seen_trivially_true = true; + } + + if (auto* nested_rule = + std::get_if>(&match.production); + nested_rule != nullptr) { + ABSL_DCHECK(*nested_rule != nullptr); + if (!CheckMatchStructure(**nested_rule)) { + valid = false; + } + } + } + + return valid; +} + +std::optional FirstMatchComposer::ResolvePolicyVariable( + absl::string_view reference) { + for (auto scope_iter = scopes_.rbegin(); scope_iter != scopes_.rend(); + ++scope_iter) { + if (auto it = scope_iter->find(reference); it != scope_iter->end()) { + return it->second; + } + } + return std::nullopt; +} + +class IndexRewrite : public AstRewriterBase { + public: + explicit IndexRewrite(absl::flat_hash_map expr_id_to_index, + OptimizerExprFactory& factory) + : expr_id_to_index_(std::move(expr_id_to_index)), factory_(factory) {} + + bool PreVisitRewrite(Expr& e) override { + if (auto it = expr_id_to_index_.find(e.id()); + it != expr_id_to_index_.end()) { + e.mutable_ident_expr().set_name(absl::StrCat("@index", it->second)); + factory_.RecordReplacement(e.id(), e); + return true; + } + return false; + } + + private: + absl::flat_hash_map expr_id_to_index_; + OptimizerExprFactory& factory_; +}; + +absl::StatusOr FirstMatchComposer::ComposeRule(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + scopes_.emplace_back(); + auto pop_scope = absl::MakeCleanup([this]() { scopes_.pop_back(); }); + ComposeRuleVariables(rule, init, insertion_expr); + Expr* insertion_point = &insertion_expr; + const bool has_default = IsTriviallyTrueCondition(rule.matches.back()); + const bool needs_wrap = !IsExhaustive(rule); + size_t end = rule.matches.size() - (has_default ? 1 : 0); + for (size_t i = 0; i < end; i++) { + const auto& match = rule.matches[i]; + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + return absl::InternalError("detected unreachable match after validation"); + } + + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + + if (!IsTriviallyTrueCondition(match)) { + Ast condition = *match.condition->ast; + MapVariables(condition); + factory_.StartCopyContext(); + auto copy = factory_.Copy(condition.root_expr()); + auto source_info = factory_.RemapSourceInfo(condition.source_info()); + factory_.MergeSourceInfo(source_info); + *insertion_point = factory_.NewCall("_?_:_", std::move(copy)); + insertion_point->mutable_call_expr().mutable_args().push_back( + std::move(production)); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + insertion_point = &insertion_point->mutable_call_expr().add_args(); + continue; + } + + if (!is_wrapped) { + return absl::InternalError( + "composition failed. expected optional wrapped rule but got a plain " + "value"); + } + auto fn = needs_wrap ? "or" : "orValue"; + *insertion_point = factory_.NewMemberCall(fn, std::move(production)); + insertion_point = &insertion_point->mutable_call_expr().add_args(); + } + + if (has_default) { + const auto& match = rule.matches.back(); + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + *insertion_point = std::move(production); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + + return needs_wrap; + } + + // Otherwise, we fell through a non-exhaustive rule. + *insertion_point = factory_.NewCall("optional.none"); + return true; +} + +absl::StatusOr FirstMatchComposer::ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr) { + if (auto* nested_rule = + std::get_if>(&production); + nested_rule != nullptr) { + return ComposeRule(**nested_rule, init, insertion_expr); + } + auto* output = std::get_if(&production); + if (output == nullptr) { + return absl::InternalError("unexpected rule production type"); + } + const EmbeddedAst& output_ast = output->output_ast; + Ast ast = *output_ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + Expr to_insert = factory_.Copy(ast.root_expr()); + auto source_info = factory_.RemapSourceInfo(ast.source_info()); + factory_.MergeSourceInfo(source_info); + insertion_expr = std::move(to_insert); + + return false; +} + +absl::flat_hash_map FirstMatchComposer::ResolveBlockIndexes( + const Ast& ast) { + absl::flat_hash_map out; + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + it++) { + const Reference& ref = it->second; + if (!it->second.overload_id().empty()) { + continue; + } + if (!absl::StartsWith(ref.name(), "variable")) { + continue; + } + if (auto index = ResolvePolicyVariable(ref.name()); index.has_value()) { + out[it->first] = *index; + } + } + return out; +} + +void FirstMatchComposer::MapVariables(Ast& ast) { + absl::flat_hash_map edit_map = ResolveBlockIndexes(ast); + IndexRewrite rewriter(std::move(edit_map), factory_); + AstRewrite(ast.mutable_root_expr(), rewriter); +} + +void FirstMatchComposer::ComposeRuleVariables(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + for (const auto& variable : rule.variables) { + Ast ast = *variable.ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + auto insertion = factory_.Copy(ast.root_expr()); + // TODO(b/506179116): apply the position offsets here. + auto info = factory_.RemapSourceInfo(ast.source_info()); + ABSL_DCHECK(init.has_list_expr()); + int index = init.mutable_list_expr().elements().size(); + init.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(insertion))); + scopes_.back()[variable.ident] = index; + } +} + +bool HasComprehensionParent(const NavigableAstNode& node) { + const NavigableAstNode* curr = &node; + while (curr != nullptr) { + if (curr->node_kind() == NodeKind::kComprehension) { + return true; + } + curr = curr->parent(); + } + return false; +} + +// Unnester implementation. +class Unnester { + public: + Unnester(Ast ast, int height, policy_internal::IssueReporter& issues) + : factory_(std::move(ast)), height_(height), issues_(issues) {} + + // Run the unnesting. + // The class cannot be reused after this is called. + absl::StatusOr Unnest() { + if (height_ > 0) { + CEL_RETURN_IF_ERROR(Slice()); + } + CEL_RETURN_IF_ERROR(Cleanup()); + return std::move(factory_.mutable_ast()); + } + + private: + // The core unnest routine. + absl::Status Slice(); + // Fixup the AST post-unnesting. + absl::Status Cleanup(); + + void ReportErrorAtId(int64_t id, absl::string_view message); + + OptimizerExprFactory factory_; + int height_; + policy_internal::IssueReporter& issues_; +}; + +class UnnestRewriter : public AstRewriterBase { + public: + explicit UnnestRewriter(OptimizerExprFactory& f, Expr& block_list_expr, + absl::Span cuts) + : factory_(f), cuts_(cuts), block_list_expr_(block_list_expr) {} + + bool PostVisitRewrite(Expr& expr) override { + using std::swap; + // Post order so we always see children before parents. + // No need to copy metadata since we're only moving exprs or minting + // new ones. + if (absl::c_contains(cuts_, expr.id())) { + size_t idx = block_list_expr_.list_expr().elements().size(); + Expr value = factory_.NewIdent(absl::StrCat("@index", idx)); + factory_.RecordReplacement(expr.id(), value, /*keep_metadata=*/true); + swap(value, expr); + block_list_expr_.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(value))); + return true; + } + return false; + } + + private: + OptimizerExprFactory& factory_; + absl::Span cuts_; + Expr& block_list_expr_; +}; + +absl::Status Unnester::Slice() { + Expr& root = factory_.mutable_ast().mutable_root_expr(); + if (root.call_expr().function() != kCelBlock || + root.call_expr().args().size() != 2 || + !root.call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + // Two passes, we identify the slice points (bottom up), then cut + // and paste the leaves into the block list. + NavigableAst nav_ast = NavigableAst::Build(factory_.ast().root_expr()); + + ABSL_DCHECK(nav_ast.IdsAreUnique()); + bool can_cut = true; + std::vector cuts; + for (const NavigableAstNode& node : nav_ast.Root().DescendantsPostorder()) { + // Subsequent cuts will be height_ + 1 in the block, indices. Within the + // error margin we specified. + if (node.height() % height_ == 0) { + if (HasComprehensionParent(node)) { + ReportErrorAtId( + node.expr()->id(), + absl::StrCat( + "cannot unnest AST due to comprehension. cannot accommodate " + "height limit of ", + height_)); + can_cut = false; + continue; + } + if (&node == &nav_ast.Root()) { + // If evenly divisible by height, don't cut since it will net a taller + // AST. + continue; + } + cuts.push_back(node.expr()->id()); + } + } + + if (!can_cut || cuts.empty()) { + return absl::OkStatus(); + } + + Expr& block_list_expr = root.mutable_call_expr().mutable_args()[0]; + Expr& insertion_expr = root.mutable_call_expr().mutable_args()[1]; + + UnnestRewriter rewriter(factory_, block_list_expr, cuts); + AstRewrite(insertion_expr, rewriter); + + return absl::OkStatus(); +} + +absl::Status Unnester::Cleanup() { + using std::swap; + + const auto& ast = factory_.ast(); + if (ast.root_expr().call_expr().function() != kCelBlock || + ast.root_expr().call_expr().args().size() != 2 || + !ast.root_expr().call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + if (ast.root_expr().call_expr().args()[0].list_expr().elements().empty()) { + Expr value = std::move(factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]); + factory_.mutable_ast().mutable_root_expr() = std::move(value); + } + + return absl::OkStatus(); +} + +void Unnester::ReportErrorAtId(int64_t id, absl::string_view message) { + int32_t position = 0; + auto it = factory_.ast().source_info().positions().find(id); + if (it != factory_.ast().source_info().positions().end()) { + position = it->second; + } + issues_.ReportError(-1, position, message); +} +} // namespace + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options) { + policy_internal::IssueReporter issues; + if (options.unnesting_height_limit != 0 && + options.unnesting_height_limit < 2) { + return absl::InvalidArgumentError( + "unnesting_height_limit must be at least 2"); + } + auto builder = compiler.ToBuilder(); + ExpressionContainer cont; + for (const auto& import : policy.imports()) { + auto status = cont.AddAbbreviation(import.name().value()); + if (!status.ok()) { + issues.ReportError( + import.name().id(), + absl::StrCat("'", import.name().value(), "': ", status.message())); + } + } + + builder->GetCheckerBuilder().SetExpressionContainer(cont); + CEL_ASSIGN_OR_RETURN(auto base_compiler, builder->Build()); + + PolicyCompiler policy_compiler(&issues, std::move(base_compiler)); + + IntermediateCompiledPolicy icp; + CEL_RETURN_IF_ERROR(policy_compiler.CompilePolicy(policy, &icp)); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + CEL_ASSIGN_OR_RETURN(base_compiler, builder->Build()); + switch (icp.semantics()) { + case RuleSemantics::kFirstMatch: { + FirstMatchComposer composer(icp, *base_compiler, issues); + CEL_RETURN_IF_ERROR(composer.Compose()); + if (!composer.success()) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + auto ast = composer.ReleaseAst(); + Unnester unnester(std::move(*ast), options.unnesting_height_limit, + issues); + CEL_ASSIGN_OR_RETURN(Ast unnested_ast, unnester.Unnest()); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + return CelPolicyValidationResult( + std::make_unique(std::move(unnested_ast)), {}, + policy.source_ptr()); + } + default: + return absl::UnimplementedError( + absl::StrCat("Unsupported RuleSemantics: ", icp.semantics())); + } +} + +} // namespace cel diff --git a/policy/compiler.h b/policy/compiler.h new file mode 100644 index 000000000..0187bd1a2 --- /dev/null +++ b/policy/compiler.h @@ -0,0 +1,50 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_validation_result.h" + +namespace cel { + +struct CompilePolicyOptions { + // If greater than 0, the compiler will attempt to unnest rule branches + // at the specified height. The overall height of the final AST may exceed + // this by a small, fixed margin. + // + // To avoid slicing comprehensions, subexpressions within comprehensions + // are not eligible for unnesting. If the height limit cannot be accommodated, + // an error with code InvalidArgument is returned. + // + // If the AST is converted to proto, even relatively low levels of nesting + // can cause problems in serialization/deserialization. This does not apply + // if the AST is used directly by the runtime. + int unnesting_height_limit = 0; +}; + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +// +// TODO(b/506179116): Implementation in progress. Functionally complete, +// but errors are not consistent with other implementations. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ diff --git a/policy/compiler_test.cc b/policy/compiler_test.cc new file mode 100644 index 000000000..8db494b45 --- /dev/null +++ b/policy/compiler_test.cc @@ -0,0 +1,946 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/types/message_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/cel_policy.yaml"; + +absl::StatusOr> BuildTestCompiler() { + CompilerOptions opts; + opts.adapt_parser_errors = true; + opts.parser_options.enable_optional_syntax = true; + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool(), opts)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", IntType()))); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", IntType()))); + + const google::protobuf::Descriptor* descriptor = + cel::internal::GetSharedTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + if (descriptor == nullptr) { + return absl::InternalError("Failed to find TestAllTypes descriptor"); + } + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("spec", cel::MessageType(descriptor)))); + + return builder->Build(); +} + +absl::StatusOr> ParsePolicyFromYaml( + absl::string_view yaml_content) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(yaml_content, "test.yaml")); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(auto parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError("Invalid policy YAML structure"); + } + return parse_result.ReleasePolicy(); +} + +TEST(CompilerTest, SmokeTest) { + std::string contents; + std::string test_file = + cel::internal::ResolveRunfilesPath(kTestPolicyFilePath); + auto read_status = cel::internal::GetFileContents(test_file, &contents); + ASSERT_THAT(read_status, IsOk()); + + auto source_or = cel::NewSource(contents, "cel_policy.yaml"); + ASSERT_THAT(source_or.status(), IsOk()); + auto source = *std::move(source_or); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + auto parse_result_or = cel::ParseYamlCelPolicy(policy_source); + ASSERT_THAT(parse_result_or.status(), IsOk()); + auto parse_result = *std::move(parse_result_or); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, VariableOutOfScopeReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: variables.non_existent == 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, ConditionNotBoolReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("condition must evaluate to bool")); +} + +TEST(CompilerTest, InvalidOutputExpressionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: undeclared_var +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, UnreachableMatchAfterTriviallyTrueCondition) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"first"' + - condition: true + output: '"second"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, UnreachableMatchAfterUnconditionalExhaustiveSubRule) { + absl::string_view yaml = R"yaml( +name: dead_branch +rule: + match: + - rule: + match: + - output: 1 + - output: 2 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, RuleWithoutMatchesReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("rule does not specify match conditions")); +} + +TEST(CompilerTest, ExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 15 + output: '"greater than 15"' + - condition: variables.test_var > 5 + output: '"greater than 5"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +TEST(CompilerTest, NonExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 5 + output: '"greater than 5"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, PolicyReferencesEnvInput) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: spec.single_int32 > 10 + output: '"greater than 10"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +struct EvaluationTestCase { + std::string name; + std::string yaml_policy; + struct Input { + int64_t x; + int64_t y; + } input; + ValueMatcher expected_result_matcher; +}; + +class PolicyEvaluationTest : public testing::TestWithParam { +}; + +TEST_P(PolicyEvaluationTest, Evaluate) { + const auto& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(test_case.yaml_policy)); + ASSERT_OK_AND_ASSIGN(auto validation_result, + CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(validation_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, validation_result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + // Set up activation + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(test_case.input.x)); + activation.InsertOrAssignValue("y", cel::IntValue(test_case.input.y)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.expected_result_matcher); +} + +constexpr absl::string_view kEvalPolicyYaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: x > 10 && y > 10 + output: '"both greater than 10"' + - condition: x > 10 + output: '"x greater than 10"' + - condition: y > 10 + output: '"y greater than 10"' + - output: '"default"' +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + PolicyEvaluationTest, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "BothGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 15}, + .expected_result_matcher = StringValueIs("both greater than 10"), + }, + EvaluationTestCase{ + .name = "XGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 5}, + .expected_result_matcher = StringValueIs("x greater than 10"), + }, + EvaluationTestCase{ + .name = "YGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 15}, + .expected_result_matcher = StringValueIs("y greater than 10"), + }, + EvaluationTestCase{ + .name = "Default", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 5}, + .expected_result_matcher = StringValueIs("default"), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNonExhaustivePolicyYaml = R"yaml( +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x < 3 + output: 1 + - condition: x < 5 + output: 2 + - condition: x < 0 + rule: + match: + - condition: x > -2 + output: 3 + - condition: x > -4 + output: 4 + - output: 5 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NonExhaustivePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals0_FallthroughTopLevel", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEquals2_MatchesFirstNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals6_FallthroughNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 6, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1_MatchesMinus2", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3_MatchesMinus4", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus5_MatchesDefault", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -5, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(5)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNestedVariablePolicyYaml = R"yaml( +name: nested_rule4 +rule: + variables: + - name: i + expression: "1" + - name: j + expression: "2" + match: + - condition: x > 0 + rule: + variables: + - name: k + expression: "3" + match: + - output: "variables.i + variables.j + variables.k" + - condition: x < 0 + rule: + variables: + - name: j + expression: "5" + - name: k + expression: "4" + match: + - output: "variables.i + variables.j + variables.k" + - output: "variables.i + variables.j" +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NestedVariablePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XGreaterThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(6), + }, + EvaluationTestCase{ + .name = "XLessThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(10), + }, + EvaluationTestCase{ + .name = "XEquals0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view + kOptionalChainingUnconditionalSubRuleOptionalParentYaml = R"yaml( +name: optional_chaining +rule: + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 + condition: x < 0 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRuleOptionalParent, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalSubRuleYaml = R"yaml( +name: optional_chaining +rule: + id: r1 + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRule, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(2), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalComplexYaml = R"yaml( +name: optional_chaining +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: x == 1 + output: 1 + - output: 2 + - rule: + match: + - condition: x == -1 + output: 3 + - condition: x == -2 + output: 4 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalComplex, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus2", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 + - output: 3 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = IntValueIs(2), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: non_exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalNonExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CompilerTest, ImportsAndAbbreviations) { + absl::string_view yaml = R"yaml( +name: imports_test +imports: + - name: cel.expr.conformance.proto3.TestAllTypes +rule: + match: + - condition: 'spec == TestAllTypes{single_int32: 10}' + output: '"matched"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + auto ast_or = CompilePolicy(*compiler, *policy); + ASSERT_THAT(ast_or, IsOk()); +} + +TEST(CompilerTest, MatchWithoutProductionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match must specify an output or rule")); +} + +int GetAstHeight(const cel::Ast& ast) { + auto nav_ast = cel::NavigableAst::Build(ast.root_expr()); + return nav_ast.Root().height(); +} + +TEST(CompilerTest, UnnestHeightValidation) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 1; + auto status_or = CompilePolicy(*compiler, *policy, options); + EXPECT_THAT(status_or.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "unnesting_height_limit must be at least 2"))); + + options.unnesting_height_limit = 2; + EXPECT_THAT(CompilePolicy(*compiler, *policy, options), IsOk()); +} + +constexpr absl::string_view kDeepPolicyYaml = R"yaml( +name: deep_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x > 1 + rule: + match: + - condition: x > 2 + rule: + match: + - condition: x > 3 + rule: + match: + - condition: x > 4 + rule: + match: + - condition: x > 5 + output: 6 + - output: 5 + - output: 4 + - output: 3 + - output: 2 + - output: 1 + - output: 0 +)yaml"; + +TEST(CompilerTest, UnnestHeightReduction) { + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + // Compile without unnesting + CompilePolicyOptions options_no_unnest; + options_no_unnest.unnesting_height_limit = 0; + ASSERT_OK_AND_ASSIGN(auto result_no_unnest, + CompilePolicy(*compiler, *policy, options_no_unnest)); + ASSERT_TRUE(result_no_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_no_unnest, result_no_unnest.ReleaseAst()); + int height_no_unnest = GetAstHeight(*ast_no_unnest); + + CompilePolicyOptions options_unnest; + options_unnest.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result_unnest, + CompilePolicy(*compiler, *policy, options_unnest)); + ASSERT_TRUE(result_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_unnest, result_unnest.ReleaseAst()); + int height_unnest = GetAstHeight(*ast_unnest); + + EXPECT_EQ(height_no_unnest, 8); + EXPECT_EQ(height_unnest, 5); + EXPECT_LT(height_unnest, height_no_unnest); +} + +TEST(CompilerTest, UnnestComprehensionFailure) { + absl::string_view yaml = R"yaml( +name: comprehension_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: "[1, 2].all(i, i > x)" + output: 1 + - output: 2 + - output: 0 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("cannot unnest AST due to comprehension")); +} + +struct UnnestEvaluationTestCase { + std::string name; + int64_t x; + ValueMatcher expected; +}; + +class UnnestedDeepPolicyEvaluationTest + : public testing::TestWithParam {}; + +TEST_P(UnnestedDeepPolicyEvaluationTest, Evaluate) { + const auto& tc = GetParam(); + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(tc.x)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(cel::Value res, program->Evaluate(&arena, activation)); + + EXPECT_THAT(res, tc.expected); +} + +INSTANTIATE_TEST_SUITE_P( + UnnestedDeepPolicyEvaluation, UnnestedDeepPolicyEvaluationTest, + testing::Values(UnnestEvaluationTestCase{"XEquals6", 6, IntValueIs(6)}, + UnnestEvaluationTestCase{"XEquals5", 5, IntValueIs(5)}, + UnnestEvaluationTestCase{"XEquals4", 4, IntValueIs(4)}, + UnnestEvaluationTestCase{"XEquals3", 3, IntValueIs(3)}, + UnnestEvaluationTestCase{"XEquals2", 2, IntValueIs(2)}, + UnnestEvaluationTestCase{"XEquals1", 1, IntValueIs(1)}, + UnnestEvaluationTestCase{"XEquals0", 0, IntValueIs(0)}, + UnnestEvaluationTestCase{"XEqualsMinus1", -1, + IntValueIs(0)}), + [](const testing::TestParamInfo< + UnnestedDeepPolicyEvaluationTest::ParamType>& info) { + return info.param.name; + }); + +TEST(CompilerTest, UnnestCleanupRunsWhenDisabled) { + // A policy without variables and without nesting. + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 0; // Disabled + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // If cleanup ran, it should have optimized away the trivial `cel.@block`. + // So the root expression should NOT be a call to `cel.@block`. + // It should be just the constant `"ok"`. + auto nav_ast = cel::NavigableAst::Build(ast->root_expr()); + EXPECT_FALSE(nav_ast.Root().expr()->has_call_expr() && + nav_ast.Root().expr()->call_expr().function() == "cel.@block"); + EXPECT_TRUE(nav_ast.Root().expr()->has_const_expr()); +} +} // namespace +} // namespace cel diff --git a/policy/internal/BUILD b/policy/internal/BUILD new file mode 100644 index 000000000..30f43d431 --- /dev/null +++ b/policy/internal/BUILD @@ -0,0 +1,68 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "issue_reporter", + srcs = ["issue_reporter.cc"], + hdrs = ["issue_reporter.h"], + deps = [ + "//common:source", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "optimizer_expr_factory", + srcs = ["optimizer_expr_factory.cc"], + hdrs = ["optimizer_expr_factory.h"], + deps = [ + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor_base", + "//common:constant", + "//common:expr", + "//common:expr_factory", + "//common:source", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "optimizer_expr_factory_test", + srcs = ["optimizer_expr_factory_test.cc"], + deps = [ + ":optimizer_expr_factory", + "//common:ast", + "//common:ast_proto", + "//common:ast_rewrite", + "//common:decl", + "//common:expr", + "//common:expr_factory", + "//common:source", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:expr_printer", + "//tools:cel_unparser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/policy/internal/issue_reporter.cc b/policy/internal/issue_reporter.cc new file mode 100644 index 000000000..944e687d6 --- /dev/null +++ b/policy/internal/issue_reporter.cc @@ -0,0 +1,45 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/issue_reporter.h" + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel::policy_internal { + +void IssueReporter::ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message) { + issues_.push_back({element, severity, message}); +} + +void IssueReporter::ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, + absl::string_view message) { + issues_.push_back({element, relative_position, severity, message}); +} + +void IssueReporter::ReportError(CelPolicyElementId element, + absl::string_view message) { + ReportIssue(element, Severity::kError, message); +} + +void IssueReporter::ReportError(CelPolicyElementId element, SourcePosition pos, + absl::string_view message) { + ReportOffsetIssue(element, pos, Severity::kError, message); +} + +} // namespace cel::policy_internal diff --git a/policy/internal/issue_reporter.h b/policy/internal/issue_reporter.h new file mode 100644 index 000000000..3f88806ef --- /dev/null +++ b/policy/internal/issue_reporter.h @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel::policy_internal { + +class IssueReporter { + private: + using Severity = CelPolicyIssue::Severity; + + public: + void ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message); + + void ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, absl::string_view message); + + void ReportError(CelPolicyElementId element, absl::string_view message); + void ReportError(CelPolicyElementId element, SourcePosition relative_pos, + absl::string_view message); + + std::vector ReleaseIssues() { + using std::swap; + std::vector out; + swap(out, issues_); + return out; + } + const std::vector& issues() const { return issues_; } + + private: + std::vector issues_; +}; + +} // namespace cel::policy_internal + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ diff --git a/policy/internal/optimizer_expr_factory.cc b/policy/internal/optimizer_expr_factory.cc new file mode 100644 index 000000000..6c89ae958 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.cc @@ -0,0 +1,373 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +namespace { + +class MaxIdVisitor final : public AstVisitorBase { + public: + ExprId max_id() const { return max_id_; } + + void PreVisitExpr(const Expr& expr) override { + max_id_ = std::max(max_id_, expr.id()); + } + + void PostVisitExpr(const Expr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr& struct_expr) override { + for (const auto& field : struct_expr.fields()) { + max_id_ = std::max(max_id_, field.id()); + } + } + + void PostVisitMap(const Expr&, const MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + max_id_ = std::max(max_id_, entry.id()); + } + } + + private: + ExprId max_id_ = 0; +}; + +ExprId GetMaxId(const Expr& expr) { + MaxIdVisitor visitor; + AstTraverse(expr, visitor); + return visitor.max_id(); +} + +ExprId GetMaxId(const Ast& ast) { + ExprId max_id = GetMaxId(ast.root_expr()); + for (const auto& [id, _] : ast.source_info().positions()) { + max_id = std::max(max_id, id); + } + for (const auto& [id, expr] : ast.source_info().macro_calls()) { + max_id = std::max(max_id, id); + max_id = std::max(max_id, GetMaxId(expr)); + } + return max_id; +} + +// Replaces nested macros in a macro_calls expr with reference nodes. +// +// The macro_calls map is used for retaining the original structure of the +// parsed expression before macro expansion. When a macro appears inside another +// macro, the parser will replace the inner macro expr node with an unspecified +// expr with the inner macro's ID in the macro_calls map to save space. +class MakeMacroCallRewrite final : public AstRewriterBase { + public: + explicit MakeMacroCallRewrite(const SourceInfo& source_info) + : source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (source_info_.macro_calls().find(expr.id()) != + source_info_.macro_calls().end()) { + ExprId id = expr.id(); + expr.mutable_kind() = UnspecifiedExpr(); + expr.set_id(id); + return true; + } + return false; + } + + private: + const SourceInfo& source_info_; +}; + +// Updates macro_calls map entries to reflect a replaced expression in the +// main AST. +class ReplaceMacroCallRewrite final : public AstRewriterBase { + public: + ReplaceMacroCallRewrite(ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) + : old_id_(old_id), replacement_(replacement), source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = macro_replacement(); + return true; + } + return false; + } + + Expr macro_replacement() { + if (!macro_replacement_) { + macro_replacement_.emplace(replacement_); + MakeMacroCallRewrite hole_creator(source_info_); + AstRewrite(*macro_replacement_, hole_creator); + } + return *macro_replacement_; + } + + private: + ExprId old_id_; + const Expr& replacement_; + absl::optional macro_replacement_; + const SourceInfo& source_info_; +}; + +void ReplaceSubExpr(Expr& expr, ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) { + ReplaceMacroCallRewrite rewriter(old_id, replacement, source_info); + AstRewrite(expr, rewriter); +} + +class IdRewriter : public AstRewriterBase { + using CopyIdFn = absl::AnyInvocable; + + public: + explicit IdRewriter(CopyIdFn copy_id) : copy_id_(std::move(copy_id)) {} + + // No structure changes just ids. + bool PreVisitRewrite(Expr& expr) override { + expr.set_id(copy_id_(expr.id())); + if (expr.has_struct_expr()) { + for (auto& field : expr.mutable_struct_expr().mutable_fields()) { + field.set_id(copy_id_(field.id())); + } + } else if (expr.has_map_expr()) { + for (auto& entry : expr.mutable_map_expr().mutable_entries()) { + entry.set_id(copy_id_(entry.id())); + } + } + return false; + } + + private: + CopyIdFn copy_id_; +}; + +} // namespace + +OptimizerExprFactory::OptimizerExprFactory(Ast basis) + : ast_(std::move(basis)), next_id_(GetMaxId(ast_) + 1) {} + +OptimizerExprFactory::OptimizerExprFactory() : next_id_(1) {} + +Expr OptimizerExprFactory::Copy(const Expr& expr) { + Expr copied = expr; + IdRewriter rewriter([this](ExprId id) { return CopyId(id); }); + AstRewrite(copied, rewriter); + return copied; +} + +ListExprElement OptimizerExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField OptimizerExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry OptimizerExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +ExprId OptimizerExprFactory::NextId() { return next_id_++; } + +ExprId OptimizerExprFactory::CopyId(ExprId id) { + if (id == 0) { + return 0; + } + auto it = renumbers_.find(id); + if (it != renumbers_.end()) { + return it->second; + } + ExprId new_id = NextId(); + renumbers_[id] = new_id; + return new_id; +} + +SourceInfo OptimizerExprFactory::RemapSourceInfo(const SourceInfo& info, + SourcePosition offset) { + SourceInfo out; + + for (const auto& [old_id, macro_expr] : info.macro_calls()) { + if (auto it = renumbers_.find(old_id); it != renumbers_.end()) { + ExprId new_id = it->second; + out.mutable_macro_calls()[new_id] = Copy(macro_expr); + } + } + + for (const auto& [old_id, new_id] : renumbers_) { + if (auto it = info.positions().find(old_id); it != info.positions().end()) { + out.mutable_positions()[new_id] = it->second + offset; + } + } + + return out; +} + +void OptimizerExprFactory::MergeSourceInfo(const SourceInfo& info) { + auto& target_info = ast_.mutable_source_info(); + + for (const auto& [id, pos] : info.positions()) { + auto [it, inserted] = target_info.mutable_positions().insert({id, pos}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in positions merge"}); + } + } + + for (const auto& [id, expr] : info.macro_calls()) { + auto [it, inserted] = target_info.mutable_macro_calls().insert({id, expr}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in macro calls merge"}); + } + } + + // TODO(b/506179116): need to add some check that we aren't + // introducing incompatible tags. Not possible in the policy compiler right + // now. + for (const auto& ext : info.extensions()) { + auto& target_exts = target_info.mutable_extensions(); + if (!absl::c_linear_search(target_exts, ext)) { + target_exts.push_back(ext); + } + } +} + +void OptimizerExprFactory::RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata) { + auto& source_info = ast_.mutable_source_info(); + if (!keep_metadata) { + source_info.mutable_positions().erase(id); + source_info.mutable_macro_calls().erase(id); + } + + for (auto& [macro_id, macro_expr] : source_info.mutable_macro_calls()) { + ReplaceSubExpr(macro_expr, id, replacement, source_info); + } +} + +Expr OptimizerExprFactory::ReportError(absl::string_view message) { + ExprId id = NextId(); + issues_.push_back(Issue{id, std::string(message)}); + return NewUnspecified(id); +} + +Expr OptimizerExprFactory::ReportErrorAt(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{expr.id(), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::ReportErrorAtCopy(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{CopyId(expr.id()), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::NewUnspecified() { return NewUnspecified(NextId()); } + +Expr OptimizerExprFactory::NewNullConst() { return NewNullConst(NextId()); } + +Expr OptimizerExprFactory::NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(const char* value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(const char* value) { + return NewStringConst(NextId(), value); +} + +absl::flat_hash_map OptimizerExprFactory::ConsumeRenumbers() { + using std::swap; + absl::flat_hash_map out; + swap(out, renumbers_); + return out; +} + +void OptimizerExprFactory::StartCopyContext() { renumbers_.clear(); } + +const std::vector& OptimizerExprFactory::issues() + const { + return issues_; +} + +const Ast& OptimizerExprFactory::ast() const { return ast_; } + +Ast& OptimizerExprFactory::mutable_ast() { return ast_; } + +absl::string_view OptimizerExprFactory::AccuVarName() { + return ExprFactory::AccuVarName(); +} + +Expr OptimizerExprFactory::NewAccuIdent() { return NewAccuIdent(NextId()); } + +ExprId OptimizerExprFactory::CopyId(const Expr& expr) { + return CopyId(expr.id()); +} + +} // namespace cel diff --git a/policy/internal/optimizer_expr_factory.h b/policy/internal/optimizer_expr_factory.h new file mode 100644 index 000000000..6f63f1485 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.h @@ -0,0 +1,419 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestOptimizerExprFactory; + +// `OptimizerExprFactory` is a specialization of `ExprFactory` used for AST +// optimization. It provides utilities for correcting metadata for modified +// ASTs. +class OptimizerExprFactory : protected ExprFactory { + public: + struct Issue { + ExprId location = 0; + std::string message; + }; + + explicit OptimizerExprFactory(Ast basis); + OptimizerExprFactory(); + + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + // Consume the current set of renumberings. + absl::flat_hash_map ConsumeRenumbers(); + + // Starts a new copy context. The current set of renumberings are cleared. + void StartCopyContext(); + + const std::vector& issues() const; + + // Record that a node in the working AST was replaced. This is used to correct + // metadata referencing the old ID. + void RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata = false); + + // Makes a copy of source metadata that is remapped to new expr Ids using + // current renumberings. This is suitable for merging into the main source + // info. + SourceInfo RemapSourceInfo(const SourceInfo& info, SourcePosition offset = 0); + + // Merge a remapped SourceInfo into the current one. + void MergeSourceInfo(const SourceInfo& info); + + const Ast& ast() const; + Ast& mutable_ast(); + + absl::string_view AccuVarName(); + + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified(); + + ABSL_MUST_USE_RESULT Expr NewNullConst(); + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value); + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value); + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value); + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name); + + ABSL_MUST_USE_RESULT Expr NewAccuIdent(); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field); + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args); + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args); + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false); + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields); + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false); + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result); + + ABSL_MUST_USE_RESULT Expr ReportError(absl::string_view message); + + // Reports an error at the id in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAt(const Expr& expr, + absl::string_view message); + // Reports an error at the mapped id of the copy of expr in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAtCopy(const Expr& expr, + absl::string_view message); + + protected: + ABSL_MUST_USE_RESULT ExprId NextId(); + + ABSL_MUST_USE_RESULT ExprId CopyId(ExprId id); + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr); + + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + private: + Ast ast_; + absl::flat_hash_map renumbers_; + std::vector issues_; + + ExprId next_id_ = 1; +}; + +// Implementation details. + +template +Expr OptimizerExprFactory::NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); +} + +template +Expr OptimizerExprFactory::NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); +} + +template +Expr OptimizerExprFactory::NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); +} + +template +StructExprField OptimizerExprFactory::NewStructField(Name name, Value value, + bool optional) { + return NewStructField(NextId(), std::move(name), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); +} + +template +MapExprEntry OptimizerExprFactory::NewMapEntry(Key key, Value value, + bool optional) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); +} + +template +Expr OptimizerExprFactory::NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); +} + +template +Expr OptimizerExprFactory::NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ diff --git a/policy/internal/optimizer_expr_factory_test.cc b/policy/internal/optimizer_expr_factory_test.cc new file mode 100644 index 000000000..1b14b5628 --- /dev/null +++ b/policy/internal/optimizer_expr_factory_test.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/ast_rewrite.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/expr_printer.h" +#include "tools/cel_unparser.h" + +namespace cel { + +using ::testing::SizeIs; + +// Expose protected members of OptimizerExprFactory for use in tests +// +// These allow setting explicit IDs which is not safe for the optimizing +// factory. +class TestOptimizerExprFactory final : public OptimizerExprFactory { + public: + using OptimizerExprFactory::OptimizerExprFactory; + + using OptimizerExprFactory::NewBoolConst; + using OptimizerExprFactory::NewCall; + using OptimizerExprFactory::NewComprehension; + using OptimizerExprFactory::NewIdent; + using OptimizerExprFactory::NewList; + using OptimizerExprFactory::NewListElement; + using OptimizerExprFactory::NewMap; + using OptimizerExprFactory::NewMapEntry; + using OptimizerExprFactory::NewMemberCall; + using OptimizerExprFactory::NewSelect; + using OptimizerExprFactory::NewStruct; + using OptimizerExprFactory::NewStructField; + using OptimizerExprFactory::NewUnspecified; + using OptimizerExprFactory::NextId; +}; + +namespace { + +class ReplaceExprRewriter final : public AstRewriterBase { + public: + ReplaceExprRewriter(ExprId old_id, const Expr& replacement) + : old_id_(old_id), replacement_(replacement) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = replacement_; + return true; + } + return false; + } + + private: + ExprId old_id_; + const Expr& replacement_; +}; + +void ReplaceExprInTree(Expr& expr, ExprId old_id, const Expr& replacement) { + ReplaceExprRewriter rewriter(old_id, replacement); + AstRewrite(expr, rewriter); +} + +absl::StatusOr> CreateTestCompiler() { + CompilerOptions opts; + opts.parser_options.add_macro_calls = true; + CEL_ASSIGN_OR_RETURN( + auto builder, cel::NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("to_replace", cel::DynType()))); + return builder->Build(); +} + +TEST(OptimizerExprFactory, CopyUnspecified) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(OptimizerExprFactory, CopyIdent) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyConst) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(OptimizerExprFactory, CopySelect) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(OptimizerExprFactory, CopyCall) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(OptimizerExprFactory, CopyList) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(OptimizerExprFactory, CopyStruct) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(OptimizerExprFactory, CopyMap) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(OptimizerExprFactory, CopyComprehension) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(OptimizerExprFactory, RemapSourceInfo) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + info.mutable_positions()[1] = 42; // old ID 1 has position 42 + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to position 42 + 10 = 52 + auto it = remapped.positions().find(2); + ASSERT_NE(it, remapped.positions().end()); + EXPECT_EQ(it->second, 52); +} + +TEST(OptimizerExprFactory, RemapSourceInfoWithMacroCalls) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + // old ID 1 has macro call with ID 3 + info.mutable_macro_calls()[1] = factory.NewIdent("bar"); + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to the copied macro call + // since "bar" has ID 3, Copy(bar) should map ID 3 to ID 4 + + auto it = remapped.macro_calls().find(2); + ASSERT_NE(it, remapped.macro_calls().end()); + + // The macro call should be an Ident with new ID 4 + EXPECT_EQ(it->second.id(), 4); + EXPECT_TRUE(it->second.has_ident_expr()); + EXPECT_EQ(it->second.ident_expr().name(), "bar"); +} + +TEST(OptimizerExprFactory, ReportError) { + TestOptimizerExprFactory factory{Ast()}; + Expr err_expr = factory.ReportError("something went wrong"); + + // err_expr should be unspecified with ID 1 + EXPECT_EQ(err_expr.id(), 1); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with ID 1 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "something went wrong"); +} + +TEST(OptimizerExprFactory, ReportErrorAt) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + Expr err_expr = factory.ReportErrorAtCopy(orig, "error on foo"); + + // err_expr should be unspecified with ID 3 (NextId) + EXPECT_EQ(err_expr.id(), 3); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with mapped ID 2 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 2); + EXPECT_EQ(factory.issues()[0].message, "error on foo"); +} + +TEST(OptimizerExprFactory, MergeSourceInfo) { + // Create a base AST with some source info + SourceInfo base_info; + base_info.set_syntax_version("cel1"); + base_info.set_location("test.cel"); + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + + TestOptimizerExprFactory factory{std::move(base_ast)}; + + // Create a new source info to merge + SourceInfo new_info; + new_info.mutable_positions()[2] = 20; + + factory.MergeSourceInfo(new_info); + + // The merged source info should have both positions + const auto& merged_info = factory.ast().source_info(); + EXPECT_EQ(merged_info.syntax_version(), "cel1"); + EXPECT_EQ(merged_info.location(), "test.cel"); + + auto it1 = merged_info.positions().find(1); + ASSERT_NE(it1, merged_info.positions().end()); + EXPECT_EQ(it1->second, 10); + + auto it2 = merged_info.positions().find(2); + ASSERT_NE(it2, merged_info.positions().end()); + EXPECT_EQ(it2->second, 20); +} + +TEST(OptimizerExprFactory, MergeSourceInfoConflict) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_positions()[1] = 20; // conflicting ID 1 + + factory.MergeSourceInfo(new_info); + + // Should report an error for the conflict + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in positions merge"); +} + +TEST(OptimizerExprFactory, RecordReplacement) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + base_info.mutable_positions()[2] = 20; + + TestOptimizerExprFactory factory{Ast()}; + + // macro_calls[1] maps ID 1 to macro call "bar(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[1] = + factory.NewCall("bar", factory.NewIdent(1, "foo")); + + // macro_calls[2] maps ID 2 to macro call "baz(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[2] = + factory.NewCall("baz", factory.NewIdent(1, "foo")); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory optimizer{std::move(base_ast)}; + + // Record the replacement of ID 1 by a new Ident "replacement" with ID 3 + optimizer.RecordReplacement(1, factory.NewIdent(3, "replacement")); + + const auto& result_info = optimizer.ast().source_info(); + + // 1. ID 1 should be erased from positions + EXPECT_EQ(result_info.positions().find(1), result_info.positions().end()); + EXPECT_NE(result_info.positions().find(2), result_info.positions().end()); + + // 2. ID 1 should be erased from macro_calls keys + EXPECT_EQ(result_info.macro_calls().find(1), result_info.macro_calls().end()); + + // 3. macro_calls[2] should still exist, but its argument referencing ID 1 + // should be replaced with the Ident "replacement" with ID 3 inline + auto it = result_info.macro_calls().find(2); + ASSERT_NE(it, result_info.macro_calls().end()); + + const Expr& macro_expr = it->second; + ASSERT_TRUE(macro_expr.has_call_expr()); + ASSERT_EQ(macro_expr.call_expr().args().size(), 1); + + const Expr& arg = macro_expr.call_expr().args()[0]; + EXPECT_EQ(arg.id(), 3); + EXPECT_TRUE(arg.has_ident_expr()); + EXPECT_EQ(arg.ident_expr().name(), "replacement"); +} + +class IdAdorner : public cel::test::ExpressionAdorner { + public: + std::string Adorn(const cel::Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(OptimizerExprFactory, UnparseCopiedMacroCall) { + // Arrange: create an template expression and one to inline. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto basis_result, + compiler->Compile("[1].map(x, x + to_replace)")); + ASSERT_TRUE(basis_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto basis_ast, basis_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto copy_result, + compiler->Compile("[1].filter(x, x > 2).size()")); + ASSERT_TRUE(copy_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto copy_ast, copy_result.ReleaseAst()); + + // Locate the "to_replace" IdentExpr node in reference_map + ExprId to_replace_id = 0; + for (const auto& [id, ref] : basis_ast->reference_map()) { + if (ref.name() == "to_replace") { + to_replace_id = id; + break; + } + } + ASSERT_NE(to_replace_id, 0); + + // Act: implement the optimization. + TestOptimizerExprFactory factory{std::move(*basis_ast)}; + Expr copied_expr = factory.Copy(copy_ast->root_expr()); + SourceInfo remapped_info = factory.RemapSourceInfo(copy_ast->source_info()); + factory.MergeSourceInfo(remapped_info); + + ReplaceExprInTree(factory.mutable_ast().mutable_root_expr(), to_replace_id, + copied_expr); + factory.RecordReplacement(to_replace_id, copied_expr); + + // Test AST structure. + EXPECT_EQ( + cel::test::ExprPrinter(IdAdorner()).Print(factory.ast().root_expr()), + R"(__comprehension__( + // Variable + x, + // Target + [ + 1#2 + ]#1, + // Accumulator + @result, + // Init + []#8, + // LoopCondition + true#9, + // LoopStep + _+_( + @result#10, + [ + _+_( + x#5, + __comprehension__( + // Variable + x, + // Target + [ + 1#18 + ]#17, + // Accumulator + @result, + // Init + []#19, + // LoopCondition + true#20, + // LoopStep + _?_:_( + _>_( + x#23, + 2#24 + )#22, + _+_( + @result#26, + [ + x#28 + ]#27 + )#25, + @result#29 + )#21, + // Result + @result#30)#16.size()#15 + )#6 + ]#11 + )#12, + // Result + @result#13)#14)"); + + // Check that the structure is compatible with unparser. + cel::expr::ParsedExpr optimized_parsed; + auto status = AstToParsedExpr(factory.ast(), &optimized_parsed); + ASSERT_THAT(status, absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::string unparsed, + google::api::expr::Unparse(optimized_parsed)); + + EXPECT_EQ(unparsed, "[1].map(x, x + [1].filter(x, x > 2).size())"); + + const CallExpr& call_expr = factory.mutable_ast() + .mutable_source_info() + .mutable_macro_calls()[14] + .mutable_call_expr(); + ASSERT_THAT(call_expr.args(), SizeIs(2)); + ASSERT_THAT(call_expr.args()[1].call_expr().args(), SizeIs(2)); + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].id(), 15); + + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].call_expr().target().id(), + 16); + EXPECT_EQ(call_expr.args()[1] + .call_expr() + .args()[1] + .call_expr() + .target() + .kind_case(), + ExprKindCase::kUnspecifiedExpr); +} + +TEST(OptimizerExprFactory, CopyMultipleAstsWithConsumeRenumbers) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto ast1_result, compiler->Compile("[1]")); + ASSERT_TRUE(ast1_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast1, ast1_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto ast2_result, compiler->Compile("2")); + ASSERT_TRUE(ast2_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast2, ast2_result.ReleaseAst()); + + TestOptimizerExprFactory factory{Ast()}; + + Expr copied1 = factory.Copy(ast1->root_expr()); + auto renumbers1 = factory.ConsumeRenumbers(); + + Expr copied2 = factory.Copy(ast2->root_expr()); + auto renumbers2 = factory.ConsumeRenumbers(); + + EXPECT_EQ(renumbers1.size(), 2); + EXPECT_EQ(renumbers2.size(), 1); + + EXPECT_NE(copied1.id(), copied2.id()); + EXPECT_GT(copied2.id(), copied1.id()); +} + +TEST(OptimizerExprFactory, MaxIdVisitorExprKinds) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateTestCompiler()); + + // Expression that covers all the kinds. + ASSERT_OK_AND_ASSIGN(auto source, NewSource(R"cel( + Struct{field : 1} || + {'key' : 'value'} || [1].exists(x, x) || foo(bar))cel")); + ASSERT_OK_AND_ASSIGN(auto ast, compiler->GetParser().Parse(*source)); + + TestOptimizerExprFactory factory{std::move(*ast)}; + + EXPECT_EQ(factory.NextId(), 26); +} + +TEST(OptimizerExprFactory, CopyListElement) { + TestOptimizerExprFactory factory{Ast()}; + ListExprElement orig = factory.NewListElement(factory.NewIdent("foo")); + ListExprElement copied = factory.Copy(orig); + EXPECT_EQ(copied.expr(), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyStructField) { + TestOptimizerExprFactory factory{Ast()}; + StructExprField orig = factory.NewStructField("bar", factory.NewIdent("baz")); + StructExprField copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 3); + EXPECT_EQ(copied.name(), "bar"); + EXPECT_EQ(copied.value(), factory.NewIdent(4, "baz")); +} + +TEST(OptimizerExprFactory, CopyMapEntry) { + TestOptimizerExprFactory factory{Ast()}; + MapExprEntry orig = + factory.NewMapEntry(factory.NewIdent("bar"), factory.NewIdent("baz")); + MapExprEntry copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 4); + EXPECT_EQ(copied.key(), factory.NewIdent(5, "bar")); + EXPECT_EQ(copied.value(), factory.NewIdent(6, "baz")); +} + +TEST(OptimizerExprFactory, MergeSourceInfoMacroConflict) { + SourceInfo base_info; + base_info.mutable_macro_calls()[1] = Expr(); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_macro_calls()[1] = Expr(); + + factory.MergeSourceInfo(new_info); + + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in macro calls merge"); +} + +} // namespace +} // namespace cel diff --git a/policy/test_custom_yaml_policy_parser.cc b/policy/test_custom_yaml_policy_parser.cc new file mode 100644 index 000000000..faced6952 --- /dev/null +++ b/policy/test_custom_yaml_policy_parser.cc @@ -0,0 +1,188 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parser.h" +#include "policy/yaml_policy_parser.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel::internal { + +// TestCustomYamlPolicyParser is used to support unit tests for custom tags +// and custom policy structures. It demonstrates the versatility of the +// cel::YamlPolicyParser framework API by implementing custom tag and block +// parsing without needing to modify the core parser. +class TestCustomYamlPolicyParser : public cel::YamlPolicyParser { + absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const override { + if (tag_name.value() == "name" || tag_name.value() == "description" || + tag_name.value() == "imports") { + return cel::YamlPolicyParser::ParsePolicyTag(ctx, tag_name, node); + } + if (tag_name.value() == "purpose") { + std::optional purpose = + GetValueString(ctx, node, "Policy purpose is not a string"); + if (purpose.has_value()) { + ctx.policy().mutable_metadata()["purpose"] = *purpose; + } + return true; + } + if (tag_name.value() == "version") { + std::optional version = + GetValueString(ctx, node, "Policy version is not a string"); + if (!version.has_value()) { + return true; + } + int version_int; + if (!absl::SimpleAtoi(version->value(), &version_int)) { + ctx.ReportError(version->id(), + absl::StrCat("Policy version is not an integer: ", + version->value())); + return true; + } + ctx.policy().mutable_metadata()["version"] = version_int; + return true; + } + + if (tag_name.value() == "conditions") { + if (!node.IsSequence()) { + ctx.ReportError(tag_name.id(), "Policy 'conditions' is not a sequence"); + return true; + } + for (const YAML::Node& condition : node) { + // Track the number of existing matches before parsing. When ParseMatch + // evaluates an 'else' block, it recursively triggers parsing and adds + // internal inner matches directly to the rule's match vector. + // Inserting the outer match at begin() + size_before ensures that the + // primary outer 'if' condition is always evaluated before its nested + // 'else' fallbacks. + // + // Example: + // if: x > 0 + // then: "positive" + // else: "negative" + // + // The inner "negative" match is parsed and appended to rule.matches() + // by the inner recursive call, before the outer "x > 0" match finishes. + // Inserting at size_before places the "x > 0" match ahead of the inner + // one. + size_t size_before = ctx.policy().rule().matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, + cel::YamlPolicyParser::ParseMatch( + ctx, condition, ctx.policy().mutable_rule())); + ctx.policy().mutable_rule().mutable_matches().insert( + ctx.policy().mutable_rule().mutable_matches().begin() + size_before, + std::move(match)); + } + + return true; + } + return false; + } + + absl::Status ParseThenBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, + Match& match) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'then' is not a string"); + if (val.has_value()) { + OutputBlock output; + output.set_output(*val); + match.set_result(output); + } + } else if (value_node.IsMap()) { + auto nested_rule = std::make_unique(); + CEL_ASSIGN_OR_RETURN( + Match nested_match, + cel::YamlPolicyParser::ParseMatch(ctx, value_node, *nested_rule)); + nested_rule->mutable_matches().insert( + nested_rule->mutable_matches().begin(), std::move(nested_match)); + match.set_result(std::move(nested_rule)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::Status ParseElseBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, Rule& rule) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'else' is not a string"); + if (val.has_value()) { + Match else_match; + else_match.set_id(CollectMetadata(ctx, value_node)); + OutputBlock output; + output.set_output(*val); + else_match.set_result(output); + rule.mutable_matches().push_back(std::move(else_match)); + } + } else if (value_node.IsMap()) { + size_t size_before = rule.matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, cel::YamlPolicyParser::ParseMatch( + ctx, value_node, rule)); + rule.mutable_matches().insert( + rule.mutable_matches().begin() + size_before, std::move(match)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, Match& match, + Rule& rule) const override { + if (tag_name.value() == "if") { + std::optional condition = + GetValueString(ctx, node, "Policy 'if' condition is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "then") { + CEL_RETURN_IF_ERROR(ParseThenBlock(ctx, node, match)); + return true; + } + if (tag_name.value() == "else") { + CEL_RETURN_IF_ERROR(ParseElseBlock(ctx, node, rule)); + return true; + } + return false; + } +}; + +const CelPolicyParser& GetTestCustomYamlPolicyParser() { + static const auto* const parser = new TestCustomYamlPolicyParser(); + return *parser; +} + +} // namespace cel::internal diff --git a/policy/test_util.cc b/policy/test_util.cc new file mode 100644 index 000000000..9fe1e43d1 --- /dev/null +++ b/policy/test_util.cc @@ -0,0 +1,221 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +#include "policy/test_util.h" + +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "cel/expr/value.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "yaml-cpp/yaml.h" + +namespace cel::test { + +namespace { + +absl::Status YamlToExprValue(const YAML::Node& node, + cel::expr::Value* proto) { + if (node.IsNull()) { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + if (node.IsScalar()) { + // Try bool + try { + proto->set_bool_value(node.as()); + return absl::OkStatus(); + } catch (...) { + } + // Try int64 + try { + int64_t val; + if (YAML::convert::decode(node, val)) { + proto->set_int64_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Try double + try { + double val; + if (YAML::convert::decode(node, val)) { + proto->set_double_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Fallback to string + proto->set_string_value(node.as()); + return absl::OkStatus(); + } + if (node.IsSequence()) { + auto* list = proto->mutable_list_value(); + for (const auto& elem : node) { + CEL_RETURN_IF_ERROR(YamlToExprValue(elem, list->add_values())); + } + return absl::OkStatus(); + } + if (node.IsMap()) { + auto* map_val = proto->mutable_map_value(); + for (auto it = node.begin(); it != node.end(); ++it) { + auto* entry = map_val->add_entries(); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->first, entry->mutable_key())); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->second, entry->mutable_value())); + } + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unknown YAML node type"); +} + +absl::Status ParseInputValue( + const YAML::Node& node, + cel::expr::conformance::test::InputValue* input_val) { + if (node.IsMap() && node["expr"].IsDefined()) { + input_val->set_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node.IsMap() && node["value"].IsDefined()) { + return YamlToExprValue(node["value"], input_val->mutable_value()); + } + return YamlToExprValue(node, input_val->mutable_value()); +} + +absl::Status ParseTestOutput(const YAML::Node& node, + cel::expr::conformance::test::TestOutput* output) { + if (!node.IsDefined()) { + return absl::InvalidArgumentError("Missing output node"); + } + if (node.IsMap()) { + if (node["expr"].IsDefined()) { + output->set_result_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node["value"].IsDefined()) { + return YamlToExprValue(node["value"], output->mutable_result_value()); + } + if (node["error"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + eval_error->add_errors()->set_message(node["error"].as()); + return absl::OkStatus(); + } + if (node["error_set"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + for (const auto& err : node["error_set"]) { + eval_error->add_errors()->set_message(err.as()); + } + return absl::OkStatus(); + } + if (node["unknown"].IsDefined()) { + auto* unknown = output->mutable_unknown(); + for (const auto& expr_id_node : node["unknown"]) { + unknown->add_exprs(expr_id_node.as()); + } + return absl::OkStatus(); + } + } + return YamlToExprValue(node, output->mutable_result_value()); +} + +absl::StatusOr +ParsePolicyTestSuiteYamlImpl(absl::string_view yaml_content) { + YAML::Node tests_node; + try { + tests_node = YAML::Load(std::string(yaml_content)); + } catch (const std::exception& e) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse YAML: ", e.what())); + } + + cel::expr::conformance::test::TestSuite test_suite; + if (tests_node["description"].IsDefined()) { + test_suite.set_description(tests_node["description"].as()); + } + + YAML::Node sections = tests_node["sections"]; + if (!sections.IsDefined()) { + sections = tests_node["section"]; // support singular format + } + if (!sections.IsDefined()) { + return absl::InvalidArgumentError( + "Missing 'sections' or 'section' in tests YAML"); + } + + for (const auto& section_node : sections) { + auto* section = test_suite.add_sections(); + if (section_node["name"].IsDefined()) { + section->set_name(section_node["name"].as()); + } + if (section_node["description"].IsDefined()) { + section->set_description(section_node["description"].as()); + } + + YAML::Node tests = section_node["tests"]; + if (!tests.IsDefined()) { + tests = section_node["test"]; // support singular format + } + if (!tests.IsDefined()) { + continue; + } + + for (const auto& test_node : tests) { + auto* test_case = section->add_tests(); + if (test_node["name"].IsDefined()) { + test_case->set_name(test_node["name"].as()); + } + if (test_node["description"].IsDefined()) { + test_case->set_description(test_node["description"].as()); + } + if (test_node["context_expr"].IsDefined()) { + test_case->mutable_input_context()->set_context_expr( + test_node["context_expr"].as()); + } + + YAML::Node input_node = test_node["input"]; + if (input_node.IsDefined() && input_node.IsMap()) { + auto* input_map = test_case->mutable_input(); + for (auto it = input_node.begin(); it != input_node.end(); ++it) { + std::string var_name = it->first.as(); + cel::expr::conformance::test::InputValue input_val; + CEL_RETURN_IF_ERROR(ParseInputValue(it->second, &input_val)); + (*input_map)[var_name] = std::move(input_val); + } + } + + YAML::Node output_node = test_node["output"]; + if (output_node.IsDefined()) { + CEL_RETURN_IF_ERROR( + ParseTestOutput(output_node, test_case->mutable_output())); + } + } + } + + return test_suite; +} + +} // namespace + +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content) { + try { + return ParsePolicyTestSuiteYamlImpl(yaml_content); + } catch (...) { + return absl::InvalidArgumentError("Failed to parse YAML"); + } +} + +} // namespace cel::test diff --git a/policy/test_util.h b/policy/test_util.h new file mode 100644 index 000000000..5fe306050 --- /dev/null +++ b/policy/test_util.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "cel/expr/conformance/test/suite.pb.h" + +namespace cel::test { + +// Parses a YAML content representing a policy test suite (tests.yaml) +// and adapts it to the cel.expr.conformance.test.TestSuite protobuf message. +// +// TODO(uncreated-issue/92): Move to the testrunner library. +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ diff --git a/policy/testdata/BUILD b/policy/testdata/BUILD new file mode 100644 index 000000000..10a26fa0b --- /dev/null +++ b/policy/testdata/BUILD @@ -0,0 +1,19 @@ +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "policy_testdata", + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) + +exports_files( + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) diff --git a/policy/testdata/cel_policy.yaml b/policy/testdata/cel_policy.yaml new file mode 100644 index 000000000..010ad8855 --- /dev/null +++ b/policy/testdata/cel_policy.yaml @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Environment: +# spec: TestAllTypes +name: cel_policy +description: A test policy for CEL +display_name: Cel Policy +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +- name: cel.expr.conformance.proto3.TestAllTypes.NestedEnum +rule: + id: test_rule + description: test rule description + variables: + - name: test_var + expression: > + TestAllTypes{single_int64: 10}.single_int64 + match: + - condition: > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + output: | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + explanation: | + "invalid spec, spec is greater than 10" + - condition: > + spec.standalone_enum == NestedEnum.BAR + output: | + "invalid spec, reference to BAR is not allowed" + - condition: spec.single_int64 == variables.test_var + output: '"invalid spec: exactly matches test_var"' + explanation: '"the spec cannot have single_int64 set to a known bad value"' \ No newline at end of file diff --git a/policy/testdata/cel_policy_parser.baseline b/policy/testdata/cel_policy_parser.baseline new file mode 100644 index 000000000..7a6678bfe --- /dev/null +++ b/policy/testdata/cel_policy_parser.baseline @@ -0,0 +1,89 @@ +POLICY SOURCE: cel_policy.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # Environment: + # spec: TestAllTypes + #0> name: #1> cel_policy + #2> description: #3> A test policy for CEL + #4> display_name: #5> Cel Policy + #6> imports: + - #7> name: #8> cel.expr.conformance.proto3.TestAllTypes + - #9> name: #10> cel.expr.conformance.proto3.TestAllTypes.NestedEnum + #11> rule: + #13> #12> id: #14> test_rule + #15> description: #16> test rule description + #17> variables: + - #18> name: #19> test_var + #20> expression: #21> > + TestAllTypes{single_int64: 10}.single_int64 + #22> match: + - #24> #23> condition: #25> > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + #26> output: #27> | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + #28> explanation: #29> | + "invalid spec, spec is greater than 10" + - #31> #30> condition: #32> > + spec.standalone_enum == NestedEnum.BAR + #33> output: #34> | + "invalid spec, reference to BAR is not allowed" + - #36> #35> condition: #37> spec.single_int64 == variables.test_var + #38> output: #39> '"invalid spec: exactly matches test_var"' + #40> explanation: #41> '"the spec cannot have single_int64 set to a known bad value"' + =========================================================== + name: #1> "cel_policy" + description: #3> "A test policy for CEL" + display_name: #5> "Cel Policy" + imports: + #7> name: #8> "cel.expr.conformance.proto3.TestAllTypes" + #9> name: #10> "cel.expr.conformance.proto3.TestAllTypes.NestedEnum" + #12> rule: { + rule_id: #14> "test_rule" + description: #16> "test rule description" + variable: { + name: #19> "test_var" + expression: #21> "TestAllTypes{single_int64: 10}.single_int64 + " + } + #23> match: { + condition: #25> "spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + " + result: { + output: #27> ""invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + " + explanation: #29> ""invalid spec, spec is greater than 10" + " + } + } + #30> match: { + condition: #32> "spec.standalone_enum == NestedEnum.BAR + " + result: { + output: #34> ""invalid spec, reference to BAR is not allowed" + " + } + } + #35> match: { + condition: #37> "spec.single_int64 == variables.test_var" + result: { + output: #39> ""invalid spec: exactly matches test_var"" + explanation: #41> ""the spec cannot have single_int64 set to a known bad value"" + } + } + } +} diff --git a/policy/testdata/custom_policy_format.yaml b/policy/testdata/custom_policy_format.yaml new file mode 100644 index 000000000..a67356906 --- /dev/null +++ b/policy/testdata/custom_policy_format.yaml @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: test +version: 42 +conditions: +- if: spec.single_string == "none" + then: "'zero'" + else: + if: spec.single_string == "integer" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: "'negative integer'" + else: "'not an integer'" diff --git a/policy/testdata/custom_policy_format_parser.baseline b/policy/testdata/custom_policy_format_parser.baseline new file mode 100644 index 000000000..d5b1a2235 --- /dev/null +++ b/policy/testdata/custom_policy_format_parser.baseline @@ -0,0 +1,75 @@ +POLICY SOURCE: custom_policy_format.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> cel_policy_custom_tags + #2> description: #3> A custom policy format + #4> imports: + - #5> name: #6> cel.expr.conformance.proto3.TestAllTypes + #7> purpose: #8> test + #9> version: #10> 42 + #11> conditions: + - #13> #12> if: #14> spec.single_string == "none" + #15> then: #16> "'zero'" + #17> else: + #19> #18> if: #20> spec.single_string == "integer" + #21> then: + #23> #22> if: #24> spec.single_int32 > 0 + #25> then: #26> "'positive integer'" + #27> else: #29> #28> "'negative integer'" + #30> else: #32> #31> "'not an integer'" + + =========================================================== + name: #1> "cel_policy_custom_tags" + description: #3> "A custom policy format" + metadata: { + purpose: #8> "test" + version: 42 + } + imports: + #5> name: #6> "cel.expr.conformance.proto3.TestAllTypes" + rule: { + #12> match: { + condition: #14> "spec.single_string == "none"" + result: { + output: #16> "'zero'" + } + } + #18> match: { + condition: #20> "spec.single_string == "integer"" + result: + rule: { + #22> match: { + condition: #24> "spec.single_int32 > 0" + result: { + output: #26> "'positive integer'" + } + } + #29> match: { + result: { + output: #28> "'negative integer'" + } + } + } + } + #32> match: { + result: { + output: #31> "'not an integer'" + } + } + } +} diff --git a/policy/testdata/custom_policy_format_with_errors.yaml b/policy/testdata/custom_policy_format_with_errors.yaml new file mode 100644 index 000000000..594747c60 --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors.yaml @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: + - testing +version: new +conditions: +- if: + spec.single_string: "none" + then: "'zero'" + else: "'not zero'" +- if: spec.single_string == "number" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: + - ignore +- else: "'negative integer'" + diff --git a/policy/testdata/custom_policy_format_with_errors_parser.baseline b/policy/testdata/custom_policy_format_with_errors_parser.baseline new file mode 100644 index 000000000..978d27bda --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors_parser.baseline @@ -0,0 +1,16 @@ +POLICY SOURCE: custom_policy_format_with_errors.yaml +-------------------------------------------------------------------- +-------------------------------------------------------------------- +PARSER ISSUES: +ERROR: custom_policy_format_with_errors.yaml:19:3: Policy purpose is not a string + | - testing + | ..^ +ERROR: custom_policy_format_with_errors.yaml:20:10: Policy version is not an integer: new + | version: new + | .........^ +ERROR: custom_policy_format_with_errors.yaml:23:5: Policy 'if' condition is not a string + | spec.single_string: "none" + | ....^ +ERROR: custom_policy_format_with_errors.yaml:31:7: Bad syntax in 'if/then' block + | - ignore + | ......^ diff --git a/policy/testdata/nested_rule.yaml b/policy/testdata/nested_rule.yaml new file mode 100644 index 000000000..2b07faa64 --- /dev/null +++ b/policy/testdata/nested_rule.yaml @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: nested_rule +rule: + variables: + - name: "permitted_regions" + expression: "['us', 'uk', 'es']" + match: + - rule: + id: "banned regions" + description: > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + variables: + - name: "banned_regions" + expression: "{'us': false, 'ru': false, 'ir': false}" + match: + - condition: | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + output: "{'banned': true}" + - condition: resource.origin in variables.permitted_regions + output: "{'banned': false}" + - output: "{'banned': true}" + explanation: "'resource is in the banned region ' + resource.origin" \ No newline at end of file diff --git a/policy/testdata/nested_rule_parser.baseline b/policy/testdata/nested_rule_parser.baseline new file mode 100644 index 000000000..128f81bda --- /dev/null +++ b/policy/testdata/nested_rule_parser.baseline @@ -0,0 +1,84 @@ +POLICY SOURCE: nested_rule.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> nested_rule + #2> rule: + #4> #3> variables: + - #5> name: #6> "permitted_regions" + #7> expression: #8> "['us', 'uk', 'es']" + #9> match: + - #11> #10> rule: + #13> #12> id: #14> "banned regions" + #15> description: #16> > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + #17> variables: + - #18> name: #19> "banned_regions" + #20> expression: #21> "{'us': false, 'ru': false, 'ir': false}" + #22> match: + - #24> #23> condition: #25> | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + #26> output: #27> "{'banned': true}" + - #29> #28> condition: #30> resource.origin in variables.permitted_regions + #31> output: #32> "{'banned': false}" + - #34> #33> output: #35> "{'banned': true}" + #36> explanation: #37> "'resource is in the banned region ' + resource.origin" + =========================================================== + name: #1> "nested_rule" + description: "nested_rule.yaml" + #3> rule: { + variable: { + name: #6> "permitted_regions" + expression: #8> "['us', 'uk', 'es']" + } + #10> match: { + result: + #12> rule: { + rule_id: #14> "banned regions" + description: #16> "determine whether the resource origin is in the banned list. If the region is also in the permitted list, the ban has no effect. + " + variable: { + name: #19> "banned_regions" + expression: #21> "{'us': false, 'ru': false, 'ir': false}" + } + #23> match: { + condition: #25> "resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + " + result: { + output: #27> "{'banned': true}" + } + } + } + } + #28> match: { + condition: #30> "resource.origin in variables.permitted_regions" + result: { + output: #32> "{'banned': false}" + } + } + #33> match: { + result: { + output: #35> "{'banned': true}" + explanation: #37> "'resource is in the banned region ' + resource.origin" + } + } + } +} diff --git a/policy/yaml_policy_parser.cc b/policy/yaml_policy_parser.cc new file mode 100644 index 000000000..c838cff33 --- /dev/null +++ b/policy/yaml_policy_parser.cc @@ -0,0 +1,411 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +CelPolicyElementId YamlPolicyParser::CollectMetadata( + CelPolicyParseContext& ctx, const YAML::Node& node) const { + CelPolicyElementId element_id = ctx.next_element_id(); + if (!node.Mark().is_null()) { + ctx.policy_source().NoteSourcePosition(element_id, node.Mark().pos); + } + return element_id; +} + +std::optional YamlPolicyParser::GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const { + if (!node.IsDefined()) { + // This should never happen since the YAML syntax has already been checked. + return std::nullopt; + } + + CelPolicyElementId id = CollectMetadata(ctx, node); + if (!node.IsScalar()) { + ctx.ReportError(id, error_message); + return std::nullopt; + } + + try { + return ValueString(id, node.as()); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return std::nullopt; + } +} + +absl::Status YamlPolicyParser::ParsePolicy(CelPolicyParseContext& ctx) const { + const Source* source = ctx.policy_source().content(); + if (source == nullptr) { + return absl::OkStatus(); + } + + ctx.policy().set_description(ValueString(-1, source->description())); + std::string text = source->content().ToString(); + YAML::Node node; + try { + node = YAML::Load(text); + } catch (YAML::Exception& e) { + if (!e.mark.is_null()) { + ctx.policy_source().NoteSourcePosition(0, e.mark.pos); + } + ctx.ReportError(0, "Invalid CEL policy YAML syntax"); + return absl::OkStatus(); + } + + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy is not a map"); + return absl::OkStatus(); + } + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, ParsePolicyTag(ctx, *key, value_node)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized top-level policy tag: ", key->value())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParsePolicyTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node) const { + if (tag_name.value() == "imports") { + CEL_RETURN_IF_ERROR(ParseImports(ctx, node)); + return true; + } + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy 'name' is not a string"); + if (name.has_value()) { + ctx.policy().set_name(*name); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy 'description' is not a string"); + if (description.has_value()) { + ctx.policy().set_description(*description); + } + return true; + } + if (tag_name.value() == "display_name") { + std::optional display_name = + GetValueString(ctx, node, "Policy 'display_name' is not a string"); + if (display_name.has_value()) { + ctx.policy().set_display_name(*display_name); + } + return true; + } + if (tag_name.value() == "rule") { + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, ctx.policy().mutable_rule())); + return true; + } + return false; +} + +absl::Status YamlPolicyParser::ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy 'imports' is not a sequence"); + return absl::OkStatus(); + } + + for (const YAML::Node& import : node) { + CelPolicyElementId import_id = CollectMetadata(ctx, import); + if (!import.IsMap()) { + ctx.ReportError(import_id, "Import is not a map"); + continue; + } + const YAML::Node& name_node = import["name"]; + if (!name_node.IsDefined()) { + ctx.ReportError(import_id, "No 'name' tag in import"); + continue; + } + std::optional import_name = + GetValueString(ctx, name_node, "Import name is not a string"); + if (import_name.has_value()) { + ctx.policy().mutable_imports().push_back(Import(import_id, *import_name)); + } + } + return absl::OkStatus(); +} + +absl::Status YamlPolicyParser::ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy 'rule' is not a map"); + return absl::OkStatus(); + } + rule.set_id(CollectMetadata(ctx, node)); + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy rule tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseRuleTag(ctx, *key, value_node, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy rule tag: ", + key->value())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const { + if (tag_name.value() == "id") { + std::optional rule_id = + GetValueString(ctx, node, "Policy rule 'id' is not a string"); + if (rule_id.has_value()) { + rule.set_rule_id(*rule_id); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy rule 'description' is not a string"); + if (description.has_value()) { + rule.set_description(*description); + } + return true; + } + if (tag_name.value() == "variables") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variables' is not a sequence"); + return true; + } + for (const YAML::Node& variable_node : node) { + CEL_ASSIGN_OR_RETURN(Variable variable, + ParseVariable(ctx, variable_node, rule)); + rule.mutable_variables().push_back(std::move(variable)); + } + return true; + } + if (tag_name.value() == "match") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'match' is not a sequence"); + return true; + } + for (const YAML::Node& match_node : node) { + CEL_ASSIGN_OR_RETURN(Match match, ParseMatch(ctx, match_node, rule)); + rule.mutable_matches().push_back(std::move(match)); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseVariable( + CelPolicyParseContext& ctx, const YAML::Node& node, Rule& rule) const { + Variable variable; + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variable' is not a map"); + return variable; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy variable tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseVariableTag(ctx, *key, value_node, variable)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized policy variable tag: ", key->value())); + } + } + return variable; +} + +absl::StatusOr YamlPolicyParser::ParseVariableTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Variable& variable) const { + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy variable 'name' is not a string"); + if (name.has_value()) { + variable.set_name(*name); + } + return true; + } + if (tag_name.value() == "expression") { + std::optional expression = GetValueString( + ctx, node, "Policy variable 'expression' is not a string"); + if (expression.has_value()) { + variable.set_expression(*expression); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + Match match; + match.set_id(CollectMetadata(ctx, node)); + if (!node.IsMap()) { + ctx.ReportError(match.id(), "Policy rule 'match' is not a map"); + return match; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy match tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseMatchTag(ctx, *key, value_node, match, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy match tag: ", + key->value())); + } + } + + if (match.has_output_block()) { + if (match.output_block().output().value().empty() && + match.output_block().explanation().has_value()) { + ctx.ReportError(match.id(), "Match specifies explanation but no output"); + } + } + + return match; +} + +absl::StatusOr YamlPolicyParser::ParseMatchTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Match& match, Rule& rule) const { + if (tag_name.value() == "condition") { + std::optional condition = + GetValueString(ctx, node, "Policy match 'condition' is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "explanation") { + std::optional explanation = + GetValueString(ctx, node, "Policy match 'explanation' is not a string"); + if (explanation.has_value()) { + if (match.has_rule()) { + ctx.ReportError( + tag_name.id(), + "Cannot specify explanation when a nested rule is present"); + } else { + match.mutable_output_block().set_explanation(*explanation); + } + } + return true; + } + if (tag_name.value() == "output") { + std::optional output = + GetValueString(ctx, node, "Policy match 'output' is not a string"); + if (output.has_value()) { + if (match.has_rule()) { + ctx.ReportError(tag_name.id(), + "Cannot specify output when a nested rule is present"); + } else { + match.mutable_output_block().set_output(*output); + } + } + return true; + } + if (tag_name.value() == "rule") { + if (match.has_output_block()) { + ctx.ReportError(tag_name.id(), + "Cannot specify nested rule when output/explanation is " + "present"); + } + auto nested_rule = std::make_unique(); + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, *nested_rule)); + match.set_result(std::move(nested_rule)); + return true; + } + return false; +} + +const CelPolicyParser& GetDefaultYamlPolicyParser() { + static const auto* const parser = new YamlPolicyParser(); + return *parser; +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source) { + return ParseYamlCelPolicy(std::move(policy_source), + GetDefaultYamlPolicyParser()); +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser) { + CelPolicyParseContext ctx(std::move(policy_source)); + CEL_RETURN_IF_ERROR(parser.ParsePolicy(ctx)); + return ctx.GetResult(); +} + +} // namespace cel diff --git a/policy/yaml_policy_parser.h b/policy/yaml_policy_parser.h new file mode 100644 index 000000000..469209333 --- /dev/null +++ b/policy/yaml_policy_parser.h @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +// A parser for YAML-based CEL policies. +// +// To support additional or alternative YAML elements, subclass +// `YamlPolicyParser` and override specific parsing methods, `Parse*` +class YamlPolicyParser : public CelPolicyParser { + public: + std::optional GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const; + + absl::Status ParsePolicy(CelPolicyParseContext& ctx) const override; + + protected: + // Collects metadata (e.g. source position) for the given YAML node, stores it + // in the context, and returns an ID that can be used to refer to it. + virtual CelPolicyElementId CollectMetadata(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a top-level tag in the policy YAML. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const; + + // Parses the imports section of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a rule element of the policy YAML, which may be the top-level rule + // or a sub-rule of a match. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, Rule& rule) const; + + // Parses a tag in a policy YAML rule. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const; + + // Parses a variable element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariable(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML variable. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariableTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Variable& variable) const; + + // Parses a match element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML match. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Match& match, Rule& rule) const; +}; + +// Returns a default implementation of YamlPolicyParser. +const CelPolicyParser& GetDefaultYamlPolicyParser(); + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser); + +// YAML CelPolicy parser that uses the default format as implemented by +// `YamlPolicyParser`. +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ diff --git a/policy/yaml_policy_parser_test.cc b/policy/yaml_policy_parser_test.cc new file mode 100644 index 000000000..4e7dfc49c --- /dev/null +++ b/policy/yaml_policy_parser_test.cc @@ -0,0 +1,305 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +namespace internal { +const CelPolicyParser& GetTestCustomYamlPolicyParser(); +} // namespace internal + +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsNull; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/"; + +constexpr absl::string_view kBaselineSeparator = + "--------------------------------------------------------------------\n"; + +struct YamlPolicyParserTestCase { + std::string policy_source_file; + std::string baseline_file; + const cel::CelPolicyParser& (*parser_factory)(); +}; + +using YamlPolicyParserTest = testing::TestWithParam; + +TEST_P(YamlPolicyParserTest, Parse) { + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().policy_source_file)); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + std::string baseline; + std::string baseline_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().baseline_file)); + ASSERT_THAT(cel::internal::GetFileContents(baseline_file, &baseline), IsOk()); + baseline = absl::StripAsciiWhitespace(baseline); + + std::ostringstream out; + out << "POLICY SOURCE: " << GetParam().policy_source_file << "\n"; + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, GetParam().policy_source_file)); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + + ASSERT_OK_AND_ASSIGN( + CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source, GetParam().parser_factory())); + + out << kBaselineSeparator; + if (parse_result.IsValid()) { + out << "PARSED POLICY:\n"; + out << parse_result.GetPolicy()->DebugString(); + } else { + ASSERT_THAT(parse_result.GetPolicy(), IsNull()); + out << kBaselineSeparator; + out << "PARSER ISSUES:\n"; + for (const auto& issue : parse_result.GetIssues()) { + out << issue.ToDisplayString(*policy_source) << "\n"; + } + } + + std::string actual(absl::StripAsciiWhitespace(out.str())); + if (actual != baseline) { + // Log the actual result to make it easier to copy/paste into the baseline + // file when updating the tests. + ABSL_LOG(INFO) << "Actual:\n" << actual; + EXPECT_EQ(actual, baseline); + } +} + +INSTANTIATE_TEST_SUITE_P( + Formats, YamlPolicyParserTest, + testing::ValuesIn({ + YamlPolicyParserTestCase{ + .policy_source_file = "cel_policy.yaml", + .baseline_file = "cel_policy_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "nested_rule.yaml", + .baseline_file = "nested_rule_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format.yaml", + .baseline_file = "custom_policy_format_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format_with_errors.yaml", + .baseline_file = "custom_policy_format_with_errors_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + })); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +using YamlPolicyParseErrorTest = testing::TestWithParam; + +TEST_P(YamlPolicyParseErrorTest, YamlSyntaxError) { + const ParseTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(param.yaml, "test")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + EXPECT_THAT(parse_result.FormattedIssues(), HasSubstr(param.expected_error)); +} + +std::vector GetParseTestCases() { + return { + ParseTestCase{ + .yaml = R"yaml( ? [ John, Doe ]: age: 30 )yaml", + .expected_error = "1:22: Invalid CEL policy YAML syntax\n" + " | ? [ John, Doe ]: age: 30 \n" + " | .....................^", + }, + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Policy is not a map\n" + " | invalid yaml \n" + " | .^", + }, + ParseTestCase{ + .yaml = R"yaml( + ? [1, 2, 3] + : "Prime numbers sequence" + )yaml", + .expected_error = "2:23: Policy tag is not a string\n" + " | ? [1, 2, 3]\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: N/A + )yaml", + .expected_error = "2:28: Policy 'imports' is not a sequence\n" + " | imports: N/A\n" + " | ...........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - cel.expr.conformance + )yaml", + .expected_error = "3:21: Import is not a map\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - name: + - cel.expr.conformance + )yaml", + .expected_error = "4:21: Import name is not a string\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: do something + )yaml", + .expected_error = "2:25: Policy 'rule' is not a map\n" + " | rule: do something\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + id: + - 22 + )yaml", + .expected_error = "4:21: Policy rule 'id' is not a string\n" + " | - 22\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + no vars + )yaml", + .expected_error = "4:23: Policy rule 'variables' is not a sequence\n" + " | no vars\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: + foo: bar + )yaml", + .expected_error = "5:25: Policy variable 'name' is not a string\n" + " | foo: bar\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: test_var + expression: + - 22 + )yaml", + .expected_error = + "6:23: Policy variable 'expression' is not a string\n" + " | - 22\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: '\u0041\u00a9\u20ac\U0001f680' + - '\u0041\u00a9\u20ac\U0001f680': name + )yaml", + .expected_error = + "5:23: Unrecognized policy variable tag: " + "\\u0041\\u00a9\\u20ac\\U0001f680\n" + " | - '\\u0041\\u00a9\\u20ac\\U0001f680': " + "name\n" + " | ......................^", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(YamlPolicyParseErrorTest, YamlPolicyParseErrorTest, + ::testing::ValuesIn(GetParseTestCases())); + +TEST(YamlPolicyParserTest, OffsetIssueFormatting) { + // TODO(b/506179116): will need to copy the go implementation in extracting + // the source string from the YAML document instead of the interpreted string + // value to fix up error locations in folded and block literals. + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, "cel_policy.yaml")); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, "cel_policy.yaml")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + CelPolicyElementId name_id = policy->name().id(); + + CelPolicyIssue issue(name_id, 4, CelPolicyIssue::Severity::kError, + "Test error"); + + std::string formatted = issue.ToDisplayString(*policy_source); + + EXPECT_THAT(formatted, HasSubstr("ERROR: cel_policy.yaml:16:11: Test error")); + EXPECT_THAT(formatted, HasSubstr(" | name: cel_policy")); + EXPECT_THAT(formatted, HasSubstr(" | ..........^")); +} + +} // namespace +} // namespace cel diff --git a/runtime/BUILD b/runtime/BUILD index b58880146..34ff411a1 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -344,14 +344,11 @@ cc_test( deps = [ ":activation", ":constant_folding", - ":function", - ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:function_descriptor", - "//common:kind", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", @@ -618,6 +615,7 @@ cc_test( ":activation", ":constant_folding", ":function_adapter", + ":optional_types", ":reference_resolver", ":regex_precompilation", ":runtime", diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index e6a74f027..4303116a3 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -326,7 +326,7 @@ TEST_F(ActivationTest, MoveAssignment) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), @@ -377,7 +377,7 @@ TEST_F(ActivationTest, MoveCtor) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index ac1e53eb5..59f267255 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -44,14 +44,13 @@ class ActivationFunctionProviderImpl public: ActivationFunctionProviderImpl() = default; - absl::StatusOr> GetFunction( + absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); - absl::optional matching_overload = - absl::nullopt; + std::optional matching_overload = absl::nullopt; for (const auto& overload : overloads) { if (overload.descriptor.ShapeMatches(descriptor)) { diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index af7f5bc06..53916777a 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -120,7 +120,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation)); @@ -146,7 +146,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 28f9bd1cb..1223ff6d1 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -195,6 +195,7 @@ cc_library( deps = [ "//common:type", "//common:value", + "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc index a9effd229..33f382858 100644 --- a/runtime/internal/convert_constant.cc +++ b/runtime/internal/convert_constant.cc @@ -33,7 +33,7 @@ using ::cel::Constant; struct ConvertVisitor { Allocator<> allocator; - absl::StatusOr operator()(absl::monostate) { + absl::StatusOr operator()(std::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } diff --git a/runtime/internal/runtime_type_provider.cc b/runtime/internal/runtime_type_provider.cc index 1acb52223..40f5ff575 100644 --- a/runtime/internal/runtime_type_provider.cc +++ b/runtime/internal/runtime_type_provider.cc @@ -44,8 +44,10 @@ absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. + auto type = FindWellKnownType(name); + if (type.has_value()) { + return type; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(name); if (desc != nullptr) { return MessageType(desc); @@ -60,9 +62,12 @@ absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::StatusOr> RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::string_view value) const { + auto enum_constant = FindWellKnownTypeEnumConstant(type, value); + if (enum_constant.has_value()) { + return enum_constant; + } const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool_->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. if (enum_desc == nullptr) { return absl::nullopt; } @@ -84,8 +89,10 @@ RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::StatusOr> RuntimeTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. + auto field = FindWellKnownTypeFieldByName(type, name); + if (field.has_value()) { + return field; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(type); if (desc == nullptr) { return absl::nullopt; diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 7e864ecf6..a60b4ce60 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -45,6 +45,7 @@ #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/function_adapter.h" +#include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/regex_precompilation.h" #include "runtime/runtime.h" @@ -73,7 +74,7 @@ struct TestCase { std::string name; std::string expression; absl::flat_hash_map> + std::variant> activation; test::ValueMatcher expected_matcher; bool reference_resolver_enabled = false; @@ -174,6 +175,7 @@ absl::StatusOr> ConfigureRuntimeImpl( if (resolve_references) { CEL_RETURN_IF_ERROR(EnableReferenceResolver( runtime_builder, ReferenceResolverEnabled::kAlways)); + CEL_RETURN_IF_ERROR(extensions::EnableOptionalTypes(runtime_builder)); } if (evaluation_options == Options::kFoldConstants) { CEL_RETURN_IF_ERROR(extensions::EnableConstantFolding(runtime_builder)); @@ -315,6 +317,14 @@ INSTANTIATE_TEST_SUITE_P( {{"condition", BoolValue(false)}}, test::StringValueIs("long_right_hand_string_0123456789"), }, + {"optional_of_long_const_string", + "condition ? optional.of('lhs_short') : " + "optional.of('long_right_hand_string_0123456789')", + {{"condition", BoolValue(false)}}, + test::OptionalValueIs( + test::StringValueIs("long_right_hand_string_0123456789")), + // optional.of is a namespaced function. + /*enable_reference_resolver=*/true}, { "computed_string", "(condition ? 'a.b' : 'b.c') + '.d.e.f'", diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 1e18fef95..7a61208a0 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -139,17 +139,23 @@ struct RuntimeOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/runtime/standard/container_membership_functions.cc b/runtime/standard/container_membership_functions.cc index 9f5ca3755..cc0638429 100644 --- a/runtime/standard/container_membership_functions.cc +++ b/runtime/standard/container_membership_functions.cc @@ -174,15 +174,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(BoolValue(key), descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(BoolValue(key), descriptor_pool, + message_factory, arena, &has)); + if (has.IsTrue()) { + return has; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return has; }; auto intKeyInSet = @@ -191,27 +192,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(IntValue(key), descriptor_pool, message_factory, arena); + Value result; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(key), descriptor_pool, + message_factory, arena, &result)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (result.IsTrue()) { + return result; } Number number = Number::FromInt64(key); if (number.LosslessConvertibleToUint()) { - const auto& result = - map_value.Has(UintValue(number.AsUint()), descriptor_pool, - message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value result_alt; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, + arena, &result_alt)); + if (result_alt.IsTrue()) { + return result_alt; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return result; }; auto stringKeyInSet = @@ -220,14 +220,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = map_value.Has(key, descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value result; + CEL_RETURN_IF_ERROR( + map_value.Has(key, descriptor_pool, message_factory, arena, &result)); + if (result.IsBool()) { + return result; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return result; }; auto uintKeyInSet = @@ -236,26 +238,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - const auto& result = - map_value.Has(UintValue(key), descriptor_pool, message_factory, arena); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(key), descriptor_pool, + message_factory, arena, &has)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (has.IsTrue()) { + return has; } + Value has_alt; Number number = Number::FromUint64(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, + arena, &has_alt)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return has; }; auto doubleKeyInSet = @@ -265,17 +267,21 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Number number = Number::FromDouble(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } if (number.LosslessConvertibleToUint()) { - const auto& result = map_value.Has( - UintValue(number.AsUint()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 50b6e28ea..76e95751b 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -69,7 +69,7 @@ Value FormatDouble(double v, const Function::InvokeContext& context) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( "double format error: ", std::make_error_code(result.ec).message()))); } - absl::string_view out(buf, result.ptr); + absl::string_view out(buf, result.ptr - buf); return StringValue::From(out, arena); #endif } diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index b73085f3c..029897233 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -52,24 +52,14 @@ using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; +using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::Truly; -struct EvaluateResultTestCase { - std::string name; - std::string expression; - bool expected_result; - std::function activation_builder; - - template - friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { - sink.Append(tc.name); - } -}; - const cel::MacroRegistry& GetMacros() { static absl::NoDestructor macros([]() { MacroRegistry registry; @@ -88,6 +78,84 @@ absl::StatusOr ParseWithTestMacros(absl::string_view expression) { return Parse(**src, GetMacros()); } +TEST(StandardRuntimeTest, RecursionLimitExceeded) { + RuntimeOptions opts; + opts.max_recursion_depth = 1; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 1 exceeded"))); +} + +TEST(StandardRuntimeTest, RecursionUnderLimit) { + RuntimeOptions opts; + opts.max_recursion_depth = 2; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, IntValueIs(3)); +} + +TEST(StandardRuntimeTest, RecursionLimitTracksLazyExpressions) { + RuntimeOptions opts; + opts.max_recursion_depth = 8; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(R"cel( + cel.bind(a, 4 + (3 + (2 + 1)), + cel.bind(b, 7 + (6 + (5 + a)), + 9 + (8 + b) + ) + ))cel")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 8 exceeded"))); +} + +struct EvaluateResultTestCase { + std::string name; + std::string expression; + bool expected_result; + std::function activation_builder; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; + class StandardRuntimeTest : public TestWithParam { public: const EvaluateResultTestCase& GetTestCase() { return GetParam(); } diff --git a/testing/testrunner/user_tests/BUILD b/testing/testrunner/user_tests/BUILD index 140b77aef..53cd8f716 100644 --- a/testing/testrunner/user_tests/BUILD +++ b/testing/testrunner/user_tests/BUILD @@ -59,6 +59,7 @@ cc_library( "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], + alwayslink = True, ) cc_library( diff --git a/testutil/BUILD b/testutil/BUILD index 3f1aa4fe8..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") @@ -61,6 +62,26 @@ cc_library( deps = ["//internal:proto_matchers"], ) +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + deps = [ + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "baseline_tests", testonly = True, @@ -86,3 +107,8 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +proto_library( + name = "test_json_names_proto", + srcs = ["test_json_names.proto"], +) diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index 4e56ad485..8ce43e63d 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -28,75 +28,6 @@ namespace cel::test { namespace { -std::string FormatPrimitive(PrimitiveType t) { - switch (t) { - case PrimitiveType::kBool: - return "bool"; - case PrimitiveType::kInt64: - return "int"; - case PrimitiveType::kUint64: - return "uint"; - case PrimitiveType::kDouble: - return "double"; - case PrimitiveType::kString: - return "string"; - case PrimitiveType::kBytes: - return "bytes"; - default: - return ""; - } -} - -std::string FormatType(const TypeSpec& t) { - if (t.has_dyn()) { - return "dyn"; - } else if (t.has_null()) { - return "null"; - } else if (t.has_primitive()) { - return FormatPrimitive(t.primitive()); - } else if (t.has_wrapper()) { - return absl::StrCat("wrapper(", FormatPrimitive(t.wrapper()), ")"); - } else if (t.has_well_known()) { - switch (t.well_known()) { - case WellKnownTypeSpec::kAny: - return "google.protobuf.Any"; - case WellKnownTypeSpec::kDuration: - return "google.protobuf.Duration"; - case WellKnownTypeSpec::kTimestamp: - return "google.protobuf.Timestamp"; - default: - return ""; - } - } else if (t.has_abstract_type()) { - const auto& abs_type = t.abstract_type(); - std::string s = abs_type.name(); - if (!abs_type.parameter_types().empty()) { - absl::StrAppend(&s, "(", - absl::StrJoin(abs_type.parameter_types(), ",", - [](std::string* out, const auto& t) { - absl::StrAppend(out, FormatType(t)); - }), - ")"); - } - return s; - } else if (t.has_type()) { - if (t.type() == TypeSpec()) { - return "type"; - } - return absl::StrCat("type(", FormatType(t.type()), ")"); - } else if (t.has_message_type()) { - return t.message_type().type(); - } else if (t.has_type_param()) { - return t.type_param().type(); - } else if (t.has_list_type()) { - return absl::StrCat("list(", FormatType(t.list_type().elem_type()), ")"); - } else if (t.has_map_type()) { - return absl::StrCat("map(", FormatType(t.map_type().key_type()), ", ", - FormatType(t.map_type().value_type()), ")"); - } - return ""; -} - std::string FormatReference(const cel::Reference& r) { if (r.overload_id().empty()) { return r.name(); @@ -113,7 +44,7 @@ class TypeAdorner : public ExpressionAdorner { auto t = ast_.type_map().find(e.id()); if (t != ast_.type_map().end()) { - absl::StrAppend(&s, "~", FormatType(t->second)); + absl::StrAppend(&s, "~", FormatTypeSpec(t->second)); } if (const auto r = ast_.reference_map().find(e.id()); r != ast_.reference_map().end()) { diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index 33050583f..f4e89706c 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -184,7 +184,7 @@ INSTANTIATE_TEST_SUITE_P( "x~google.protobuf.Timestamp"}, TestCase{TypeSpec(DynTypeSpec()), "x~dyn"}, TestCase{TypeSpec(NullTypeSpec()), "x~null"}, - TestCase{TypeSpec(UnsetTypeSpec()), "x~"}, + TestCase{TypeSpec(UnsetTypeSpec()), "x~*error*"}, TestCase{TypeSpec(MessageTypeSpec("com.example.Type")), "x~com.example.Type"}, TestCase{TypeSpec(AbstractType("optional_type", diff --git a/testutil/test_json_names.proto b/testutil/test_json_names.proto new file mode 100644 index 000000000..a9551085b --- /dev/null +++ b/testutil/test_json_names.proto @@ -0,0 +1,31 @@ +edition = "2024"; + +package cel.cpp.testutil; + +option features.enforce_naming_style = STYLE_LEGACY; + +// This proto tests json_name options +message TestJsonNames { + int32 int32_snake_case_json_name = 1 + [json_name = "int32_snake_case_json_name"]; + int64 int64_camel_case_json_name = 2 [json_name = "int64CamelCaseJsonName"]; + uint32 uint32_default_json_name = 3; + uint64 uint64_custom_json_name = 4 [json_name = "uint64-custom-json-name"]; + + // Collides with normal field name. + string string_json_name_shadows = 5 [json_name = "single_string"]; + string single_string = 6; + + // protoc should fail on cases like these + // double double_json_shadow_default = 7 [json_name = "doubleJsonDefault"] + // double double_json_default = 8; + // double double_json_swapped_a = 7 [json_name = "double_json_swapped_b"]; + // double double_json_swapped_b = 8 [json_name = "double_json_swapped_a"]; + + extensions 100 to 199; +} + +extend TestJsonNames { + int32 int32_snake_case_ext = 100; + int64 int64CamelCaseExt = 101; +} diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..19e9a4844 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,173 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +std::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +std::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +std::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return std::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ diff --git a/tools/BUILD b/tools/BUILD index ceb2befc5..af006a67b 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -204,6 +204,56 @@ cc_library( ], ) +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "//common:operators", + "//internal:status_macros", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//env:config", + "//env:env_runtime", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//tools/testdata:test_policy_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc index 00ab7cb5a..b5bba3ffe 100644 --- a/tools/branch_coverage.cc +++ b/tools/branch_coverage.cc @@ -71,7 +71,7 @@ struct OtherNode { // Representation for coverage of an AST node. struct CoverageNode { int evaluate_count; - absl::variant kind; + std::variant kind; }; const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, diff --git a/tools/cel_unparser.cc b/tools/cel_unparser.cc index 28a1187bb..741d91208 100644 --- a/tools/cel_unparser.cc +++ b/tools/cel_unparser.cc @@ -150,6 +150,8 @@ class Unparser { // - a ternary conditional operator bool IsBinaryOrTernaryOperator(const Expr& expr); + bool IsLogicalOperator(absl::string_view op); + template void Print(Ts&&... args) { absl::StrAppend(&output_, std::forward(args)...); @@ -436,6 +438,24 @@ absl::Status Unparser::VisitUnary(const Expr::Call& expr, absl::Status Unparser::VisitBinary(const Expr::Call& expr, const std::string& op) { + if (expr.args_size() < 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& fun = expr.function(); + if (IsLogicalOperator(fun)) { + for (int i = 0; i < expr.args_size(); ++i) { + if (i > 0) { + Print(kSpace, op, kSpace); + } + const auto& arg = expr.args(i); + bool arg_paren = IsComplexOperatorWithRespectTo(arg, fun); + CEL_RETURN_IF_ERROR(VisitMaybeNested(arg, arg_paren)); + } + return absl::OkStatus(); + } + if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); @@ -443,7 +463,6 @@ absl::Status Unparser::VisitBinary(const Expr::Call& expr, const auto& lhs = expr.args(0); const auto& rhs = expr.args(1); - const auto& fun = expr.function(); // add parens if the current operator is lower precedence than the lhs expr // operator. @@ -549,6 +568,10 @@ bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); } +bool Unparser::IsLogicalOperator(absl::string_view op) { + return op == CelOperator::LOGICAL_AND || op == CelOperator::LOGICAL_OR; +} + } // namespace absl::StatusOr Unparse(const Expr& expr, diff --git a/tools/cel_unparser_test.cc b/tools/cel_unparser_test.cc index 4cba4ce4d..aca6e91fd 100644 --- a/tools/cel_unparser_test.cc +++ b/tools/cel_unparser_test.cc @@ -67,6 +67,22 @@ INSTANTIATE_TEST_SUITE_P( {// Empty Expr error {"", absl::InvalidArgumentError("Unsupported Expr")}, + // Logical operators with too few arguments (single argument) + { + R"pb( + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + { + R"pb( + call_expr { + function: "_||_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + // Constants {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, {"const_expr{bool_value: true}", "true"}, @@ -619,6 +635,7 @@ TEST_P(UnparserTestTextExpr, Test) { options.add_macro_calls = true; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = true; ASSERT_OK_AND_ASSIGN(ParsedExpr result, Parse(GetParam().expr, "unparser", options)); @@ -779,6 +796,8 @@ INSTANTIATE_TEST_SUITE_P( {"has(a.`b.c`)", ""}, {"a.`b/c`", ""}, {"a.?`b/c`", ""}, + {"a && b && c && d", ""}, + {"a || b || c || d", ""}, })); } // namespace diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 10c0b1cb8..2ee226859 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -127,7 +127,7 @@ class ObjectStringIndexedMapImpl : public CelMap { arena_, **it, schema_, object_, arena_)); } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr ListKeys() const override { return &keys_; } @@ -188,7 +188,7 @@ absl::optional FlatBuffersMapImpl::operator[]( } auto field = keys_.fields->LookupByKey(cel_key.StringOrDie().value().data()); if (field == nullptr) { - return absl::nullopt; + return std::nullopt; } switch (field->type()->base_type()) { case reflection::Byte: @@ -323,15 +323,15 @@ absl::optional FlatBuffersMapImpl::operator[]( } default: // Unsupported vector base types - return absl::nullopt; + return std::nullopt; } break; } default: // Unsupported types: enums, unions, arrays - return absl::nullopt; + return std::nullopt; } - return absl::nullopt; + return std::nullopt; } const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..8c89ee2f0 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,459 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { + +using ::google::api::expr::common::CelOperator; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +class ProtoToPredicateBuilder final : private ExprFactory { + public: + ProtoToPredicateBuilder() : id_(1) {} + + absl::StatusOr Build(absl::string_view input_name, + const Message& message) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(message, base_expr, predicates)); + Expr root = LogicalAnd(predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::string_view input_name, + absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + message_asts.reserve(messages.size()); + for (const auto* message : messages) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(*message, base_expr, predicates)); + message_asts.push_back(LogicalAnd(predicates)); + } + + return Ast(LogicalOr(message_asts), std::move(source_info_)); + } + + private: + // Retrieves the "match_path" string option from the field options if + // defined, returning an empty string otherwise. + std::string GetMatchPath(const ::google::protobuf::FieldDescriptor* field) { + const ::google::protobuf::Message& options = field->options(); + const ::google::protobuf::Reflection* refl = options.GetReflection(); + std::vector fields; + refl->ListFields(options, &fields); + for (const auto* f : fields) { + if (f->name() == "match_path") { + return refl->GetString(options, f); + } + } + return ""; + } + + // Parses a dot-separated string representation of a path (e.g. "dest.region") + // and builds a corresponding select chain AST. + Expr ParseAndBuildPath(absl::string_view path_str) { + std::vector parts = absl::StrSplit(path_str, '.'); + Expr e = NewIdent(NextId(), parts[0]); + for (size_t i = 1; i < parts.size(); ++i) { + e = NewSelect(NextId(), std::move(e), parts[i]); + } + return e; + } + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Log a warning as message should be handled by Walk. + ABSL_LOG(WARNING) << "PrimitiveToExpr: Unhandled field type: " + << FieldDescriptor::TypeName(field->type()); + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args = {std::move(lhs), std::move(rhs)}; + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp(CelOperator::EQUALS, std::move(lhs), + std::move(rhs)); + } + + Expr LogicalOr(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_OR, exprs); + } + + Expr LogicalAnd(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_AND, exprs); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr LogicalOp(absl::string_view op, std::vector& exprs) { + if (exprs.empty()) { + return NewBoolConst(NextId(), true); + } + if (exprs.size() == 1) { + return std::move(exprs[0]); + } + return NewCall(NextId(), op, std::move(exprs)); + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all key-value pairs + // specified in the policy are present in the input map field: + // "key" in input.map && input.map["key"] == value + absl::Status WalkMapField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, const Expr& base_expr, + int size, std::vector& predicates) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + + struct MapEntry { + const Message* message; + }; + std::vector entries; + entries.reserve(size); + for (int i = 0; i < size; ++i) { + entries.push_back({&reflection->GetRepeatedMessage(message, field, i)}); + } + + if (!entries.empty()) { + const Reflection* const entry_ref = entries[0].message->GetReflection(); + std::sort(entries.begin(), entries.end(), + [entry_ref, key_field](const MapEntry& a, const MapEntry& b) { + switch (key_field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return entry_ref->GetInt32(*a.message, key_field) < + entry_ref->GetInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_INT64: + return entry_ref->GetInt64(*a.message, key_field) < + entry_ref->GetInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT32: + return entry_ref->GetUInt32(*a.message, key_field) < + entry_ref->GetUInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT64: + return entry_ref->GetUInt64(*a.message, key_field) < + entry_ref->GetUInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_BOOL: + return !entry_ref->GetBool(*a.message, key_field) && + entry_ref->GetBool(*b.message, key_field); + case FieldDescriptor::CPPTYPE_STRING: + return entry_ref->GetString(*a.message, key_field) < + entry_ref->GetString(*b.message, key_field); + default: + return false; + } + }); + } + + std::vector map_checks; + map_checks.reserve(size); + for (const auto& entry : entries) { + const Message& entry_msg = *entry.message; + const Reflection* const entry_ref = entry_msg.GetReflection(); + + Expr key_expr = PrimitiveToExpr(entry_msg, entry_ref, key_field); + + // Represents `"key" in input.map` to assert the key exists. + Expr in_check = NewCall(NextId(), CelOperator::IN, + std::vector{key_expr, map_path}); + // Represents `input.map["key"]` to lookup the value. + Expr lookup_path = NewCall(NextId(), CelOperator::INDEX, + std::vector{map_path, key_expr}); + + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& value_msg = + entry_ref->GetMessage(entry_msg, value_field); + std::vector val_predicates; + CEL_RETURN_IF_ERROR(Walk(value_msg, lookup_path, val_predicates)); + + if (!val_predicates.empty()) { + // Represents `"key" in input.map && (nested message fields check...)` + map_checks.push_back(std::move(in_check)); + map_checks.insert(map_checks.end(), + std::make_move_iterator(val_predicates.begin()), + std::make_move_iterator(val_predicates.end())); + } else { + // Represents `"key" in input.map` if nested message is empty. + map_checks.push_back(std::move(in_check)); + } + } else { + Expr value_expr = PrimitiveToExpr(entry_msg, entry_ref, value_field); + // Represents `input.map["key"] == value` + Expr eq_check = + ConstructEquality(std::move(lookup_path), std::move(value_expr)); + + // Represents `"key" in input.map && input.map["key"] == value` + map_checks.push_back(std::move(in_check)); + map_checks.push_back(std::move(eq_check)); + } + } + + predicates.push_back(LogicalAnd(map_checks)); + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds predicates for a repeated field: + // - Repeated Messages are mapped to a logical OR (||) of the generated + // predicates for each message. + // - Repeated Primitives are mapped either to: + // - `lhs in [values]` if a "match_path" option is specified. + // - `value in input.field` conjoined with && for each value otherwise. + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const Expr& base_expr, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + std::vector message_asts; + message_asts.reserve(size); + for (int i = 0; i < size; ++i) { + const Message& sub_message = + reflection->GetRepeatedMessage(message, field, i); + std::vector sub_predicates; + Expr sub_base = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, sub_base, sub_predicates)); + message_asts.push_back(LogicalAnd(sub_predicates)); + } + // Represents alternate message predicates conjoined with OR: `msg_1 || + // msg_2 || ...` + predicates.push_back(LogicalOr(message_asts)); + return absl::OkStatus(); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::string match_path_val = GetMatchPath(field); + if (!match_path_val.empty()) { + Expr lhs = ParseAndBuildPath(match_path_val); + // Represents `lhs in [values]` check (e.g. `dest.region in ["us-east", + // "us-west"]`). + predicates.push_back( + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(lhs), std::move(literal_list)})); + return absl::OkStatus(); + } + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + std::vector element_checks; + element_checks.reserve(size); + for (int i = 0; i < size; ++i) { + Expr elem_expr = RepeatedPrimitiveToExpr(message, reflection, field, i); + // Represents `value in input.field` check. + Expr in_check = + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(elem_expr), map_path}); + element_checks.push_back(std::move(in_check)); + } + // Represents `"val1" in input.list && "val2" in input.list && ...` + predicates.push_back(LogicalAnd(element_checks)); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, const Expr& base_expr, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkMapField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkRepeatedField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, field_path, predicates)); + } else { + // Primitive field: base_expr.field == + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + predicates.push_back( + ConstructEquality(std::move(field_path), + PrimitiveToExpr(message, reflection, field))); + } + } + return absl::OkStatus(); + } + + ExprId id_; + SourceInfo source_info_; +}; + +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, message); +} + +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..ed01cb1e8 --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +// Translates a Protocol Buffer message into a CEL AST representing a predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message); + +// Translates a list of Protocol Buffer messages into a CEL AST representing a +// conjoined or alternate predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..80ad140c7 --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,593 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/value.h" +#include "env/config.h" +#include "env/env_runtime.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/testdata/test_policy.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/json/json.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "google.api.expr.runtime.TestMessage" +)"; + +TestMessage ParseTestMessage(absl::string_view textproto) { + TestMessage msg; + google::protobuf::TextFormat::ParseFromString(textproto, &msg); + return msg; +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +struct TestCase { + std::string name; + std::vector input_textprotos; + std::string expected_unparsed; + std::string eval_textproto; + bool expected_eval_result = true; + // If true, skip the eval step of the test. This is useful for tests where + // the expected expression does not share the same type structure as the + // input proto, such as empty messages. + bool skip_eval = false; +}; + +class ProtoToPredicateTest : public ::testing::TestWithParam {}; + +TEST_P(ProtoToPredicateTest, ConformanceTests) { + const TestCase& param = GetParam(); + + std::vector input_messages; + input_messages.reserve(param.input_textprotos.size()); + for (const auto& proto_str : param.input_textprotos) { + input_messages.push_back(ParseTestMessage(proto_str)); + } + + std::vector ptr_messages; + ptr_messages.reserve(input_messages.size()); + for (const auto& msg : input_messages) { + ptr_messages.push_back(&msg); + } + + absl::StatusOr ast_or; + if (input_messages.size() == 1) { + ast_or = ProtoToPredicateAst("input", input_messages[0]); + } else { + ast_or = ProtoToPredicateAst("input", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); + + if (!param.skip_eval) { + TestMessage eval_msg = ParseTestMessage(param.eval_textproto); + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, eval_msg)); + EXPECT_EQ(eval_result, param.expected_eval_result); + } +} + +INSTANTIATE_TEST_SUITE_P( + ProtoToPredicateSubCases, ProtoToPredicateTest, + testing::Values( + TestCase{ + .name = "EmptyMessageTest", + .input_textprotos = {""}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "EmptyMessagesListTest", + .input_textprotos = {}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "PrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 42 string_value: "hello" + )pb", + }, + TestCase{ + .name = "AllPrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.int64_value == 43 && " + "input.uint32_value == 44u && input.uint64_value == 45u && " + "input.float_value == 46.5 && input.double_value == 47.5 && " + "input.string_value == \"hello\" && " + "input.bytes_value == b\"world\" && " + "input.bool_value == true && " + "input.enum_value == 1", + .eval_textproto = R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb", + }, + TestCase{ + .name = "NestedMessageTest", + .input_textprotos = {R"pb( + message_value: { int32_value: 42 } + )pb"}, + .expected_unparsed = "input.message_value.int32_value == 42", + .eval_textproto = R"pb( + message_value: { int32_value: 42 } + )pb", + }, + TestCase{ + .name = "RepeatedFieldTest", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 2 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldSingleElementTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + )pb"}, + .expected_unparsed = "42 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldEmptyTest", + .input_textprotos = {R"pb( + int32_list: [] + )pb"}, + .expected_unparsed = "true", + .eval_textproto = R"pb( + int32_list: [] + )pb", + }, + TestCase{ + .name = "ListFieldEvalNegative", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 3 ] + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "SingleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb"}, + .expected_unparsed = "42 in input.int32_list && " + "43 in input.int64_list && " + "44u in input.uint32_list && " + "45u in input.uint64_list && " + "46.5 in input.float_list && " + "47.5 in input.double_list && " + "\"hello\" in input.string_list && " + "b\"world\" in input.bytes_list && " + "true in input.bool_list && " + "1 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb", + }, + TestCase{ + .name = "MultipleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb"}, + .expected_unparsed = + "42 in input.int32_list && 142 in input.int32_list && " + "43 in input.int64_list && 143 in input.int64_list && " + "44u in input.uint32_list && 144u in input.uint32_list && " + "45u in input.uint64_list && 145u in input.uint64_list && " + "46.5 in input.float_list && 146.5 in input.float_list && " + "47.5 in input.double_list && 147.5 in input.double_list && " + "\"hello\" in input.string_list && \"universe\" in " + "input.string_list && " + "b\"world\" in input.bytes_list && b\"space\" in " + "input.bytes_list && " + "true in input.bool_list && false in input.bool_list && " + "1 in input.enum_list && 2 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb", + }, + TestCase{ + .name = "MapFieldTest", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb", + }, + TestCase{ + .name = "MapFieldEvalNegativeVal", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 3 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldEvalNegativeNoKey", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldIntKeySortingTest", + .input_textprotos = {R"pb( + int32_int32_map: { key: 10 value: 100 } + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + )pb"}, + .expected_unparsed = "5 in input.int32_int32_map && " + "input.int32_int32_map[5] == 50 && " + "8 in input.int32_int32_map && " + "input.int32_int32_map[8] == 80 && " + "10 in input.int32_int32_map && " + "input.int32_int32_map[10] == 100", + .eval_textproto = R"pb( + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + int32_int32_map: { key: 10 value: 100 } + )pb", + }, + TestCase{ + .name = "MultipleMessagesTest", + .input_textprotos = {R"pb( + int32_value: 42 + )pb", + R"pb( + int32_value: 41 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 41 string_value: "hello" + )pb", + }, + TestCase{ + .name = "RepeatedMessageFieldTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 } + , { int32_value: 43 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42 || " + "input.message_list.int32_value == 43", + .skip_eval = true, + }, + TestCase{ + .name = "RepeatedMessageSingleElementTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42", + .skip_eval = true, + })); + +struct PolicyTestCase { + std::string name; + std::string json_input; + std::string expected_unparsed; +}; + +class PolicyJsonTest : public ::testing::TestWithParam {}; + +TEST_P(PolicyJsonTest, Conformance) { + const PolicyTestCase& param = GetParam(); + + cel::cpp::tools::Policy policy; + google::protobuf::json::ParseOptions options; + options.ignore_unknown_fields = true; + auto status = + google::protobuf::json::JsonStringToMessage(param.json_input, &policy, options); + ASSERT_THAT(status, IsOk()) << "Failed to parse JSON: " << param.json_input; + + absl::StatusOr ast_or; + std::vector ptr_messages; + ptr_messages.reserve(policy.destinations_size()); + for (const auto& dest : policy.destinations()) { + ptr_messages.push_back(&dest); + } + + if (ptr_messages.empty()) { + auto parsed_expr_or = google::api::expr::parser::Parse("false"); + ASSERT_THAT(parsed_expr_or, IsOk()); + auto ast_ptr_or = cel::CreateAstFromParsedExpr(*parsed_expr_or); + ASSERT_THAT(ast_ptr_or, IsOk()); + ast_or = std::move(**ast_ptr_or); + } else if (ptr_messages.size() == 1) { + ast_or = ProtoToPredicateAst("dest", *ptr_messages[0]); + } else { + ast_or = ProtoToPredicateAst("dest", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); +} + +INSTANTIATE_TEST_SUITE_P( + PolicyJsonSubCases, PolicyJsonTest, + testing::Values( + PolicyTestCase{ + .name = "SimpleMatch", + .json_input = + R"({ "destinations": [ { "agent": { "id": "agent-007" } } ] })", + .expected_unparsed = "dest.agent.name == \"agent-007\"", + }, + PolicyTestCase{ + .name = "MultipleFields", + .json_input = + R"({ "destinations": [ { + "tool": { + "name": "admin_tool", + "annotations": { + "read_only_hint": false + } + } + } + ] })", + .expected_unparsed = + "dest.tool.name == \"admin_tool\" && " + "dest.tool.annotations.read_only_hint == false", + }, + PolicyTestCase{ + .name = "RepeatedMessages", + .json_input = + R"({ "destinations": [ + { "agent": { "id": "worker-1" } }, + { "agent": { "id": "worker-2" } }, + ] })", + .expected_unparsed = "dest.agent.name == \"worker-1\" || " + "dest.agent.name == \"worker-2\"", + }, + PolicyTestCase{ + .name = "RepeatedPrimitiveArraySingleElement", + .json_input = + R"({ "destinations": [ { + "tool": { + "role_members": { + "admin": { + "principals": ["alice"] + } + } + } + } ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "\"alice\" in dest.tool.role_members[\"admin\"].principals", + }, + PolicyTestCase{ + .name = "RepeatedArrayEmpty", + .json_input = R"({ "destinations": [ { "tool": { } } ] })", + .expected_unparsed = "true", + }, + PolicyTestCase{ + .name = "MapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "name": "shell", + "labels": { + "cluster": "us-central1", + "project": "dev" + } + } + } ] })", + .expected_unparsed = + "dest.tool.name == \"shell\" && \"cluster\" in " + "dest.tool.labels && dest.tool.labels[\"cluster\"] == " + "\"us-central1\" && \"project\" in dest.tool.labels && " + "dest.tool.labels[\"project\"] == \"dev\"", + }, + PolicyTestCase{ + .name = "NestedMapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "role_members": { + "admin": { + "all_users": true + } + } + } } + ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "dest.tool.role_members[\"admin\"].all_users == true", + }, + PolicyTestCase{ + .name = "EmptyPolicy", + .json_input = "{}", + .expected_unparsed = "false", + })); + +} // namespace +} // namespace cel::tools diff --git a/tools/testdata/BUILD b/tools/testdata/BUILD index 493f0ff2f..c88c9c478 100644 --- a/tools/testdata/BUILD +++ b/tools/testdata/BUILD @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@com_github_google_flatbuffers//:build_defs.bzl", - "flatbuffer_library_public", -) +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") licenses(["notice"]) @@ -46,3 +45,15 @@ cc_library( linkstatic = True, deps = ["@com_github_google_flatbuffers//:runtime_cc"], ) + +proto_library( + name = "test_policy_proto", + srcs = ["test_policy.proto"], + visibility = ["//tools:__subpackages__"], +) + +cc_proto_library( + name = "test_policy_cc_proto", + visibility = ["//tools:__subpackages__"], + deps = [":test_policy_proto"], +) diff --git a/tools/testdata/test_policy.proto b/tools/testdata/test_policy.proto new file mode 100644 index 000000000..b5d424c04 --- /dev/null +++ b/tools/testdata/test_policy.proto @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test schema representing client-configured policies. +// It is used by the `proto_to_predicate` tool to translate Protobuf policies +// into CEL predicates. +edition = "2023"; + +package cel.cpp.tools; + +option cc_enable_arenas = true; + +// Represents the targeted client agent. +message Agent { + string name = 1 [json_name = "id"]; +} + +// Specifies additional metadata tool annotations. +message ToolAnnotations { + bool read_only_hint = 1; +} + +// Represents a mapped nested message entry value inside map fields. +message Members { + repeated string principals = 1; + + repeated string regions = 2; + + bool all_users = 3; + + bool all_authenticated_users = 4; +} + +// Represents a metadata tool block. +message Tool { + // The name of the tool. + string name = 1; + + // Additional metadata annotations for the tool. + ToolAnnotations annotations = 2; + + // A string-to-string map, transpiled as conjoined existence and equality + // checks. + map labels = 3; + + // A map with string keys representing roles and Member instances as values. + map role_members = 4; +} + +// Represents a policy mapping destination block. +message Target { + oneof kind { + Agent agent = 1; + Tool tool = 2; + } +} + +// Represents the top-level policy containing multiple alternate destination +// rules. +message Policy { + repeated Target destinations = 1; +} diff --git a/validator/BUILD b/validator/BUILD new file mode 100644 index 000000000..9910a6b97 --- /dev/null +++ b/validator/BUILD @@ -0,0 +1,214 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "validator", + srcs = ["validator.cc"], + hdrs = ["validator.h"], + deps = [ + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:navigable_ast", + "//common:source", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "validator_test", + srcs = ["validator_test.cc"], + deps = [ + ":validator", + "//checker:type_check_issue", + "//common:ast", + "//common:expr", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "timestamp_literal_validator_test", + srcs = ["timestamp_literal_validator_test.cc"], + deps = [ + ":timestamp_literal_validator", + ":validator", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "timestamp_literal_validator", + srcs = ["timestamp_literal_validator.cc"], + hdrs = ["timestamp_literal_validator.h"], + deps = [ + ":validator", + "//common:constant", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:time", + "//tools:navigable_ast", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "ast_depth_validator", + srcs = ["ast_depth_validator.cc"], + hdrs = ["ast_depth_validator.h"], + deps = [ + ":validator", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "homogeneous_literal_validator", + srcs = ["homogeneous_literal_validator.cc"], + hdrs = ["homogeneous_literal_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "regex_validator", + srcs = ["regex_validator.cc"], + hdrs = ["regex_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:constant", + "//common:expr", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:re2_options", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "homogeneous_literal_validator_test", + srcs = ["homogeneous_literal_validator_test.cc"], + deps = [ + ":homogeneous_literal_validator", + ":validator", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:strings", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "ast_depth_validator_test", + srcs = ["ast_depth_validator_test.cc"], + deps = [ + ":ast_depth_validator", + ":validator", + "//checker:type_check_issue", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/log:absl_check", + ], +) + +cc_test( + name = "regex_validator_test", + srcs = ["regex_validator_test.cc"], + deps = [ + ":regex_validator", + ":validator", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "comprehension_nesting_validator", + srcs = ["comprehension_nesting_validator.cc"], + hdrs = ["comprehension_nesting_validator.h"], + deps = [ + ":validator", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "comprehension_nesting_validator_test", + srcs = ["comprehension_nesting_validator_test.cc"], + deps = [ + ":comprehension_nesting_validator", + ":validator", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + +licenses(["notice"]) diff --git a/validator/ast_depth_validator.cc b/validator/ast_depth_validator.cc new file mode 100644 index 000000000..0f6b8d93d --- /dev/null +++ b/validator/ast_depth_validator.cc @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include "absl/strings/str_cat.h" +#include "validator/validator.h" + +namespace cel { + +Validation AstDepthValidator(int max_depth) { + return Validation([max_depth](ValidationContext& context) { + int height = context.navigable_ast().Root().height(); + if (height > max_depth) { + context.ReportError(absl::StrCat("AST depth ", height, + " exceeds maximum of ", max_depth)); + return false; + } + return true; + }); +} + +} // namespace cel diff --git a/common/type_factory.h b/validator/ast_depth_validator.h similarity index 58% rename from common/type_factory.h rename to validator/ast_depth_validator.h index 33829ea8b..a640af12e 100644 --- a/common/type_factory.h +++ b/validator/ast_depth_validator.h @@ -1,4 +1,4 @@ -// Copyright 2023 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,19 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#include "validator/validator.h" namespace cel { -// `TypeFactory` is the preferred way for constructing compound types such as -// lists, maps, structs, and opaques. It caches types and avoids constructing -// them multiple times. -class TypeFactory { - public: - virtual ~TypeFactory() = default; -}; +// Returns a `Validation` that checks the AST depth is less than or equal to +// max_depth. +Validation AstDepthValidator(int max_depth); } // namespace cel -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ diff --git a/validator/ast_depth_validator_test.cc b/validator/ast_depth_validator_test.cc new file mode 100644 index 000000000..eda59b40d --- /dev/null +++ b/validator/ast_depth_validator_test.cc @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "checker/type_check_issue.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +std::unique_ptr CreateCompiler() { + auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(builder); + ABSL_CHECK_OK((*builder)->AddLibrary(StandardCompilerLibrary())); + auto compiler = (*builder)->Build(); + ABSL_CHECK_OK(compiler); + return *std::move(compiler); +} + +TEST(AstDepthValidatorTest, Basic) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + 2 + 3")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(2)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 3 exceeds maximum of 2")))); +} + +TEST(AstDepthValidatorTest, Nested) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile("1 + (2 + (3 + (4 + 5)))")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(4)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 5 exceeds maximum of 4")))); +} + +} // namespace +} // namespace cel diff --git a/validator/comprehension_nesting_validator.cc b/validator/comprehension_nesting_validator.cc new file mode 100644 index 000000000..81c47cbc3 --- /dev/null +++ b/validator/comprehension_nesting_validator.cc @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool IsEmptyRangeComprehension(const NavigableAstNode& node) { + ABSL_DCHECK(node.expr()->has_comprehension_expr()); + const auto& comp = node.expr()->comprehension_expr(); + return comp.has_iter_range() && comp.iter_range().has_list_expr() && + comp.iter_range().list_expr().elements().empty(); +} + +} // namespace + +Validation ComprehensionNestingLimitValidator(int limit) { + return Validation( + [limit](ValidationContext& context) -> bool { + bool is_valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kComprehension) { + continue; + } + if (IsEmptyRangeComprehension(node)) { + continue; + } + + int count = 0; + const NavigableAstNode* current = &node; + while (current != nullptr) { + if (current->node_kind() == NodeKind::kComprehension && + !IsEmptyRangeComprehension(*current)) { + count++; + } + current = current->parent(); + } + if (count > limit) { + context.ReportErrorAt( + node.expr()->id(), + absl::StrCat("comprehension nesting level of ", count, + " exceeds limit of ", limit)); + is_valid = false; + break; + } + } + return is_valid; + }, + "cel.validator.comprehension_nesting_limit"); +} + +} // namespace cel diff --git a/validator/comprehension_nesting_validator.h b/validator/comprehension_nesting_validator.h new file mode 100644 index 000000000..4dab78db0 --- /dev/null +++ b/validator/comprehension_nesting_validator.h @@ -0,0 +1,31 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that comprehensions are not nested beyond +// the specified limit. +// +// Comprehensions with an empty iteration range (e.g. `cel.bind`) do not count +// towards the nesting limit. +Validation ComprehensionNestingLimitValidator(int limit); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ diff --git a/validator/comprehension_nesting_validator_test.cc b/validator/comprehension_nesting_validator_test.cc new file mode 100644 index 000000000..c1b47f82d --- /dev/null +++ b/validator/comprehension_nesting_validator_test.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + return builder->Build(); +} + +struct TestCase { + std::string expression; + int limit; + bool valid; + std::string error_substr = ""; +}; + +using ComprehensionNestingValidatorTest = testing::TestWithParam; + +TEST_P(ComprehensionNestingValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(ComprehensionNestingLimitValidator(test_case.limit)); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + auto result_or = compiler->Compile(test_case.expression); + if (!result_or.ok()) { + GTEST_SKIP() << "Expression failed to compile: " << test_case.expression + << " " << result_or.status().message(); + } + auto result = std::move(result_or).value(); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression + << " Limit: " << test_case.limit; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionNestingValidatorTest, ComprehensionNestingValidatorTest, + testing::Values( + TestCase{"[1, 2].all(x, x > 0)", 1, true}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 2, true}, + // Empty range comprehension (does not count) + TestCase{"[].all(x, [1, 2].all(y, y > 0))", 1, true}, + TestCase{"cel.bind(x, [1, 2].all(y, y > 0), [1, 2].all(z, z > 0))", 1, + true}, + // Nested empty range comprehensions + TestCase{"[].all(x, [].all(y, true))", 0, true}, + // Deeply nested mixed + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 2, true})); + +} // namespace +} // namespace cel diff --git a/validator/homogeneous_literal_validator.cc b/validator/homogeneous_literal_validator.cc new file mode 100644 index 000000000..4a490dea2 --- /dev/null +++ b/validator/homogeneous_literal_validator.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool InExemptFunction(const NavigableAstNode& node, + const std::vector& exempt_functions) { + const NavigableAstNode* parent = node.parent(); + while (parent != nullptr) { + if (parent->node_kind() == NodeKind::kCall) { + absl::string_view fn_name = parent->expr()->call_expr().function(); + for (const auto& exempt : exempt_functions) { + if (exempt == fn_name) { + return true; + } + } + } + parent = parent->parent(); + } + return false; +} + +bool IsOptional(const TypeSpec& t) { + return t.has_abstract_type() && t.abstract_type().name() == "optional_type"; +} + +const TypeSpec& GetOptionalParameter(const TypeSpec& t) { + return t.abstract_type().parameter_types()[0]; +} + +void TypeMismatch(ValidationContext& context, int64_t id, + const TypeSpec& expected, const TypeSpec& actual) { + context.ReportErrorAt( + id, absl::StrCat("expected type '", FormatTypeSpec(expected), + "' but found '", FormatTypeSpec(actual), "'")); +} + +bool TypeEquiv(const TypeSpec& a, const TypeSpec& b) { + if (a == b) { + return true; + } + + if (a.has_error() || b.has_error()) { + // Don't report mismatch if there's an error (type checking failed for the + // expression). + return true; + } + + if (a.has_wrapper() && b.has_primitive()) { + return a.wrapper() == b.primitive(); + } else if (a.has_primitive() && b.has_wrapper()) { + return a.primitive() == b.wrapper(); + } + + if (a.has_list_type() && b.has_list_type()) { + return TypeEquiv(a.list_type().elem_type(), b.list_type().elem_type()); + } + + if (a.has_map_type() && b.has_map_type()) { + return TypeEquiv(a.map_type().key_type(), b.map_type().key_type()) && + TypeEquiv(a.map_type().value_type(), b.map_type().value_type()); + } + + if (a.has_abstract_type() && b.has_abstract_type() && + a.abstract_type().name() == b.abstract_type().name() && + a.abstract_type().parameter_types().size() == + b.abstract_type().parameter_types().size()) { + for (int i = 0; i < a.abstract_type().parameter_types().size(); ++i) { + if (!TypeEquiv(a.abstract_type().parameter_types()[i], + b.abstract_type().parameter_types()[i])) { + return false; + } + } + return true; + } + + return false; +} + +} // namespace + +Validation HomogeneousLiteralValidator( + std::vector exempt_functions) { + return Validation([exempt_functions = std::move(exempt_functions)]( + ValidationContext& context) -> bool { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kList) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& list_expr = node.expr()->list_expr(); + const auto& elements = list_expr.elements(); + const TypeSpec* expected_type = nullptr; + + for (const auto& element : elements) { + int64_t id = element.expr().id(); + const TypeSpec& actual_type = context.ast().GetTypeOrDyn(id); + const TypeSpec* type_to_check = &actual_type; + + if (element.optional() && IsOptional(actual_type)) { + type_to_check = &GetOptionalParameter(actual_type); + } + + if (expected_type == nullptr) { + expected_type = type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_type, *type_to_check))) { + TypeMismatch(context, id, *expected_type, *type_to_check); + valid = false; + break; + } + } + } else if (node.node_kind() == NodeKind::kMap) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& map_expr = node.expr()->map_expr(); + const auto& entries = map_expr.entries(); + const TypeSpec* expected_key_type = nullptr; + const TypeSpec* expected_value_type = nullptr; + + for (const auto& entry : entries) { + int64_t key_id = entry.key().id(); + int64_t val_id = entry.value().id(); + const TypeSpec& actual_key_type = context.ast().GetTypeOrDyn(key_id); + const TypeSpec& actual_val_type = context.ast().GetTypeOrDyn(val_id); + const TypeSpec* key_type_to_check = &actual_key_type; + const TypeSpec* val_type_to_check = &actual_val_type; + + if (entry.optional() && IsOptional(actual_val_type)) { + val_type_to_check = &GetOptionalParameter(actual_val_type); + } + + if (expected_key_type == nullptr) { + expected_key_type = key_type_to_check; + expected_value_type = val_type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_key_type, *key_type_to_check))) { + TypeMismatch(context, key_id, *expected_key_type, + *key_type_to_check); + valid = false; + break; + } + if (!(TypeEquiv(*expected_value_type, *val_type_to_check))) { + TypeMismatch(context, val_id, *expected_value_type, + *val_type_to_check); + valid = false; + break; + } + } + } + } + return valid; + }); +} + +} // namespace cel diff --git a/validator/homogeneous_literal_validator.h b/validator/homogeneous_literal_validator.h new file mode 100644 index 000000000..e37648a25 --- /dev/null +++ b/validator/homogeneous_literal_validator.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ + +#include +#include + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that all literals in map or list literals +// are the same type. If the list or map is part of an argument to an exempted +// function, it is not checked. +Validation HomogeneousLiteralValidator( + std::vector exempt_functions); + +inline Validation HomogeneousLiteralValidator() { + // Default to exempting the strings extension "format" function. + return HomogeneousLiteralValidator({"format"}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ diff --git a/validator/homogeneous_literal_validator_test.cc b/validator/homogeneous_literal_validator_test.cc new file mode 100644 index 000000000..b027fa4b0 --- /dev/null +++ b/validator/homogeneous_literal_validator_test.cc @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); + builder->AddLibrary(extensions::StringsCompilerLibrary()).IgnoreError(); + cel::Type message_type = cel::Type::Message( + builder->GetCheckerBuilder().descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("msg", message_type))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using HomogeneousLiteralValidatorTest = testing::TestWithParam; + +TEST_P(HomogeneousLiteralValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(HomogeneousLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + HomogeneousLiteralValidatorTest, HomogeneousLiteralValidatorTest, + testing::Values( + // Lists + TestCase{"[1, 2, 3]", true}, TestCase{"['a', 'b', 'c']", true}, + TestCase{"[1, 'a']", false, "expected type 'int' but found 'string'"}, + TestCase{"[1, 2, 'a']", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1], [2]]", true}, + TestCase{"[[1], ['a']]", false, + "expected type 'list(int)' but found 'list(string)'"}, + + // Dyn casts + TestCase{"[dyn(1), dyn('a')]", true, ""}, + TestCase{"[dyn(1), 2]", false, "expected type 'dyn' but found 'int'"}, + + // Maps + TestCase{"{1: 'a', 2: 'b'}", true}, TestCase{"{'a': 1, 'b': 2}", true}, + TestCase{"{1: 'a', 'b': 2}", false, + "expected type 'int' but found 'string'"}, + TestCase{"{1: 'a', 2: 3}", false, + "expected type 'string' but found 'int'"}, + + // Optionals + TestCase{"[optional.of(1), optional.of(2)]", true}, + TestCase{"[optional.of(1), optional.of('b')]", false, + "expected type 'optional_type(int)' but found " + "'optional_type(string)'"}, + + TestCase{"[?optional.of(1), ?optional.of(2)]", true}, + TestCase{"[?optional.of(1), ?optional.of('a')]", false, + "expected type 'int' but found 'string'"}, + TestCase{"{?1: optional.of('a'), ?2: optional.none()}", true}, + TestCase{"{?1: optional.of('a'), ?2: optional.of(1)}", false, + "expected type 'string' but found 'int'"}, + + // Exempted Functions + TestCase{"'%v %v'.format([1, 'a'])", true}, + + // Mixed Primitives and Wrappers + TestCase{"[1, msg.single_int64_wrapper]", true}, + TestCase{"[msg.single_int64_wrapper, 1]", true}, + TestCase{"['foo', msg.single_string_wrapper]", true}, + TestCase{"[msg.single_string_wrapper, 'foo']", true}, + TestCase{"{1: msg.single_int64_wrapper, 2: 3}", true}, + TestCase{"{1: 2, 2: msg.single_int64_wrapper}", true}, + TestCase{"[[1], [msg.single_int64_wrapper]]", true}, + TestCase{"[optional.of(1), optional.of(msg.single_int64_wrapper)]", + true}, + TestCase{"[1, msg.single_string_wrapper]", false, + "expected type 'int' but found 'wrapper(string)'"}, + TestCase{"[msg.single_int64_wrapper, 'foo']", false, + "expected type 'wrapper(int)' but found 'string'"}, + TestCase{"[msg.single_int64_wrapper, msg.single_string_wrapper]", false, + "expected type 'wrapper(int)' but found 'wrapper(string)'"}, + + // Nested + TestCase{"[1, [2, 'a']]", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1, 2], [3, 4]]", true, ""}, + TestCase{"[{1: 2}, {'foo': 3}]", false, + "expected type 'map(int, int)' but found 'map(string, int)'"}, + TestCase{"[{1: 2}, {3: 'foo'}]", false, + "expected type 'map(int, int)' but found 'map(int, string)'"}, + TestCase{"[{1: 2}, {3: 4}]", true, ""})); + +} // namespace +} // namespace cel diff --git a/validator/regex_validator.cc b/validator/regex_validator.cc new file mode 100644 index 000000000..df92bfb1e --- /dev/null +++ b/validator/regex_validator.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "internal/re2_options.h" +#include "validator/validator.h" +#include "re2/re2.h" + +namespace cel { + +namespace { + +bool CheckPattern(ValidationContext& context, const NavigableAstNode& node, + int arg_index) { + ABSL_DCHECK(node.expr()->has_call_expr()); + const auto& call_expr = node.expr()->call_expr(); + + const Expr* pattern_expr = nullptr; + + if (call_expr.has_target()) { + if (arg_index == 0) { + pattern_expr = &call_expr.target(); + } else if (call_expr.args().size() > arg_index - 1) { + pattern_expr = &call_expr.args()[arg_index - 1]; + } + } else if (call_expr.args().size() > arg_index) { + pattern_expr = &call_expr.args()[arg_index]; + } + + if (pattern_expr == nullptr || !pattern_expr->has_const_expr()) { + return true; + } + + const auto& const_expr = pattern_expr->const_expr(); + if (!const_expr.has_string_value()) { + return true; + } + + absl::string_view pattern_string = const_expr.string_value(); + RE2 re(pattern_string, internal::MakeRE2Options()); + if (!re.ok()) { + context.ReportErrorAt( + pattern_expr->id(), + absl::StrCat("invalid regular expression: ", re.error())); + return false; + } + return true; +} + +} // namespace + +Validation RegexPatternValidator( + absl::string_view id, std::vector config) { + return Validation( + [config = std::move(config)](ValidationContext& context) -> bool { + bool result = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kCall) { + for (const auto& config : config) { + if (node.expr()->call_expr().function() == config.function_name) { + if (!CheckPattern(context, node, config.pattern_arg_index)) { + result = false; + } + break; + } + } + } + } + return result; + }, + id); +} + +} // namespace cel diff --git a/validator/regex_validator.h b/validator/regex_validator.h new file mode 100644 index 000000000..15ee1755e --- /dev/null +++ b/validator/regex_validator.h @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "common/standard_definitions.h" +#include "validator/validator.h" + +namespace cel { + +// Configuration for the regex pattern validator. +struct RegexPatternValidatorConfig { + // The resolved function name. + std::string function_name; + // the index of the pattern argument (counting the receiver as arg 0 if + // present). + int pattern_arg_index; +}; + +// Returns a `Validation` that checks all calls to the given regex functions +// It validates that the specified argument is a valid regular expression if it +// is a literal string. +Validation RegexPatternValidator( + absl::string_view id, std::vector config); + +// Returns a `Validation` that checks all calls to the CEL `matches` function. +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +inline Validation MatchesValidator() { + return RegexPatternValidator( + "cel.validator.matches", + {{std::string(StandardFunctions::kRegexMatch), 1}}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ diff --git a/validator/regex_validator_test.cc b/validator/regex_validator_test.cc new file mode 100644 index 000000000..cfab1468d --- /dev/null +++ b/validator/regex_validator_test.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("p", StringType()))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using MatchesValidatorTest = testing::TestWithParam; + +TEST_P(MatchesValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(MatchesValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + MatchesValidatorTest, MatchesValidatorTest, + testing::Values( + // Member calls + TestCase{"'hello'.matches('h.*')", true}, + TestCase{"'hello'.matches('h[')", false, "invalid regular expression"}, + TestCase{"'hello'.matches('h(a|b)')", true}, + TestCase{"'hello'.matches('h(a|b')", false, + "invalid regular expression"}, + // Global calls + TestCase{"matches('hello', 'h.*')", true}, + TestCase{"matches('hello', 'h[')", false, "invalid regular expression"}, + // Non-literal patterns (should not report regex errors) + TestCase{"'hello'.matches(p)", true}, + TestCase{"'hello'.matches('h' + 'ello')", true}, + TestCase{"'hello'.matches(dyn(1))", true}, + + // Empty pattern + TestCase{"'hello'.matches('')", true})); + +} // namespace +} // namespace cel diff --git a/validator/timestamp_literal_validator.cc b/validator/timestamp_literal_validator.cc new file mode 100644 index 000000000..8b9b76ebb --- /dev/null +++ b/validator/timestamp_literal_validator.cc @@ -0,0 +1,134 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "common/navigable_ast.h" +#include "common/standard_definitions.h" +#include "internal/time.h" +#include "tools/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +bool ValidateTimestamps(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kTimestamp) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + absl::Time ts; + const Constant& constant = child.expr()->const_expr(); + if (constant.has_string_value()) { + absl::string_view timestamp_str = + child.expr()->const_expr().string_value(); + if (!absl::ParseTime(absl::RFC3339_full, timestamp_str, &ts, nullptr)) { + context.ReportErrorAt(child.expr()->id(), "invalid timestamp literal"); + valid = false; + continue; + } + } else if (constant.has_int_value()) { + ts = absl::FromUnixSeconds(constant.int_value()); + } else { + // Checker should have already reported an error. + continue; + } + + if (absl::Status status = internal::ValidateTimestamp(ts); !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid timestamp literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +bool ValidateDurations(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kDuration) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + const Constant& constant = child.expr()->const_expr(); + if (!constant.has_string_value()) { + continue; + } + absl::Duration duration; + + absl::string_view duration_str = child.expr()->const_expr().string_value(); + if (!absl::ParseDuration(duration_str, &duration)) { + context.ReportErrorAt(child.expr()->id(), "invalid duration literal"); + valid = false; + continue; + } + + if (absl::Status status = internal::ValidateDuration(duration); + !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid duration literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +} // namespace + +const Validation& TimestampLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateTimestamps, "cel.validator.timestamp"); + return *kInstance; +} + +// Returns a validator that checks duration literals. +const Validation& DurationLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateDurations, "cel.validator.duration"); + return *kInstance; +} + +} // namespace cel diff --git a/validator/timestamp_literal_validator.h b/validator/timestamp_literal_validator.h new file mode 100644 index 000000000..6d2a39318 --- /dev/null +++ b/validator/timestamp_literal_validator.h @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ + +#include "validator/validator.h" +namespace cel { + +// Returns a `Validation` that checks timestamp literals are valid for CEL. +const Validation& TimestampLiteralValidator(); + +// Returns a `Validation` that checks duration literals are valid for CEL. +const Validation& DurationLiteralValidator(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ diff --git a/validator/timestamp_literal_validator_test.cc b/validator/timestamp_literal_validator_test.cc new file mode 100644 index 000000000..136f7d645 --- /dev/null +++ b/validator/timestamp_literal_validator_test.cc @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + auto builder = + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()).value(); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + return builder->Build(); +} + +class TimestampLiteralValidatorTest : public ::testing::Test { + protected: + TimestampLiteralValidatorTest() { + validator_.AddValidation(TimestampLiteralValidator()); + } + + std::unique_ptr compiler_; + Validator validator_; +}; + +TEST(TimestampLiteralValidatorTest, FormatsIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile("timestamp('invalid')")); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_EQ(result.FormatError(), + R"(ERROR: :1:11: invalid timestamp literal + | timestamp('invalid') + | ..........^)"); +} + +TEST(TimestampLiteralValidatorTest, AccumulatesIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + constexpr absl::string_view kExpression = R"cel( + [ timestamp('invalid'), + timestamp('9999-12-31T23:59:59Z'), + timestamp('10000-01-01T00:00:00Z') + ].all(t, + t - timestamp(0) < duration('10000s') && + t - timestamp(0) > duration("invalid") + ))cel"; + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile(kExpression)); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + AllOf(HasSubstr("2:17: invalid timestamp literal"), + HasSubstr("4:17: invalid timestamp literal"), + HasSubstr("7:35: invalid duration literal"))); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using TimestampLiteralValidatorParameterizedTest = + testing::TestWithParam; + +TEST_P(TimestampLiteralValidatorParameterizedTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + TimestampLiteralValidatorParameterizedTest, + TimestampLiteralValidatorParameterizedTest, + ::testing::Values( + TestCase{"timestamp('2023-01-01T00:00:00Z')", true}, + TestCase{"timestamp('9999-12-31T23:59:59Z')", true}, + TestCase{"timestamp('invalid')", false, "invalid timestamp literal"}, + TestCase{"timestamp('10000-01-01T00:00:00Z')", false, + "invalid timestamp literal"}, + TestCase{"timestamp(0)", true}, + TestCase{"timestamp(-62135596801)", false, + "invalid timestamp literal: Timestamp \"0-12-31T23:59:59Z\" " + "below minimum allowed timestamp \"1-01-01T00:00:00Z\""}, + TestCase{"timestamp(253402300800)", false, + "invalid timestamp literal: Timestamp " + "\"10000-01-01T00:00:00Z\" above maximum allowed timestamp " + "\"9999-12-31T23:59:59.999999999Z\""}, + TestCase{"duration('1s')", true}, + TestCase{"duration('invalid')", false, "invalid duration literal"}, + TestCase{"duration('-1000000000000s')", false, + "below minimum allowed duration"}, + TestCase{"duration('1000000000000s')", false, + "above maximum allowed duration"})); + +} // namespace +} // namespace cel diff --git a/validator/validator.cc b/validator/validator.cc new file mode 100644 index 000000000..e000c71e8 --- /dev/null +++ b/validator/validator.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" + +namespace cel { + +void Validator::AddValidation(Validation validation) { + ABSL_DCHECK(validation); + if (!validation) return; + validations_.push_back(std::move(validation)); +} + +Validator::ValidationOutput Validator::Validate(const Ast& ast) const { + ValidationOutput result; + ValidationContext context(ast); + for (const auto& validation : validations_) { + if (!validation(context)) { + result.valid = false; + } + } + result.issues = context.ReleaseIssues(); + return result; +} + +void Validator::UpdateValidationResult(ValidationResult& in) const { + if (!in.IsValid() || in.GetAst() == nullptr) { + // If the result is already decided invalid, just return it. + return; + } + + auto result = Validate(*in.GetAst()); + if (!result.valid) { + in.ReleaseAst().IgnoreError(); + } + for (auto& issue : result.issues) { + in.AddIssue(std::move(issue)); + } +} + +void ValidationContext::ReportWarningAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportErrorAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportWarning(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + SourceLocation{}, std::string(message))); +} + +void ValidationContext::ReportError(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + SourceLocation{}, std::string(message))); +} + +} // namespace cel diff --git a/validator/validator.h b/validator/validator.h new file mode 100644 index 000000000..a278bd44f --- /dev/null +++ b/validator/validator.h @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/navigable_ast.h" +namespace cel { + +// Context for a validation pass. +// +// Assumed to be scoped to a Validator::Validate() call. Instances must not +// outlive the `ast` passed to the constructor. +class ValidationContext { + public: + explicit ValidationContext(const Ast& ast ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ast_(ast) {} + + const Ast& ast() const { return ast_; } + const NavigableAst& navigable_ast() const { + if (!navigable_ast_) { + navigable_ast_ = NavigableAst::Build(ast_.root_expr()); + } + return navigable_ast_; + } + + void ReportWarningAt(int64_t id, absl::string_view message); + void ReportErrorAt(int64_t id, absl::string_view message); + void ReportWarning(absl::string_view message); + void ReportError(absl::string_view message); + + std::vector ReleaseIssues() { + auto out = std::move(issues_); + issues_.clear(); + return out; + } + + private: + const Ast& ast_; + mutable NavigableAst navigable_ast_; + std::vector issues_; +}; + +// A single validation to apply to an AST. +// +// May be empty if default constructed or moved from. +// use operator bool() to check if the validation is empty. +class Validation { + public: + // Tests the AST reports any issues to the context. + // + // Returns false if the AST is invalid. + // + // The same instance is used across Validate() so must be thread safe + // (typically stateless). + using ImplFunction = + absl::AnyInvocable; + + Validation() = default; + explicit Validation(ImplFunction impl); + Validation(ImplFunction impl, absl::string_view id); + + const ImplFunction& impl() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl; + } + + absl::string_view id() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->id; + } + + bool operator()(ValidationContext& context) const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl(context); + } + + explicit operator bool() const { return rep_ != nullptr; } + + private: + struct Rep { + ImplFunction impl; + // Optional id if supported in environment config. + std::string id; + }; + + std::shared_ptr rep_; +}; + +// A validator checks a set of semantic rules for a given AST. +class Validator { + public: + Validator() = default; + + void AddValidation(Validation validation); + absl::Span validations() const { return validations_; } + + struct ValidationOutput { + bool valid = true; + std::vector issues; + }; + + // Validates the given AST by applying all of the validations. + ValidationOutput Validate(const Ast& ast) const; + + // Validates the given AST, updating the validation result in place. + // + // Used to apply validators to the output of the type checker. + void UpdateValidationResult(ValidationResult& in) const; + + private: + std::vector validations_; +}; + +// Implementation details. +inline Validation::Validation(ImplFunction impl) + : rep_(std::make_shared( + Validation::Rep{std::move(impl)})) {} + +inline Validation::Validation(ImplFunction impl, absl::string_view id) + : rep_(std::make_shared( + Validation::Rep{std::move(impl), std::string(id)})) {} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ diff --git a/validator/validator_test.cc b/validator/validator_test.cc new file mode 100644 index 000000000..744475ec1 --- /dev/null +++ b/validator/validator_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; + +TEST(ValidatorTest, AddValidationAndValidate) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportError("error 1"); + return false; + })); + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportWarning("warning 1"); + return true; + })); + + Ast ast; + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + ElementsAre(Property(&TypeCheckIssue::message, Eq("error 1")), + Property(&TypeCheckIssue::message, Eq("warning 1")))); + EXPECT_EQ(output.issues[0].severity(), TypeCheckIssue::Severity::kError); + EXPECT_EQ(output.issues[1].severity(), TypeCheckIssue::Severity::kWarning); +} + +TEST(ValidatorTest, ReportAt) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportErrorAt(1, "error at 1"); + context.ReportWarningAt(2, "warning at 2"); + return false; + })); + + Expr expr; + expr.set_id(1); + SourceInfo source_info; + source_info.mutable_positions()[1] = 10; + source_info.mutable_positions()[2] = 20; + source_info.set_line_offsets({15, 25}); + + Ast ast(std::move(expr), std::move(source_info)); + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + ASSERT_EQ(output.issues.size(), 2); + + EXPECT_EQ(output.issues[0].location().line, 1); + EXPECT_EQ(output.issues[0].location().column, 10); + + EXPECT_EQ(output.issues[1].location().line, 2); + EXPECT_EQ(output.issues[1].location().column, 5); +} + +} // namespace +} // namespace cel