From 0295c5288cc60ed1e0055bcef8f73e8b9999241d Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 3 Mar 2026 14:38:20 -0800 Subject: [PATCH 001/143] Update error message. PiperOrigin-RevId: 878130486 --- eval/eval/regex_match_step_test.cc | 2 +- internal/re2_options.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc index 8d54a0188..53b955b25 100644 --- a/eval/eval/regex_match_step_test.cc +++ b/eval/eval/regex_match_step_test.cc @@ -94,7 +94,7 @@ TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, - Eq("regular expressions exceeds max allowed size"))); + Eq("regular expression exceeds max allowed size"))); } } // namespace diff --git a/internal/re2_options.h b/internal/re2_options.h index 9c20ceb63..25a30f6bd 100644 --- a/internal/re2_options.h +++ b/internal/re2_options.h @@ -45,13 +45,13 @@ inline absl::Status CheckRE2(const RE2& re, int max_program_size) { if (max_program_size > 0 && program_size > 0 && program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } int reverse_program_size = re.ReverseProgramSize(); if (max_program_size > 0 && reverse_program_size > 0 && reverse_program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } return absl::OkStatus(); } From 0fc3715279c337e77aabebdbb954c895054f3d09 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 6 Mar 2026 22:13:29 -0800 Subject: [PATCH 002/143] No public description PiperOrigin-RevId: 879974382 --- eval/public/value_export_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index edb6e83e0..bca8a8d65 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -67,8 +67,8 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value, break; } case CelValue::Type::kBytes: { - absl::Base64Escape(in_value.BytesOrDie().value(), - out_value->mutable_string_value()); + *out_value->mutable_string_value() = + absl::Base64Escape(in_value.BytesOrDie().value()); break; } case CelValue::Type::kDuration: { From dd1bceff1ec9eda8d2c6a7ebb2651d3a14e513d4 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 10 Mar 2026 15:06:43 -0700 Subject: [PATCH 003/143] Rewrite IdentStep as Constant when variable value is provided to TypeChecker PiperOrigin-RevId: 881641178 --- eval/compiler/qualified_reference_resolver.cc | 11 ++- .../qualified_reference_resolver_test.cc | 91 +++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 09950bfe8..67f86ebb6 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -135,8 +135,17 @@ class ReferenceResolver : public cel::AstRewriterBase { expr.mutable_const_expr().set_int64_value( reference->value().int64_value()); return true; + } else if (expr.has_ident_expr()) { + // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes + // it is interpreted as null value and sometimes as an enum constant. + if (reference->value().has_null_value() && + expr.ident_expr().name() == + "google.protobuf.NullValue.NULL_VALUE") { + return false; + } + expr.set_const_expr(reference->value()); + return true; } else { - // No update if the constant reference isn't an int (an enum value). return false; } } diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 0d710a465..3fa7fca21 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -45,6 +45,7 @@ namespace google::api::expr::runtime { namespace { +using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Ast; @@ -343,6 +344,60 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { })pb")); } +// foo && bar +constexpr char kConstReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { + name: "foo" + } + } + args { + id: 5 + ident_expr { + name: "bar" + } + } + } +)"; + +TEST(ResolveReferences, ConstReferenceFolded) { + std::unique_ptr expr_ast = ParseTestProto(kConstReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar"); + expr_ast->mutable_reference_map()[5].mutable_value().set_bool_value(false); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + const_expr { bool_value: true } + } + args { + id: 5 + const_expr { bool_value: false } + } + })pb")); +} + TEST(ResolveReferences, ConstReferenceSkipped) { std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; @@ -388,6 +443,42 @@ TEST(ResolveReferences, ConstReferenceSkipped) { })pb")); } +constexpr char kNullValueReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_+_" + args { + id: 2 + ident_expr { + name: "google.protobuf.NullValue.NULL_VALUE" + } + } + args { + id: 5 + const_expr { int64_value: 1 } + } + } +)"; + +TEST(ResolveReferences, NullValueReferenceSkipped) { + std::unique_ptr expr_ast = ParseTestProto(kNullValueReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name( + "google.protobuf.NullValue.NULL_VALUE"); + expr_ast->mutable_reference_map()[2].mutable_value().set_null_value(nullptr); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(/*was_rewritten=*/false)); +} + constexpr char kExtensionAndExpr[] = R"( id: 1 call_expr { From 8be02b6f370d601a72e9c8bd64516168fc558da8 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 11 Mar 2026 12:15:14 -0700 Subject: [PATCH 004/143] Update checker to treat gp.NullValue as an int constant. Makes the C++ checker behave more consistent with cel-go and cel-java. Direct references to google.protobuf.NullValue should behave as an enum/int, but google.protobuf.Value{} with null_value alternative active behaves as CEL null. PiperOrigin-RevId: 882134119 --- checker/internal/type_checker_impl.cc | 18 +----------------- checker/internal/type_checker_impl_test.cc | 5 ----- checker/optional_test.cc | 11 ++++++----- checker/standard_library.cc | 6 ++---- checker/type_checker_builder_factory_test.cc | 2 +- common/type.cc | 4 +++- common/type_test.cc | 4 ++-- 7 files changed, 15 insertions(+), 35 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 14dce1647..1e9995b19 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -103,21 +103,6 @@ SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { return SourceLocation{line_idx + 1, rel_position}; } -// Special case for protobuf null fields. -bool IsPbNullFieldAssignable(const Type& value, const Type& field) { - if (field.IsNull()) { - return value.IsInt() || value.IsNull(); - } - - if (field.IsOptional() && value.IsOptional() && - field.AsOptional()->GetParameter().IsNull()) { - auto value_param = value.AsOptional()->GetParameter(); - return value_param.IsInt() || value_param.IsNull(); - } - - return false; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -421,8 +406,7 @@ class ResolveVisitor : public AstVisitorBase { if (field.optional()) { field_type = OptionalType(arena_, field_type); } - if (!inference_context_->IsAssignable(value_type, field_type) && - !IsPbNullFieldAssignable(value_type, field_type)) { + if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, field.id()), absl::StrCat( diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index c36051376..6eccc3701 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -1966,11 +1966,6 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, - CheckedExprTestCase{ - .expr = "TestAllTypes{null_value: null}", - .expected_result_type = AstType( - MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), - }, // Legacy nullability behaviors. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 28ae9a889..87c14f0cd 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -267,15 +267,16 @@ INSTANTIATE_TEST_SUITE_P( IsOptionalType(TypeSpec(PrimitiveType::kString))}, TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", IsOptionalType(TypeSpec(PrimitiveType::kString))}, - // Legacy nullability behaviors. TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}", - Eq(TypeSpec(MessageTypeSpec( - "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " + // Legacy nullability behaviors. + TestCase{ + "cel.expr.conformance.proto3.TestAllTypes{?single_value: null}", + Eq(TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_value: " "optional.of(null)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, diff --git a/checker/standard_library.cc b/checker/standard_library.cc index 4cd9e9831..744a171ef 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -14,6 +14,7 @@ #include "checker/standard_library.h" +#include #include #include "absl/base/no_destructor.h" @@ -833,11 +834,8 @@ absl::Status AddTypeConstantVariables(TypeCheckerBuilder& builder) { absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { VariableDecl pb_null; pb_null.set_name("google.protobuf.NullValue.NULL_VALUE"); - // TODO(uncreated-issue/74): This is interpreted as an enum (int) or null in - // different cases. We should add some additional spec tests to cover this and - // update the behavior to be consistent. pb_null.set_type(IntType()); - pb_null.set_value(Constant(nullptr)); + pb_null.set_value(Constant(int64_t{0})); CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(pb_null))); return absl::OkStatus(); } diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index d5cf47fee..a15d2e173 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -464,7 +464,7 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationValue) { // Note: one of fields are all added with safe traversal, so // we lose the union discriminator information. R"cel( - null_value == null && + null_value == 0 && number_value == 0.0 && string_value == '' && list_value == [] && diff --git a/common/type.cc b/common/type.cc index 2b81e39f8..ce8c7a89a 100644 --- a/common/type.cc +++ b/common/type.cc @@ -75,7 +75,9 @@ Type Type::Message(const Descriptor* absl_nonnull descriptor) { Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) { if (descriptor->full_name() == "google.protobuf.NullValue") { - return NullType(); + // Special case NullValue to prevent the emebedder providing a different + // descriptor for it and it leaking. + return IntType(); } return EnumType(descriptor); } diff --git a/common/type_test.cc b/common/type_test.cc index 119234fdc..2cebf27ba 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -45,7 +45,7 @@ TEST(Type, Enum) { EXPECT_EQ(Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"))), - NullType()); + IntType()); } TEST(Type, Field) { @@ -58,7 +58,7 @@ TEST(Type, Field) { BoolType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("null_value"))), - NullType()); + IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int32"))), IntType()); From 5cc391e2ea5e3d2d563bf18b33fd411fb1f46d84 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 13 Mar 2026 16:31:23 -0700 Subject: [PATCH 005/143] internal change PiperOrigin-RevId: 883393454 --- runtime/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/runtime/BUILD b/runtime/BUILD index b58880146..776a8223d 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -344,14 +344,11 @@ cc_test( deps = [ ":activation", ":constant_folding", - ":function", - ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:function_descriptor", - "//common:kind", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", From f72d56a653f93d474479108958c9d8a64a0d97fb Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 20 Mar 2026 18:17:21 -0700 Subject: [PATCH 006/143] Fix C++-17 compatibility issue for double(string) impl. StartIt, EndIt ctor for std::string_view is a C++ 20 feature. PiperOrigin-RevId: 887065454 --- runtime/standard/type_conversion_functions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 50b6e28ea..76e95751b 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -69,7 +69,7 @@ Value FormatDouble(double v, const Function::InvokeContext& context) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( "double format error: ", std::make_error_code(result.ec).message()))); } - absl::string_view out(buf, result.ptr); + absl::string_view out(buf, result.ptr - buf); return StringValue::From(out, arena); #endif } From 298ac3d30f42a67996ed36b2f3ce01fc6757faa9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 23 Mar 2026 13:30:26 -0700 Subject: [PATCH 007/143] Simplify recursive planning limit checks. Instead of trying to wrap a recursive subprogram in a stack machine step, the planner fails if the limit is exceeded. This simplifies program planning and avoids high overhead in deep but unbalanced ASTs. PiperOrigin-RevId: 888254034 --- eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 271 ++++++++---------- eval/public/cel_options.h | 16 +- .../expression_builder_benchmark_test.cc | 95 ++++-- runtime/runtime_options.h | 16 +- .../standard_runtime_builder_factory_test.cc | 92 +++++- 6 files changed, 297 insertions(+), 194 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..62b208772 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,6 +33,7 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", + "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a0fd427bd..91822092c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -109,6 +110,13 @@ constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: @@ -577,6 +585,12 @@ class FlatExprVisitor : public cel::AstVisitor { } } + void SetMaxRecursionDepth(int max_recursion_depth) { + max_recursion_depth_ = max_recursion_depth; + } + + bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } + void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); @@ -947,8 +961,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { SetProgressStatusError(absl::InternalError( @@ -1064,21 +1077,13 @@ class FlatExprVisitor : public cel::AstVisitor { } } + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. absl::optional RecursionEligible() { - if (program_builder_.current() == nullptr) { + if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } - absl::optional depth = - program_builder_.current()->RecursiveDependencyDepth(); - if (!depth.has_value()) { - // one or more of the dependencies isn't eligible. - return depth; - } - if (options_.max_recursion_depth < 0 || - *depth < options_.max_recursion_depth) { - return depth; - } - return absl::nullopt; + return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> @@ -1089,10 +1094,7 @@ class FlatExprVisitor : public cel::AstVisitor { return program_builder_.current()->ExtractRecursiveDependencies(); } - void MaybeMakeTernaryRecursive(const cel::Expr* expr) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); @@ -1107,26 +1109,16 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, condition_plan->recursive_program().depth); - - if (left_plan == nullptr || !left_plan->IsRecursive()) { + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, @@ -1136,10 +1128,7 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth + 1); } - void MaybeMakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); @@ -1151,21 +1140,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); if (is_or) { SetRecursiveStep( @@ -1182,11 +1164,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - void MaybeMakeOptionalShortcircuitRecursive(const cel::Expr* expr, - bool is_or_value) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -1199,21 +1177,13 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, @@ -1225,7 +1195,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } @@ -1233,16 +1203,12 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; - if (options_.max_recursion_depth > 0 && - result_depth >= options_.max_recursion_depth) { - return; - } - auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), @@ -1252,42 +1218,26 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); - - if (accu_plan == nullptr || !accu_plan->IsRecursive()) { - return; - } - auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); - - if (range_plan == nullptr || !range_plan->IsRecursive()) { - return; - } - auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); - - if (loop_plan == nullptr || !loop_plan->IsRecursive()) { - return; - } - auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); - - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); - - if (result_plan == nullptr || !result_plan->IsRecursive()) { + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } @@ -1298,11 +1248,6 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); - if (options_.max_recursion_depth > 0 && - max_depth >= options_.max_recursion_depth) { - return; - } - auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, @@ -1566,7 +1511,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } @@ -1579,8 +1524,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - absl::optional depth = RecursionEligible(); - if (depth.has_value()) { + if (absl::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1614,8 +1558,7 @@ class FlatExprVisitor : public cel::AstVisitor { std::vector fields = std::move(status_or_resolved_fields.value().second); - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { SetProgressStatusError(absl::InternalError( @@ -1646,7 +1589,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } @@ -1656,8 +1599,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { SetProgressStatusError(absl::InternalError( @@ -1696,8 +1638,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), @@ -1727,8 +1668,9 @@ class FlatExprVisitor : public cel::AstVisitor { return; } } - auto recursion_depth = RecursionEligible(); - if (recursion_depth.has_value()) { + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); @@ -1777,6 +1719,11 @@ class FlatExprVisitor : public cel::AstVisitor { return; } program_builder_.current()->set_recursive_program(std::move(step), depth); + if (depth > max_recursion_depth_) { + SetProgressStatusError(absl::InvalidArgumentError( + absl::StrCat("Maximum recursion depth of ", + options_.max_recursion_depth, " exceeded"))); + } } void SetProgressStatusError(const absl::Status& status) { @@ -1980,17 +1927,17 @@ class FlatExprVisitor : public cel::AstVisitor { IssueCollector& issue_collector_; ProgramBuilder& program_builder_; - PlannerContext extension_context_; + PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; absl::optional block_; + int max_recursion_depth_ = 0; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); - auto depth = RecursionEligible(); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a @@ -2000,7 +1947,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2027,9 +1974,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2046,15 +1991,13 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { - auto depth = RecursionEligible(); - if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError( @@ -2155,9 +2098,8 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2182,8 +2124,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2221,6 +2162,9 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && arg_num == 0 && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, @@ -2248,6 +2192,9 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, @@ -2275,6 +2222,28 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; + } + switch (cond_) { case BinaryCond::kAnd: visitor_->AddStep(CreateAndStep(expr->id())); @@ -2298,26 +2267,6 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_step_.set_target(visitor_->GetCurrentIndex())); } - // Handle maybe replacing the subprogram with a recursive version. This needs - // to happen after the jump step is updated (though it may get overwritten). - switch (cond_) { - case BinaryCond::kAnd: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false); - break; - case BinaryCond::kOr: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true); - break; - case BinaryCond::kOptionalOr: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/false); - break; - case BinaryCond::kOptionalOrValue: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/true); - break; - default: - ABSL_UNREACHABLE(); - } } void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2327,6 +2276,9 @@ void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -2380,6 +2332,10 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), @@ -2393,7 +2349,6 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } - visitor_->MaybeMakeTernaryRecursive(expr); } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2403,8 +2358,11 @@ void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } visitor_->AddStep(CreateTernaryStep(expr->id())); - visitor_->MaybeMakeTernaryRecursive(expr); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { @@ -2417,6 +2375,9 @@ void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); @@ -2491,6 +2452,9 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { case cel::ITER_RANGE: { break; @@ -2590,6 +2554,13 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( issue_collector, program_builder, extension_context, enable_optional_types_); + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + int depth_limit = options_.max_recursion_depth == -1 + ? std::numeric_limits::max() + : options_.max_recursion_depth; + visitor.SetMaxRecursionDepth(depth_limit); + } + cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 779839583..4d81eb8a7 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -171,17 +171,23 @@ struct InterpreterOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index c26a7cd5c..410df8902 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -50,8 +48,24 @@ using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, }; +std::string LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); @@ -64,21 +78,33 @@ BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; - switch (param) { case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && @@ -105,7 +131,9 @@ void BM_SymbolicPolicy(benchmark::State& state) { BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, @@ -209,6 +237,7 @@ BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) @@ -231,10 +260,13 @@ void BM_NestedComprehension(benchmark::State& state) { BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 @@ -260,7 +292,9 @@ void BM_Comparisons(benchmark::State& state) { BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( @@ -290,6 +324,8 @@ BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || @@ -325,7 +361,9 @@ void BM_RegexPrecompilationDisabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); @@ -333,10 +371,13 @@ void BM_RegexPrecompilationEnabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; @@ -377,7 +418,17 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 1e18fef95..7a61208a0 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -139,17 +139,23 @@ struct RuntimeOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index b73085f3c..029897233 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -52,24 +52,14 @@ using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; +using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::Truly; -struct EvaluateResultTestCase { - std::string name; - std::string expression; - bool expected_result; - std::function activation_builder; - - template - friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { - sink.Append(tc.name); - } -}; - const cel::MacroRegistry& GetMacros() { static absl::NoDestructor macros([]() { MacroRegistry registry; @@ -88,6 +78,84 @@ absl::StatusOr ParseWithTestMacros(absl::string_view expression) { return Parse(**src, GetMacros()); } +TEST(StandardRuntimeTest, RecursionLimitExceeded) { + RuntimeOptions opts; + opts.max_recursion_depth = 1; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 1 exceeded"))); +} + +TEST(StandardRuntimeTest, RecursionUnderLimit) { + RuntimeOptions opts; + opts.max_recursion_depth = 2; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, IntValueIs(3)); +} + +TEST(StandardRuntimeTest, RecursionLimitTracksLazyExpressions) { + RuntimeOptions opts; + opts.max_recursion_depth = 8; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(R"cel( + cel.bind(a, 4 + (3 + (2 + 1)), + cel.bind(b, 7 + (6 + (5 + a)), + 9 + (8 + b) + ) + ))cel")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 8 exceeded"))); +} + +struct EvaluateResultTestCase { + std::string name; + std::string expression; + bool expected_result; + std::function activation_builder; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; + class StandardRuntimeTest : public TestWithParam { public: const EvaluateResultTestCase& GetTestCase() { return GetParam(); } From 993bc2d65c215070413f564605d7b29b4cfb51ec Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 23 Mar 2026 15:54:42 -0700 Subject: [PATCH 008/143] Limit precision to <1000 in (string).format in the strings extension. PiperOrigin-RevId: 888321829 --- extensions/formatting.cc | 5 ++++ extensions/formatting_test.cc | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 970cc6388..6e58a7b86 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -54,6 +54,7 @@ namespace { static constexpr int32_t kNanosPerMillisecond = 1000000; static constexpr int32_t kNanosPerMicrosecond = 1000; +static constexpr int kMaxPrecision = 1000; absl::StatusOr FormatString( const Value& value, @@ -79,6 +80,10 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } + if (precision > kMaxPrecision) { + return absl::InvalidArgumentError( + absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision)); + } return std::pair{i, precision}; } diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 433e4ae24..824f14e45 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -59,6 +59,49 @@ using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::ValuesIn; +using StringFormatLimitsTest = TestWithParam; + +// Check that formatted floating points are reversible. +TEST_P(StringFormatLimitsTest, FormatLimits) { + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + RegisterStringFormattingFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(GetParam(), "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + static_assert(std::numeric_limits::min_exponent == -1021); + for (double x : { + 0x1p-1021, + 0x3p-1021, + std::numeric_limits::epsilon() * 0x1p-3, + std::numeric_limits::epsilon() * 0x7p-3, + 1.1 / 7.0 * 1e-101, + 1.2 / 7.0 * 1e-101, + }) { + activation.InsertOrAssignValue("x", DoubleValue(x)); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); + } +} + +INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, + ValuesIn({ + "double('%.326f'.format([x])) == x", + "double('%.17e'.format([x])) == x", + })); + struct FormattingTestCase { std::string name; std::string format; @@ -207,6 +250,12 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "'hello'", .error = "unable to find end of precision specifier", }, + { + .name = "InvalidPrecisionOutOfRange", + .format = "%.1001f", + .format_args = "1.2345", + .error = "precision specifier exceeds maximum of 100", + }, { .name = "DecimalFormatingClause", .format = "int %d, uint %d", From 614b79f15c94beca72d98fc9852fbf8b3baa3e39 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 23 Mar 2026 15:59:33 -0700 Subject: [PATCH 009/143] Introduce EnvRuntime and config-driven standard extension functions PiperOrigin-RevId: 888323645 --- extensions/BUILD | 1 + extensions/math_ext.cc | 29 +++++++++++++++++++---------- extensions/math_ext.h | 6 ++++-- extensions/strings.cc | 37 +++++++++++++++++++++++++------------ extensions/strings.h | 5 +++-- 5 files changed, 52 insertions(+), 26 deletions(-) diff --git a/extensions/BUILD b/extensions/BUILD index 1e6e9204a..fe97af46a 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -75,6 +75,7 @@ cc_library( srcs = ["math_ext.cc"], hdrs = ["math_ext.h"], deps = [ + ":math_ext_decls", "//common:casting", "//common:value", "//eval/public:cel_function_registry", diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 7b3655de3..4d133d90c 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -308,7 +308,8 @@ Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { } // namespace absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); @@ -360,6 +361,9 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, registry))); + if (version == 0) { + return absl::OkStatus(); + } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( @@ -370,15 +374,6 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.round", RoundDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtInt, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.trunc", TruncDouble, registry))); @@ -453,6 +448,20 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightUint, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtUint, registry))); + return absl::OkStatus(); } diff --git a/extensions/math_ext.h b/extensions/math_ext.h index 63d9e964b..fe000e476 100644 --- a/extensions/math_ext.h +++ b/extensions/math_ext.h @@ -18,6 +18,7 @@ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "extensions/math_ext_decls.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" @@ -25,8 +26,9 @@ namespace cel::extensions { // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. -absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterMathExtensionFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kMathExtensionLatestVersion); absl::Status RegisterMathExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, diff --git a/extensions/strings.cc b/extensions/strings.cc index 652c72572..ed6f27319 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -306,17 +306,8 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "join", /*receiver_style=*/true), - UnaryFunctionAdapter, ListValue>::WrapFunction( - Join1))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, ListValue, StringValue>:: - CreateDescriptor("join", /*receiver_style=*/true), - BinaryFunctionAdapter, ListValue, - StringValue>::WrapFunction(Join2))); + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -350,7 +341,6 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, @@ -388,9 +378,32 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "trim", &Trim, registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "join", /*receiver_style=*/true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + Join1))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, StringValue>:: + CreateDescriptor("join", /*receiver_style=*/true), + BinaryFunctionAdapter, ListValue, + StringValue>::WrapFunction(Join2))); + if (version == 2) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "reverse", &Reverse, registry))); diff --git a/extensions/strings.h b/extensions/strings.h index 5dab33c5d..3cbc9f19f 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -28,8 +28,9 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; // Register extension functions for strings. -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kStringsExtensionLatestVersion); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, From 1792682e83b401dced56213229cd4f4df68afdbd Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 24 Mar 2026 10:26:16 -0700 Subject: [PATCH 010/143] Publish cel/cpp/env to GitHub PiperOrigin-RevId: 888737100 --- MODULE.bazel | 5 + env/BUILD | 290 ++++ env/config.cc | 190 +++ env/config.h | 160 +++ env/config_test.cc | 222 ++++ env/env.cc | 318 +++++ env/env.h | 71 + env/env_runtime.cc | 77 ++ env/env_runtime.h | 73 + env/env_runtime_test.cc | 160 +++ env/env_std_extensions.cc | 76 ++ env/env_std_extensions.h | 42 + env/env_std_extensions_test.cc | 116 ++ env/env_test.cc | 713 ++++++++++ env/env_yaml.cc | 1026 ++++++++++++++ env/env_yaml.h | 39 + env/env_yaml_test.cc | 1467 +++++++++++++++++++++ env/internal/BUILD | 87 ++ env/internal/ext_registry.cc | 63 + env/internal/ext_registry.h | 74 ++ env/internal/ext_registry_test.cc | 73 + env/internal/runtime_ext_registry.cc | 64 + env/internal/runtime_ext_registry.h | 84 ++ env/internal/runtime_ext_registry_test.cc | 126 ++ env/runtime_std_extensions.cc | 130 ++ env/runtime_std_extensions.h | 46 + env/runtime_std_extensions_test.cc | 229 ++++ 27 files changed, 6021 insertions(+) create mode 100644 env/BUILD create mode 100644 env/config.cc create mode 100644 env/config.h create mode 100644 env/config_test.cc create mode 100644 env/env.cc create mode 100644 env/env.h create mode 100644 env/env_runtime.cc create mode 100644 env/env_runtime.h create mode 100644 env/env_runtime_test.cc create mode 100644 env/env_std_extensions.cc create mode 100644 env/env_std_extensions.h create mode 100644 env/env_std_extensions_test.cc create mode 100644 env/env_test.cc create mode 100644 env/env_yaml.cc create mode 100644 env/env_yaml.h create mode 100644 env/env_yaml_test.cc create mode 100644 env/internal/BUILD create mode 100644 env/internal/ext_registry.cc create mode 100644 env/internal/ext_registry.h create mode 100644 env/internal/ext_registry_test.cc create mode 100644 env/internal/runtime_ext_registry.cc create mode 100644 env/internal/runtime_ext_registry.h create mode 100644 env/internal/runtime_ext_registry_test.cc create mode 100644 env/runtime_std_extensions.cc create mode 100644 env/runtime_std_extensions.h create mode 100644 env/runtime_std_extensions_test.cc diff --git a/MODULE.bazel b/MODULE.bazel index fbe9b41fc..02404b645 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -100,3 +100,8 @@ http_jar( sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], ) + +bazel_dep( + name = "yaml-cpp", + version = "0.9.0", +) diff --git a/env/BUILD b/env/BUILD new file mode 100644 index 000000000..f5ce35557 --- /dev/null +++ b/env/BUILD @@ -0,0 +1,290 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + "//common:constant", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "env", + srcs = ["env.cc"], + hdrs = ["env.h"], + deps = [ + ":config", + "//checker:type_checker_builder", + "//common:constant", + "//common:decl", + "//common:type", + "//common:type_kind", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//env/internal:ext_registry", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_runtime", + srcs = ["env_runtime.cc"], + hdrs = ["env_runtime.h"], + deps = [ + ":config", + "//env/internal:runtime_ext_registry", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_std_extensions", + srcs = ["env_std_extensions.cc"], + hdrs = ["env_std_extensions.h"], + deps = [ + ":env", + "//checker:optional", + "//compiler:optional", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:proto_ext", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + ], +) + +cc_library( + name = "env_yaml", + srcs = ["env_yaml.cc"], + hdrs = ["env_yaml.h"], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + "//common:constant", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@yaml-cpp", + ], +) + +cc_library( + name = "runtime_std_extensions", + srcs = ["runtime_std_extensions.cc"], + hdrs = ["runtime_std_extensions.h"], + deps = [ + ":env_runtime", + "//checker:optional", + "//env/internal:runtime_ext_registry", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "config_test", + srcs = ["config_test.cc"], + deps = [ + ":config", + "//common:constant", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "env_test", + srcs = ["env_test.cc"], + deps = [ + ":config", + ":env", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:constant", + "//common:decl", + "//common:expr", + "//common:type", + "//common:value", + "//compiler", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_runtime_test", + srcs = ["env_runtime_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":env_yaml", + ":runtime_std_extensions", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//common:value", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_std_extensions_test", + srcs = ["env_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_std_extensions", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "env_yaml_test", + srcs = ["env_yaml_test.cc"], + deps = [ + ":config", + ":env_yaml", + "//common:constant", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "runtime_std_extensions_test", + srcs = ["runtime_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":runtime_std_extensions", + "//checker:optional", + "//checker:validation_result", + "//common:ast", + "//common:value", + "//compiler", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:strings", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/config.cc b/env/config.cc new file mode 100644 index 000000000..ccb4de34c --- /dev/null +++ b/env/config.cc @@ -0,0 +1,190 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +const char* ConstantKindToTypeName(const ConstantKind& kind) { + return std::visit(absl::Overload{ + [](const std::monostate& arg) { return "dyn"; }, + [](const std::nullptr_t& arg) { return "null"; }, + [](bool arg) { return "bool"; }, + [](int64_t arg) { return "int"; }, + [](uint64_t arg) { return "uint"; }, + [](double arg) { return "double"; }, + [](const BytesConstant& arg) { return "bytes"; }, + [](const StringConstant& arg) { return "string"; }, + [](absl::Duration arg) { return "duration"; }, + [](absl::Time arg) { return "timestamp"; }, + }, + kind); +} +} // namespace + +absl::Status Config::AddExtensionConfig(std::string name, int version) { + for (const ExtensionConfig& extension_config : extension_configs_) { + if (extension_config.name == name) { + if (extension_config.version == version) { + return absl::OkStatus(); + } + return absl::AlreadyExistsError(absl::StrCat( + "Extension '", name, "' version ", extension_config.version, + " is already included. Cannot also include version ", version)); + } + } + extension_configs_.push_back( + ExtensionConfig{.name = std::move(name), .version = version}); + return absl::OkStatus(); +} + +absl::Status Config::SetStandardLibraryConfig( + const Config::StandardLibraryConfig& standard_library_config) { + if (!standard_library_config.included_macros.empty() && + !standard_library_config.excluded_macros.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded macros."); + } + + if (!standard_library_config.included_functions.empty() && + !standard_library_config.excluded_functions.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded functions."); + } + + absl::flat_hash_set included_function_names; + for (const auto& function : standard_library_config.included_functions) { + if (function.second.empty()) { + included_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.included_functions) { + if (included_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot include function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + absl::flat_hash_set excluded_function_names; + for (const auto& function : standard_library_config.excluded_functions) { + if (function.second.empty()) { + excluded_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.excluded_functions) { + if (excluded_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot exclude function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + standard_library_config_ = standard_library_config; + return absl::OkStatus(); +} + +absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { + for (const VariableConfig& existing_variable_config : variable_configs_) { + if (existing_variable_config.name == variable_config.name) { + return absl::AlreadyExistsError(absl::StrCat( + "Variable '", variable_config.name, "' is already included.")); + } + } + if (variable_config.value.has_value()) { + absl::string_view constant_type_name = + ConstantKindToTypeName(variable_config.value.kind()); + if (constant_type_name != variable_config.type_info.name) { + return absl::InvalidArgumentError( + absl::StrCat("Variable '", variable_config.name, "' has type ", + variable_config.type_info.name, + " but is assigned a constant value of type ", + constant_type_name, ".")); + } + } + variable_configs_.push_back(variable_config); + return absl::OkStatus(); +} + +absl::Status Config::ValidateFunctionConfig( + const FunctionConfig& function_config) { + for (const auto& overload : function_config.overload_configs) { + if (overload.is_member_function && overload.parameters.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Function '", function_config.name, "' overload '", + overload.overload_id, + "' is marked as a member function but has no parameters. Member " + "functions must have at least one parameter (target).")); + } + } + return absl::OkStatus(); +} + +absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { + CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); + function_configs_.push_back(function_config); + return absl::OkStatus(); +} + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config) { + os << "StandardLibraryConfig("; + if (!config.included_macros.empty()) { + os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); + } + if (!config.excluded_macros.empty()) { + os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); + } + if (!config.included_functions.empty()) { + os << "\n included_functions=" + << absl::StrJoin(config.included_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + if (!config.excluded_functions.empty()) { + os << "\n excluded_functions=" + << absl::StrJoin(config.excluded_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + os << "\n)"; + return os; +} + +} // namespace cel diff --git a/env/config.h b/env/config.h new file mode 100644 index 000000000..10b23d030 --- /dev/null +++ b/env/config.h @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ +#define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel { + +class Config { + public: + void SetName(std::string name) { name_ = std::move(name); } + std::string GetName() const { return name_; } + + struct ContainerConfig { + std::string name; + // TODO(uncreated-issue/87): add support for aliases and abbreviations. + + bool IsEmpty() const { return name.empty(); } + }; + + void SetContainerConfig(ContainerConfig container_config) { + container_config_ = std::move(container_config); + } + + const ContainerConfig& GetContainerConfig() const { + return container_config_; + } + + struct ExtensionConfig { + static constexpr int kLatest = std::numeric_limits::max(); + + std::string name; + int version = kLatest; + }; + + absl::Status AddExtensionConfig(std::string name, + int version = ExtensionConfig::kLatest); + + const std::vector& GetExtensionConfigs() const { + return extension_configs_; + } + + struct StandardLibraryConfig { + // Exclude the entire standard library. + bool disable = false; + + // Exclude all standard library macros. + bool disable_macros = false; + + // Either included or excluded macros can be set, not both. If neither are + // set, all standard library macros are included. + absl::flat_hash_set included_macros; + absl::flat_hash_set excluded_macros; + + // Sets of pairs of function name and overload id to include or exclude. + // Either included or excluded functions can be set, not both. If neither + // are set, all standard library functions are included. + // If an overload is specified, only that overload is included or excluded. + // If no overload is specified (empty second element of pair), all overloads + // are included or excluded. + absl::flat_hash_set> included_functions; + absl::flat_hash_set> excluded_functions; + + bool IsEmpty() const { + return !disable && !disable_macros && included_macros.empty() && + excluded_macros.empty() && included_functions.empty() && + excluded_functions.empty(); + } + }; + + absl::Status SetStandardLibraryConfig( + const StandardLibraryConfig& standard_library_config); + + const StandardLibraryConfig& GetStandardLibraryConfig() const { + return standard_library_config_; + } + + struct TypeInfo { + std::string name; + std::vector params; + bool is_type_param = false; + }; + + struct VariableConfig { + std::string name; + std::string description; + TypeInfo type_info; + Constant value; + }; + + // Adds a variable config to the environment. The variable name and type + // are used by the CEL type checker to validate expressions. The variable + // value is used as an input value at runtime. + // + // Returns an error if a variable with the same name already exists, or if the + // type of the constant value does not match the specified type. + absl::Status AddVariableConfig(const VariableConfig& variable_config); + + const std::vector& GetVariableConfigs() const { + return variable_configs_; + } + + struct FunctionOverloadConfig { + std::string overload_id; + std::vector examples; + bool is_member_function = false; + std::vector parameters; + TypeInfo return_type; + }; + + struct FunctionConfig { + std::string name; + std::string description; + std::vector overload_configs; + }; + + absl::Status AddFunctionConfig(const FunctionConfig& function_config); + + const std::vector& GetFunctionConfigs() const { + return function_configs_; + } + + private: + std::string name_; + ContainerConfig container_config_; + std::vector extension_configs_; + StandardLibraryConfig standard_library_config_; + std::vector variable_configs_; + std::vector function_configs_; + + absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); +}; + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ diff --git a/env/config_test.cc b/env/config_test.cc new file mode 100644 index 000000000..df0d6f875 --- /dev/null +++ b/env/config_test.cc @@ -0,0 +1,222 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(EnvConfigTest, ExtensionConfigs) { + Config config; + ASSERT_THAT( + config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), + IsOk()); + ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvConfigTest, ExtensionConfigConflict) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 3), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::string expected_error; // Empty if no error is expected. +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + + Config config; + absl::Status status = + config.SetStandardLibraryConfig(param.standard_library_config); + if (param.expected_error.empty()) { + EXPECT_THAT(status, IsOk()); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + ::testing::Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_macros = {"all", "exists"}, + .excluded_macros = {"map", "filter"}, + }, + .expected_error = "Cannot set both included and excluded macros.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add_list'", + })); + +TEST(VariableConfigTest, VariableConfig) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = + { + .name = "mytype", + .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, + }, + }; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + + ASSERT_EQ(config.GetVariableConfigs().size(), 1); + const auto& added_config = config.GetVariableConfigs()[0]; + EXPECT_EQ(added_config.type_info.name, "mytype"); + ASSERT_THAT(added_config.type_info.params.size(), 2); + EXPECT_EQ(added_config.type_info.params[0].name, "int"); + EXPECT_FALSE(added_config.type_info.params[0].is_type_param); + EXPECT_EQ(added_config.type_info.params[1].name, "A"); + EXPECT_TRUE(added_config.type_info.params[1].is_type_param); +} + +TEST(VariableConfigTest, VariableConfigConflict) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + .value = Constant(StringConstant("hello")), + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Variable 'test' has type int but is assigned " + "a constant value of type string."))); +} + +TEST(FunctionConfigTest, FunctionConfig) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.description = "Ultimate test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_with_pill", + .examples = {"oracle.isTheOne('Neo', RED)"}, + .is_member_function = true, + .parameters = {{.name = "string"}, {.name = "Choice"}}, + .return_type = {.name = "bool"}, + }); + ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); + ASSERT_EQ(config.GetFunctionConfigs().size(), 1); + const auto& added_config = config.GetFunctionConfigs()[0]; + EXPECT_EQ(added_config.name, "test"); + EXPECT_EQ(added_config.description, "Ultimate test"); + EXPECT_EQ(added_config.overload_configs.size(), 1); + + const auto& overload_config = added_config.overload_configs[0]; + EXPECT_EQ(overload_config.overload_id, "test_with_pill"); + EXPECT_THAT(overload_config.examples, + ElementsAre("oracle.isTheOne('Neo', RED)")); + EXPECT_TRUE(overload_config.is_member_function); + EXPECT_THAT( + overload_config.parameters, + ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), + Field(&Config::TypeInfo::is_type_param, false)), + AllOf(Field(&Config::TypeInfo::name, "Choice"), + Field(&Config::TypeInfo::is_type_param, false)))); + EXPECT_THAT(overload_config.return_type, + Field(&Config::TypeInfo::name, "bool")); +} + +TEST(FunctionConfigTest, FunctionConfigInvalidMember) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_member_no_params", + .is_member_function = true, + .parameters = {}, + }); + EXPECT_THAT(config.AddFunctionConfig(function_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("is marked as a member function but has no " + "parameters"))); +} + +} // namespace +} // namespace cel diff --git a/env/env.cc b/env/env.cc new file mode 100644 index 000000000..2c2555f14 --- /dev/null +++ b/env/env.cc @@ -0,0 +1,318 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, + absl::string_view macro) { + if (config.disable_macros) { + return false; + } + if (config.excluded_macros.contains(macro)) { + return false; + } + if (!config.included_macros.empty() && + !config.included_macros.contains(macro)) { + return false; + } + return true; +} + +bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, + absl::string_view function, + absl::string_view overload_id) { + if (config.excluded_functions.contains( + std::make_pair(std::string(function), std::string(overload_id))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + if (!config.included_functions.empty() && + !config.included_functions.contains( + std::make_pair(std::string(function), "")) && + !config.included_functions.contains( + std::make_pair(std::string(function), std::string(overload_id)))) { + return false; + } + return true; +} + +absl::StatusOr MakeStdlibSubset( + const Config::StandardLibraryConfig& standard_library_config) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + // Capturing by reference is safe. The returned CompilerLibrarySubset's + // callbacks are only used during CompilerBuilder::Build() to configure + // contributed functions and macros. They are not retained by the constructed + // Compiler instance. The referenced config outlives the Build() call. + subset.should_include_macro = [&standard_library_config](const Macro& macro) { + return ShouldIncludeMacro(standard_library_config, macro.function()); + }; + subset.should_include_overload = [&standard_library_config]( + absl::string_view function, + absl::string_view overload_id) { + return ShouldIncludeFunction(standard_library_config, function, + overload_id); + }; + return subset; +} + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return MessageType(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, arena, descriptor_pool)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], arena, descriptor_pool)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + arena, descriptor_pool)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], arena, descriptor_pool)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], arena, + descriptor_pool)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} + +absl::StatusOr FunctionConfigToFunctionDecl( + const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + FunctionDecl function_decl; + function_decl.set_name(function_config.name); + for (const Config::FunctionOverloadConfig& overload_config : + function_config.overload_configs) { + OverloadDecl overload_decl; + overload_decl.set_id(overload_config.overload_id); + overload_decl.set_member(overload_config.is_member_function); + for (const Config::TypeInfo& parameter : overload_config.parameters) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(parameter, arena, descriptor_pool)); + overload_decl.mutable_args().push_back(parameter_type); + } + CEL_ASSIGN_OR_RETURN( + Type return_type, + TypeInfoToType(overload_config.return_type, arena, descriptor_pool)); + overload_decl.set_result(return_type); + CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); + } + return function_decl; +} + +} // namespace + +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); + cel::TypeCheckerBuilder& checker_builder = + compiler_builder->GetCheckerBuilder(); + + checker_builder.set_container(config_.GetContainerConfig().name); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(StandardCompilerLibrary())); + CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, + MakeStdlibSubset(config_.GetStandardLibraryConfig())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); + } + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_ASSIGN_OR_RETURN(CompilerLibrary library, + extension_registry_.GetCompilerLibrary( + extension_config.name, extension_config.version)); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); + } + + google::protobuf::Arena* arena = checker_builder.arena(); + for (const Config::VariableConfig& variable_config : + config_.GetVariableConfigs()) { + VariableDecl variable_decl; + variable_decl.set_name(variable_config.name); + CEL_ASSIGN_OR_RETURN(Type type, + TypeInfoToType(variable_config.type_info, arena, + descriptor_pool_.get())); + variable_decl.set_type(type); + if (variable_config.value.has_value()) { + variable_decl.set_value(variable_config.value); + } + CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); + } + + for (const Config::FunctionConfig& function_config : + config_.GetFunctionConfigs()) { + CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, + FunctionConfigToFunctionDecl(function_config, arena, + descriptor_pool_.get())); + CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); + } + + return compiler_builder->Build(); +} + +} // namespace cel diff --git a/env/env.h b/env/env.h new file mode 100644 index 000000000..f46e5947c --- /dev/null +++ b/env/env.h @@ -0,0 +1,71 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ + +#include + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/internal/ext_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Env class establishes the environment for compiling CEL expressions. +// +// It is used to configure compiler options, extension functions, and other +// customizable CEL features. +class Env { + public: + // Registers a `CompilerLibrary` with the environment. Note that the library + // does not automatically get added to a `Compiler`. `NewCompiler` relies + // on `Config` to determine which libraries to load. + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + extension_registry_.RegisterCompilerLibrary(name, alias, version, + std::move(library_factory)); + } + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + const google::protobuf::DescriptorPool* GetDescriptorPool() const { + return descriptor_pool_.get(); + } + + void SetConfig(const Config& config) { config_ = config; } + + absl::StatusOr> NewCompiler(); + + private: + cel::env_internal::ExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + CompilerOptions compiler_options_; + Config config_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ diff --git a/env/env_runtime.cc b/env/env_runtime.cc new file mode 100644 index 000000000..09bbcde04 --- /dev/null +++ b/env/env_runtime.cc @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" + +namespace cel { + +absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { + const std::vector& extension_configs = + config_.GetExtensionConfigs(); + const Config::ExtensionConfig* optional_extension_config = nullptr; + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (extension_config.name == "optional") { + optional_extension_config = &extension_config; + runtime_options_.enable_qualified_type_identifiers = true; + break; + } + } + + CEL_ASSIGN_OR_RETURN( + RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR(RegisterStandardFunctions( + runtime_builder.function_registry(), runtime_options_)); + } + + // Register optional extension functions first, because other extensions + // depend on it (e.g. regex). + if (optional_extension_config != nullptr) { + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, optional_extension_config->name, + optional_extension_config->version)); + } + + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (&extension_config == optional_extension_config) { + continue; + } + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, extension_config.name, + extension_config.version)); + } + return runtime_builder; +} + +absl::StatusOr> EnvRuntime::NewRuntime() { + CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); + return std::move(runtime_builder).Build(); +} + +} // namespace cel diff --git a/env/env_runtime.h b/env/env_runtime.h new file mode 100644 index 000000000..ff62ec1d4 --- /dev/null +++ b/env/env_runtime.h @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "env/internal/runtime_ext_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// EnvRuntime class establishes the environment for creating CEL runtimes. +// +// It is used to configure runtime options, extension functions, and other +// customizable CEL runtime features. +// +// EnvRuntime is separate from Env to avoid a dependency on the compiler for +// binaries that only use the runtime. +// +// Even though EnvRuntime is separate from Env, the Config and DescriptorPool +// passed to EnvRuntime are expected to be the same as those passed to Env for +// compilation. This ensures consistency between compilation and runtime. +class EnvRuntime { + public: + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + void SetConfig(const Config& config) { config_ = config; } + + RuntimeOptions& mutable_runtime_options() { return runtime_options_; } + + absl::StatusOr CreateRuntimeBuilder(); + + // Shortcut for CreateRuntimeBuilder() followed by Build(). + absl::StatusOr> NewRuntime(); + + private: + cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { + return extension_registry_; + } + + friend void RegisterStandardExtensions(EnvRuntime& env_runtime); + + cel::env_internal::RuntimeExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + Config config_; + RuntimeOptions runtime_options_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc new file mode 100644 index 000000000..1c4205224 --- /dev/null +++ b/env/env_runtime_test.cc @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string config_yaml; + std::string expr; + bool expected_to_fail = false; +}; + +class EnvRuntimeTest : public testing::TestWithParam {}; + +TEST_P(EnvRuntimeTest, EndToEnd) { + const TestCase& param = GetParam(); + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + std::unique_ptr ast; + if (!param.expected_to_fail) { + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); + } else { + // Bypass type checking to allow compilation to succeed since we expect the + // runtime to fail. + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource(param.expr, "")); + ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); + } + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + + absl::StatusOr> program_or = + runtime->CreateProgram(std::move(ast)); + if (param.expected_to_fail) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr; + return; + } + + ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; + + std::unique_ptr program = *std::move(program_or); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; +} + +std::vector GetEnvRuntimeTestCases() { + return { + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + - name: "optional" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + )yaml", + .expr = "1 + 2 == 3", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "1 + 2 == 3", + .expected_to_fail = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, + ValuesIn(GetEnvRuntimeTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_std_extensions.cc b/env/env_std_extensions.cc new file mode 100644 index 000000000..f2041b979 --- /dev/null +++ b/env/env_std_extensions.cc @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include "checker/optional.h" +#include "compiler/optional.h" +#include "env/env.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/proto_ext.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" + +namespace cel { + +void RegisterStandardExtensions(Env& env) { + env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { + return extensions::BindingsCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { + return extensions::EncodersCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.lists", "lists", version, + [version]() { return extensions::ListsCompilerLibrary(version); }); + } + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.math", "math", version, + [version]() { return extensions::MathCompilerLibrary(version); }); + } + for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { + env.RegisterCompilerLibrary("optional", "", version, [version]() { + return OptionalCompilerLibrary(version); + }); + } + env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { + return extensions::ProtoExtCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { + return extensions::SetsCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.strings", "strings", version, + [version]() { return extensions::StringsCompilerLibrary(version); }); + } + env.RegisterCompilerLibrary( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + []() { return extensions::ComprehensionsV2CompilerLibrary(); }); + env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { + return extensions::RegexExtCompilerLibrary(); + }); +} + +} // namespace cel diff --git a/env/env_std_extensions.h b/env/env_std_extensions.h new file mode 100644 index 000000000..79cf37dbf --- /dev/null +++ b/env/env_std_extensions.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ + +#include "env/env.h" + +namespace cel { + +// Registers the standard CEL extensions with the given environment. This makes +// them available, but does not enable them. See Env::Config for how to enable +// extensions. +// +// Extensions are registered under the following names: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// - cel.lib.ext.regex (alias: "regex") +void RegisterStandardExtensions(Env& env); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ diff --git a/env/env_std_extensions_test.cc b/env/env_std_extensions_test.cc new file mode 100644 index 000000000..7d9572cc0 --- /dev/null +++ b/env/env_std_extensions_test.cc @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::TestWithParam; + +struct TestCase { + std::string extension; + std::string expr; +}; + +class EnvStdExtensions : public testing::TestWithParam {}; + +TEST_P(EnvStdExtensions, RegistrationTest) { + const TestCase& param = GetParam(); + + Env env; + RegisterStandardExtensions(env); + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); + ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr + << " but got: " << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + RegistrationTest, EnvStdExtensions, + ::testing::Values( + TestCase{ + .extension = "cel.lib.ext.bindings", // official name + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "bindings", // alias + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "encoders", + .expr = "base64.encode(b'hello')", + }, + TestCase{ + .extension = "lists", + .expr = "[1, 2, 3].sort()", + }, + TestCase{ + .extension = "lists", + .expr = "['a'].sortBy(e, e)", + }, + TestCase{ + .extension = "math", + .expr = "math.sqrt(-1)", + }, + TestCase{ + .extension = "optional", + .expr = "[1, 2].first()", + }, + TestCase{ + .extension = "optional", + .expr = "[0][?1]", // optional syntax auto-enabled + }, + TestCase{ + .extension = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension = "strings", + .expr = "'foo'.reverse()", + }, + TestCase{ + .extension = "two-var-comprehensions", + .expr = "[1, 2, 3, 4].all(i, v, i < v)", + }, + TestCase{ + .extension = "regex", + .expr = "regex.replace('abc', '$', '_end')", + })); + +} // namespace +} // namespace cel diff --git a/env/env_test.cc b/env/env_test.cc new file mode 100644 index 000000000..dcd2d97fa --- /dev/null +++ b/env/env_test.cc @@ -0,0 +1,713 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::internal::test::EqualsProto; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::testing::Values; +using ::testing::ValuesIn; + +Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { + return factory.NewStringConst("Hello"); +} + +class TestLibrary : public CompilerLibrary { + public: + explicit TestLibrary(int version) + : CompilerLibrary( + "testlib", + [version](ParserBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto macro1, + cel::Macro::Global("testMacro1", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto macro2, + cel::Macro::Global("testMacro2", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro2)); + } + return status; + }, + [version](TypeCheckerBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto func1, cel::MakeFunctionDecl( + "testFunc1", MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto func2, + cel::MakeFunctionDecl("testFunc2", + MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func2)); + } + return status; + }) {}; +}; + +absl::StatusOr CompileAndEvalExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); + if (compiler == nullptr) { + return absl::InternalError("Failed to create compiler"); + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); + if (!result.GetIssues().empty()) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::RuntimeOptions opts; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + rt_builder, cel::ReferenceResolverEnabled::kAlways)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + if (runtime == nullptr) { + return absl::InternalError("Failed to create runtime"); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); + if (ast == nullptr) { + return absl::InternalError("Failed to create AST"); + } + google::protobuf::Arena arena; + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + if (program == nullptr) { + return absl::InternalError("Failed to create program"); + } + CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); + return value; +} + +absl::StatusOr CompileAndEvalBooleanExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); + return value.GetBool(); +} + +class LibraryConfigTest : public testing::Test { + protected: + void SetUp() override { + env_.RegisterCompilerLibrary("testlib", "ml", 1, + []() { return TestLibrary(1); }); + env_.RegisterCompilerLibrary("testlib", "ml", 2, + []() { return TestLibrary(2); }); + env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + } + + Env env_; +}; + +TEST_F(LibraryConfigTest, DefaultVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), IsEmpty()); + EXPECT_THAT(result4.GetIssues(), IsEmpty()); +} + +TEST_F(LibraryConfigTest, SpecificVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testMacro2'")))); + EXPECT_THAT(result4.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testFunc2'")))); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::vector expected_valid_expressions; + std::vector expected_invalid_expressions; +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), + IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + for (const std::string& expr : param.expected_valid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), IsEmpty()) + << "With config: " << param.standard_library_config + << ", expected no issues for expr: " << expr + << " but got: " << result1.FormatError(); + } + for (const std::string& expr : param.expected_invalid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) + << "With config: " << param.standard_library_config + << ", expected compilation error for expr: " << expr << " but got: \'" + << result1.FormatError() << "\'"; + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + .expected_valid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable = true}, + .expected_invalid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable_macros = true}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, + .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "add_bytes"}, + {"_+_", "add_list"}, + {"_+_", "add_string"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_functions = {{"_+_", ""}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + })); + +TEST(ContainerConfigTest, ContainerConfig) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +struct TypeInfoTestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TypeInfoTestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + Config::VariableConfig variable_config; + variable_config.name = "test"; + variable_config.type_info = param.type_info; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_THAT(compiler, NotNull()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("test")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << " error: " << result.FormatError(); + + // Obtain the inferred return type of the expression `test`. + const Ast* ast = result.GetAst(); + ASSERT_THAT(ast, NotNull()); + cel::expr::CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*ast, &checked_expr), IsOk()); + auto it = checked_expr.type_map().find(checked_expr.expr().id()); + ASSERT_NE(it, checked_expr.type_map().end()); + + cel::expr::Type actual_type_pb = it->second; + EXPECT_THAT(actual_type_pb, EqualsProto(expected_type_pb)); +} + +std::vector GetTypeInfoTestCases() { + return { + TypeInfoTestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TypeInfoTestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TypeInfoTestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { dyn {} } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TypeInfoTestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TypeInfoTestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TypeInfoTestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TypeInfoTestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { dyn {} } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, TypeInfoTest, + ValuesIn(GetTypeInfoTestCases())); + +struct VariableConfigWithValueTestCase { + Config::VariableConfig variable_config; + std::string validate_type_expr; + std::string validate_value_expr; +}; + +class VariableConfigWithValueTest + : public testing::TestWithParam {}; + +TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { + const VariableConfigWithValueTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN( + bool type_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_type_expr)); + ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; + if (!param.validate_value_expr.empty()) { + ASSERT_OK_AND_ASSIGN( + bool value_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_value_expr)); + ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; + } +} + +Config::VariableConfig MakeConstant( + absl::string_view variable_name, absl::string_view type_name, + absl::AnyInvocable setter) { + Config::VariableConfig variable_config; + variable_config.name = variable_name; + Constant c; + setter(c); + variable_config.type_info.name = type_name; + variable_config.value = c; + return variable_config; +} + +std::vector +GetVariableConfigWithValueTestCases() { + return { + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "null", [](auto& c) { c.set_null_value(nullptr); }), + .validate_type_expr = "type(x) == type(null)", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "bool", [](auto& c) { c.set_bool_value(true); }), + .validate_type_expr = "type(x) == bool", + .validate_value_expr = "x == true", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "int", [](Constant& c) { c.set_int_value(42); }), + .validate_type_expr = "type(x) == int", + .validate_value_expr = "x == 42", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "uint", [](Constant& c) { c.set_uint_value(777); }), + .validate_type_expr = "type(x) == uint", + .validate_value_expr = "x == 777u", + }, + VariableConfigWithValueTestCase{ + .variable_config = + MakeConstant("x", "double", + [](Constant& c) { c.set_double_value(1.0 / 3.0); }), + .validate_type_expr = "type(x) == double", + .validate_value_expr = "x > 0.333 && x < 0.334", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant("x", "bytes", + [](Constant& c) { + c.set_bytes_value(absl::string_view( + "\xff\x00\x01", 3)); + }), + .validate_type_expr = "type(x) == bytes", + .validate_value_expr = "x == b'\\xff\\x00\\x01'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "string", [](Constant& c) { c.set_string_value("hello"); }), + .validate_type_expr = "type(x) == string", + .validate_value_expr = "x == 'hello'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "timestamp", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); + }), + .validate_type_expr = + "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", + .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "duration", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3)); + }), + .validate_type_expr = "type(x) == type(duration('1h2m3s'))", + .validate_value_expr = "x == duration('1h2m3s')", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, + ValuesIn(GetVariableConfigWithValueTestCases())); + +struct FunctionConfigTestCase { + Config::FunctionConfig function_config; + std::vector variable_configs; + std::string expr; + std::string expected_error; +}; + +class FunctionConfigTest + : public testing::TestWithParam {}; + +TEST_P(FunctionConfigTest, FunctionConfig) { + const FunctionConfigTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + for (const Config::VariableConfig& variable_config : param.variable_configs) { + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + } + ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + if (param.expected_error.empty()) { + EXPECT_TRUE(result.GetIssues().empty()) + << " expr: " << param.expr << " error: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + UnorderedElementsAre(Property(&TypeCheckIssue::message, + HasSubstr(param.expected_error)))) + << " expr: " << param.expr << " error: " << result.FormatError(); + } +} + +std::vector GetFunctionConfigTestCases() { + return {{ + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(int,int)", + .examples = {"add(1, 2) -> 3"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "add(1, 2)", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 3"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "1.add(2) == 3", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(string,string)", + .examples = + {"add('hello', 'world') -> 'hello world'"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "add('hello', 'world')", + .expected_error = "found no matching overload for 'add' applied to " + "'(string, string)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 'three'"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "1.add(2) == 3", + .expected_error = "found no matching overload for '_==_' applied to " + "'(string, int)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "double"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Matching opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "int"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Mismatched opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + .expected_error = "found no matching overload for 'sum' applied to " + "'(collection(double))'", + }, + }}; +} + +INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, + ::testing::ValuesIn(GetFunctionConfigTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_yaml.cc b/env/env_yaml.cc new file mode 100644 index 000000000..0035709e9 --- /dev/null +++ b/env/env_yaml.cc @@ -0,0 +1,1026 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "yaml-cpp/emitter.h" +#include "yaml-cpp/emittermanip.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/mark.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +namespace { + +std::string FormatYamlErrorMessage(absl::string_view yaml, + absl::string_view error, + const YAML::Mark& mark) { + if (mark.is_null()) { + return std::string(error); + } + std::string message; + absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, + "\n|"); + size_t start = mark.pos - mark.column; + size_t end = yaml.find('\n', mark.pos); + if (end == std::string::npos) { + end = yaml.size(); + } + + absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", + std::string(mark.column, ' '), "^"); + + return message; +} + +absl::StatusOr LoadYaml(const std::string& yaml) { + try { + return YAML::Load(yaml); + } catch (YAML::ParserException& e) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, e.msg, e.mark)); + } +} + +absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, + absl::string_view error) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, error, node.Mark())); +} + +std::string GetString(absl::string_view yaml, const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return ""; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return ""; + } +} + +bool IsBinary(const YAML::Node& node) { + return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; +} + +absl::StatusOr GetBinary(absl::string_view yaml, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { + return ""; + } + std::string binary; + // Instead of using the YAML::Binary type, we use absl::Base64Unescape + // because YAML::Binary is lenient to Base64 decoding errors. + if (absl::Base64Unescape(GetString(yaml, node), &binary)) { + return binary; + } else { + return YamlError(yaml, node, + "Node '" + GetString(yaml, node) + + "' is not a valid Base64 encoded binary"); + } +} + +absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return false; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + return YamlError(yaml, node, + "Node '" + std::string(key) + "' is not a boolean"); + } +} + +absl::Status ParseName(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node name = root["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' is not a string"); + } + config.SetName(GetString(yaml, name)); + } + return absl::OkStatus(); +} + +absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node container = root["container"]; + if (container.IsDefined()) { + if (!container.IsScalar()) { + return YamlError(yaml, container, "Node 'container' is not a string"); + } + config.SetContainerConfig({.name = GetString(yaml, container)}); + } + return absl::OkStatus(); +} + +absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node extensions = root["extensions"]; + if (!extensions.IsDefined()) { + return absl::OkStatus(); + } + if (!extensions.IsSequence()) { + return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); + } + + for (const YAML::Node& extension : extensions) { + if (!extension || !extension.IsMap()) { + return YamlError(yaml, extension, "Extension is not a map"); + } + const YAML::Node name = extension["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Extension name is not a string"); + } + std::string name_str = GetString(yaml, name); + + const YAML::Node version = extension["version"]; + std::string version_str = GetString(yaml, version); + int extension_version; + if (version.IsDefined()) { + bool is_valid_version = false; + if (version.IsScalar()) { + if (version_str == "latest") { + extension_version = Config::ExtensionConfig::kLatest; + is_valid_version = true; + } else { + if (absl::SimpleAtoi(version_str, &extension_version) && + extension_version >= 0) { + is_valid_version = true; + } + } + } + if (!is_valid_version) { + return YamlError( + yaml, version, + absl::StrCat("Extension '", name_str, + "' version is not a valid number or 'latest'")); + } + } else { + extension_version = Config::ExtensionConfig::kLatest; + } + absl::Status add_status = + config.AddExtensionConfig(name_str, extension_version); + if (!add_status.ok()) { + return YamlError(yaml, extension, add_status.message()); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> ParseMacroList( + absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set macro_set; + const YAML::Node macros = standard_library[std::string(key)]; + if (!macros.IsDefined()) { + return macro_set; + } + if (!macros.IsSequence()) { + return YamlError(yaml, macros, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& macro : macros) { + if (!macro.IsScalar()) { + return YamlError(yaml, macro, + absl::StrCat("Entry in '", key, "' is not a string")); + } + macro_set.insert(GetString(yaml, macro)); + } + return macro_set; +} + +absl::StatusOr>> +ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set> function_set; + const YAML::Node functions = standard_library[std::string(key)]; + if (!functions.IsDefined()) { + return function_set; + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& function : functions) { + if (!function.IsMap()) { + return YamlError(yaml, function, + absl::StrCat("Entry in '", key, "' is not a map")); + } + const YAML::Node name = function["name"]; + if (!name.IsDefined()) { + return YamlError( + yaml, function, + absl::StrCat("Function name in not specified in '", key, "'")); + } + if (!name.IsScalar()) { + return YamlError( + yaml, name, + absl::StrCat("Function name in '", key, "' entry is not a string")); + } + std::string name_str = GetString(yaml, name); + const YAML::Node overloads = function["overloads"]; + if (!overloads.IsDefined()) { + function_set.insert(std::make_pair(name_str, "")); + } else { + if (!overloads.IsSequence()) { + return YamlError( + yaml, overloads, + absl::StrCat("Overloads in '", key, "' entry is not a sequence")); + } + for (const YAML::Node& overload : overloads) { + if (!overload.IsMap()) { + return YamlError( + yaml, overload, + absl::StrCat("Overload in '", key, "' entry is not a map")); + } + const YAML::Node id = overload["id"]; + if (!id || !id.IsScalar()) { + return YamlError( + yaml, id, + absl::StrCat("Overload id in '", key, "' entry is not a string")); + } + function_set.insert(std::make_pair(name_str, GetString(yaml, id))); + } + } + } + return function_set; +} + +absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node standard_library = root["stdlib"]; + if (!standard_library.IsDefined()) { + return absl::OkStatus(); + } + + if (!standard_library.IsMap()) { + return YamlError(yaml, standard_library, + "Standard library config ('stdlib') is not a map"); + } + + Config::StandardLibraryConfig standard_library_config; + + const YAML::Node disable = standard_library["disable"]; + if (disable.IsDefined()) { + if (!disable.IsScalar()) { + return YamlError(yaml, disable, "Node 'disable' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable, + GetBool(yaml, "disable", disable)); + } + + const YAML::Node disable_macros = standard_library["disable_macros"]; + if (disable_macros.IsDefined()) { + if (!disable_macros.IsScalar()) { + return YamlError(yaml, disable_macros, + "Node 'disable_macros' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, + GetBool(yaml, "disable_macros", disable_macros)); + } + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_macros, + ParseMacroList(yaml, standard_library, "include_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_macros, + ParseMacroList(yaml, standard_library, "exclude_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_functions, + ParseFunctionList(yaml, standard_library, "include_functions")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_functions, + ParseFunctionList(yaml, standard_library, "exclude_functions")); + + return config.SetStandardLibraryConfig(standard_library_config); +} + +absl::StatusOr ParseTypeInfo(const YAML::Node& node, + absl::string_view yaml) { + Config::TypeInfo type_config; + const YAML::Node type_name = node["type_name"]; + if (!type_name.IsDefined()) { + return type_config; + } + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + type_config.name = GetString(yaml, type_name); + + const YAML::Node is_type_param = node["is_type_param"]; + if (is_type_param.IsDefined()) { + if (!is_type_param.IsScalar()) { + return YamlError(yaml, is_type_param, + "Node 'is_type_param' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(type_config.is_type_param, + GetBool(yaml, "is_type_param", is_type_param)); + } + + const YAML::Node params = node["params"]; + if (!params.IsDefined()) { + return type_config; + } + if (!params.IsSequence()) { + return YamlError(yaml, params, "Node 'params' is not a sequence"); + } + for (const YAML::Node& param : params) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, + ParseTypeInfo(param, yaml)); + type_config.params.push_back(param_config); + } + + return type_config; +} + +bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { + if (a.name != b.name) { + return a.name < b.name; + } + if (a.params.size() != b.params.size()) { + return a.params.size() < b.params.size(); + } + for (size_t i = 0; i < a.params.size(); ++i) { + if (CompareTypeInfo(a.params[i], b.params[i])) { + return true; + } + if (CompareTypeInfo(b.params[i], a.params[i])) { + return false; + } + } + return false; // They are equal +} + +ConstantKindCase GetConstantKindCase(absl::string_view type_name) { + static const auto kTypeNameToConstantKindCase = + absl::NoDestructor>({ + {"null", ConstantKindCase::kNull}, + {"bool", ConstantKindCase::kBool}, + {"int", ConstantKindCase::kInt}, + {"uint", ConstantKindCase::kUint}, + {"double", ConstantKindCase::kDouble}, + {"string", ConstantKindCase::kString}, + {"bytes", ConstantKindCase::kBytes}, + {"duration", ConstantKindCase::kDuration}, + {"timestamp", ConstantKindCase::kTimestamp}, + }); + if (auto it = kTypeNameToConstantKindCase->find(type_name); + it != kTypeNameToConstantKindCase->end()) { + return it->second; + } + return ConstantKindCase::kUnspecified; +} + +absl::StatusOr ParseConstantValue(absl::string_view yaml, + const YAML::Node& node, + ConstantKindCase constant_kind_case, + absl::string_view value) { + switch (constant_kind_case) { + case ConstantKindCase::kNull: + if (!value.empty()) { + return YamlError(yaml, node, "Failed to parse null constant"); + } + return Constant(nullptr); + case ConstantKindCase::kBool: + if (absl::EqualsIgnoreCase(value, "true")) { + return Constant(true); + } else if (absl::EqualsIgnoreCase(value, "false")) { + return Constant(false); + } else { + return YamlError(yaml, node, "Failed to parse bool constant"); + } + case ConstantKindCase::kInt: + int64_t int_value; + if (!absl::SimpleAtoi(value, &int_value)) { + return YamlError(yaml, node, "Failed to parse int constant"); + } + return Constant(int_value); + case ConstantKindCase::kUint: + uint64_t uint_value; + if (absl::EndsWith(value, "u")) { + value = value.substr(0, value.size() - 1); + } + if (!absl::SimpleAtoi(value, &uint_value)) { + return YamlError(yaml, node, "Failed to parse uint constant"); + } + return Constant(uint_value); + case ConstantKindCase::kDouble: + double double_value; + if (!absl::SimpleAtod(value, &double_value)) { + return YamlError(yaml, node, "Failed to parse double constant"); + } + return Constant(double_value); + case ConstantKindCase::kBytes: { + if (!IsBinary(node)) { + absl::StatusOr bytes_literal = + internal::ParseBytesLiteral(value); + if (bytes_literal.ok()) { + return Constant(BytesConstant(*bytes_literal)); + } + } + return Constant(BytesConstant(value)); + } + case ConstantKindCase::kString: + return Constant(StringConstant(value)); + case ConstantKindCase::kDuration: { + // Duration is deprecated as a builtin type, but still supported for + // compatibility. + absl::Duration duration_value; + if (!absl::ParseDuration(value, &duration_value)) { + return YamlError(yaml, node, "Failed to parse duration constant"); + } + return Constant(duration_value); + } + case ConstantKindCase::kTimestamp: { + // Timestamp is deprecated as a builtin type, but still supported for + // compatibility. + absl::Time timestamp_value; + std::string error; + // Format: YYYY-MM-DDThh:mm:ssZ + if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, + &error)) { + return YamlError( + yaml, node, + absl::StrCat("Failed to parse timestamp constant: ", error, + " supported format: YYYY-MM-DDThh:mm:ssZ")); + } + return Constant(timestamp_value); + } + default: + // This should never happen. + return YamlError(yaml, node, "Constant type is not supported"); + } +} + +absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node variables = root["variables"]; + if (!variables.IsDefined()) { + return absl::OkStatus(); + } + if (!variables.IsSequence()) { + return YamlError(yaml, variables, "Node 'variables' is not a sequence"); + } + + for (const YAML::Node& variable : variables) { + Config::VariableConfig variable_config; + if (!variable || !variable.IsMap()) { + return YamlError(yaml, variable, "Variable is not a map"); + } + const YAML::Node name = variable["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Variable name is not a string"); + } + variable_config.name = GetString(yaml, name); + const YAML::Node description = variable["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Variable description is not a string"); + } + variable_config.description = GetString(yaml, description); + } + + CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); + ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); + std::string value_str; + YAML::Node value = variable["value"]; + if (value.IsDefined()) { + if (constant_kind_case == ConstantKindCase::kUnspecified) { + return YamlError(yaml, value, + absl::StrCat("Constant type '", type_info.name, + "' is not supported")); + } + if (!value.IsScalar()) { + return YamlError(yaml, value, "Variable value is not a scalar"); + } + if (IsBinary(value)) { + CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); + } else { + value_str = GetString(yaml, value); + } + } + + variable_config.type_info = type_info; + + if (constant_kind_case != ConstantKindCase::kUnspecified) { + CEL_ASSIGN_OR_RETURN( + variable_config.value, + ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } + + CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseFunctionOverloadConfig( + absl::string_view yaml, const YAML::Node& overload) { + Config::FunctionOverloadConfig overload_config; + if (!overload || !overload.IsMap()) { + return YamlError(yaml, overload, "Function overload is not a map"); + } + const YAML::Node id = overload["id"]; + if (id.IsDefined()) { + if (!id.IsScalar()) { + return YamlError(yaml, id, "Function overload id is not a string"); + } + overload_config.overload_id = GetString(yaml, id); + } + const YAML::Node examples = overload["examples"]; + if (examples.IsDefined()) { + if (!examples.IsSequence()) { + return YamlError(yaml, examples, + "Function overload examples is not a sequence"); + } + for (const YAML::Node& example : examples) { + if (!example.IsScalar()) { + return YamlError(yaml, example, + "Function overload example is not a string"); + } + overload_config.examples.push_back(GetString(yaml, example)); + } + } + + const YAML::Node target = overload["target"]; + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; + overload_config.parameters.push_back(type_info); + } + + const YAML::Node args = overload["args"]; + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + + const YAML::Node return_type = overload["return"]; + if (return_type.IsDefined()) { + if (!return_type.IsMap()) { + return YamlError(yaml, return_type, + "Function overload return type is not a map"); + } + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } + return overload_config; +} + +absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node functions = root["functions"]; + if (!functions.IsDefined()) { + return absl::OkStatus(); + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, "Node 'functions' is not a sequence"); + } + + for (const YAML::Node& function : functions) { + Config::FunctionConfig function_config; + if (!function || !function.IsMap()) { + return YamlError(yaml, function, "Function is not a map"); + } + const YAML::Node name = function["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Function name is not a string"); + } + function_config.name = GetString(yaml, name); + const YAML::Node description = function["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Function description is not a string"); + } + function_config.description = GetString(yaml, description); + } + const YAML::Node overloads = function["overloads"]; + if (overloads.IsDefined()) { + if (!overloads.IsSequence()) { + return YamlError(yaml, overloads, + "Function 'overloads' item is not a sequence"); + } + + for (const YAML::Node& overload : overloads) { + CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload)); + function_config.overload_configs.push_back(std::move(overload_config)); + } + } + + CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); + } + return absl::OkStatus(); +} + +void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { + const auto& container_config = env_config.GetContainerConfig(); + if (container_config.IsEmpty()) { + return; + } + + out << YAML::Key << "container"; + out << YAML::Value << YAML::DoubleQuoted << container_config.name; +} + +void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { + if (env_config.GetExtensionConfigs().empty()) { + return; + } + + // Sort the extensions to make the output deterministic. + std::vector sorted_extensions = + env_config.GetExtensionConfigs(); + absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, + const Config::ExtensionConfig& b) { + return a.name < b.name; + }); + out << YAML::Key << "extensions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::ExtensionConfig& extension_config : sorted_extensions) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << extension_config.name; + if (extension_config.version != Config::ExtensionConfig::kLatest) { + out << YAML::Key << "version"; + out << YAML::Value << extension_config.version; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitMacroList(YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set& macros) { + if (macros.empty()) { + return; + } + out << YAML::Key << std::string(key); + out << YAML::Value << YAML::BeginSeq; + std::vector sorted_macros(macros.begin(), macros.end()); + absl::c_sort(sorted_macros); + for (const std::string& macro : sorted_macros) { + out << YAML::Value << YAML::DoubleQuoted << macro; + } + out << YAML::EndSeq; +} + +void EmitFunctionList( + YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set>& functions) { + if (functions.empty()) { + return; + } + + // Build a map from function name to a vector of overload ids. + // Using std::map ensures function names are sorted. + std::map> function_overloads; + for (const auto& pair : functions) { + function_overloads[pair.first].push_back(pair.second); + } + + out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; + for (auto const& [name, overloads] : function_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << name; + + // If the only overload is the empty string, it signifies that all overloads + // of the function are included/excluded. In this case, we don't emit the + // "overloads" key. Otherwise, emit the specific overloads. + if (!(overloads.size() == 1 && overloads[0].empty())) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = overloads; + absl::c_sort(sorted_overloads); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const std::string& overload : sorted_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { + const Config::StandardLibraryConfig& standard_library_config = + env_config.GetStandardLibraryConfig(); + if (standard_library_config.IsEmpty()) { + return; + } + + out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; + if (standard_library_config.disable) { + out << YAML::Key << "disable" << YAML::Value << true; + } + if (standard_library_config.disable_macros) { + out << YAML::Key << "disable_macros" << YAML::Value << true; + } + EmitMacroList(out, "include_macros", standard_library_config.included_macros); + EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); + EmitFunctionList(out, "include_functions", + standard_library_config.included_functions); + EmitFunctionList(out, "exclude_functions", + standard_library_config.excluded_functions); + out << YAML::EndMap; +} + +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { + // Note: the map is already started when this is called, so we don't emit + // BeginMap here or EndMap at the end. + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } +} + +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { + const auto& variable_configs = env_config.GetVariableConfigs(); + if (variable_configs.empty()) { + return; + } + + // Sort variable_configs by name to ensure deterministic output. + std::vector sorted_variable_configs = + variable_configs; + absl::c_sort(sorted_variable_configs, + [](const Config::VariableConfig& a, + const Config::VariableConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "variables"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::VariableConfig& variable_config : + sorted_variable_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.name; + if (!variable_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.description; + } + EmitTypeInfo(variable_config.type_info, out); + if (variable_config.value.has_value()) { + const Constant& constant = variable_config.value; + switch (constant.kind_case()) { + case ConstantKindCase::kUnspecified: + case ConstantKindCase::kNull: + break; + case ConstantKindCase::kBool: + out << YAML::Key << "value" << YAML::Value << constant.bool_value(); + break; + case ConstantKindCase::kInt: + out << YAML::Key << "value" << YAML::Value << constant.int_value(); + break; + case ConstantKindCase::kUint: + out << YAML::Key << "value" << YAML::Value << constant.uint_value(); + break; + case ConstantKindCase::kDouble: + out << YAML::Key << "value" << YAML::Value << constant.double_value(); + break; + case ConstantKindCase::kBytes: { + out << YAML::Key << "value"; + const std::string& bytes_value = constant.bytes_value(); + std::string hex_escaped = "b\""; + for (unsigned char byte : bytes_value) { + absl::StrAppend(&hex_escaped, "\\x"); + absl::StrAppendFormat(&hex_escaped, "%02x", byte); + } + absl::StrAppend(&hex_escaped, "\""); + out << YAML::Value << hex_escaped; + break; + } + case ConstantKindCase::kString: + out << YAML::Key << "value"; + out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); + break; + case ConstantKindCase::kDuration: + out << YAML::Key << "value" << YAML::Value; + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + out << absl::FormatDuration(constant.duration_value()); + break; + case ConstantKindCase::kTimestamp: + out << YAML::Key << "value" << YAML::Value; + out << absl::FormatTime( + "%4Y-%2m-%2d%ET%2H:%2M:%E*SZ", + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + constant.timestamp_value(), absl::UTCTimeZone()); + break; + } + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitFunctionOverloadConfig( + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out); + } else { + EmitTypeInfo(overload_config.parameters[0], out); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out); + out << YAML::EndMap; + + out << YAML::EndMap; +} + +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { + const std::vector& function_configs = + env_config.GetFunctionConfigs(); + if (function_configs.empty()) { + return; + } + + // Sort function_configs by name to ensure deterministic output. + std::vector sorted_function_configs = + function_configs; + absl::c_sort(sorted_function_configs, + [](const Config::FunctionConfig& a, + const Config::FunctionConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "functions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionConfig& function_config : + sorted_function_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << function_config.name; + if (!function_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << function_config.description; + } + if (!function_config.overload_configs.empty()) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = + function_config.overload_configs; + absl::c_sort(sorted_overloads, + [](const Config::FunctionOverloadConfig& a, + const Config::FunctionOverloadConfig& b) { + for (size_t i = 0; i < a.parameters.size(); ++i) { + // Order like this: foo(a), foo(a, b) + if (i >= b.parameters.size()) { + return false; + } + if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { + return true; + } + if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { + return false; + } + } + return false; + }); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionOverloadConfig& overload_config : + sorted_overloads) { + EmitFunctionOverloadConfig(overload_config, out); + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} +} // namespace + +absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { + Config config; + CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); + return config; +} + +void EnvConfigToYaml(const Config& env_config, std::ostream& os) { + YAML::Emitter out(os); + out.SetIndent(2); + out << YAML::BeginMap; + if (!env_config.GetName().empty()) { + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); + } + EmitContainerConfig(env_config, out); + EmitExtensionConfigs(env_config, out); + EmitStandardLibraryConfig(env_config, out); + EmitVariableConfigs(env_config, out); + EmitFunctionConfigs(env_config, out); + out << YAML::EndMap; +} + +} // namespace cel diff --git a/env/env_yaml.h b/env/env_yaml.h new file mode 100644 index 000000000..c96b45933 --- /dev/null +++ b/env/env_yaml.h @@ -0,0 +1,39 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" + +namespace cel { + +// EnvConfigFromYaml creates an environment configuration from a YAML string. +// +// To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz +// coverage, but its security model is unclear. Additionally, callers should be +// aware that improper CEL configuration can lead to unsafe or unpredictably +// expensive expressions. +absl::StatusOr EnvConfigFromYaml(const std::string& yaml); + +// EnvConfigToYaml serializes an environment configuration as a YAML string. +void EnvConfigToYaml(const Config& env_config, std::ostream& os); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc new file mode 100644 index 000000000..828a39b48 --- /dev/null +++ b/env/env_yaml_test.cc @@ -0,0 +1,1467 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(EnvYamlTest, ParseContainerConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: "test.container" + )yaml")); + + EXPECT_THAT(config.GetContainerConfig(), + Field(&Config::ContainerConfig::name, "test.container")); +} + +TEST(EnvYamlTest, ParseExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + extensions: + - name: "math" + version: latest + - name: "optional" + version: 2 + - name: "strings" + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvYamlTest, DefaultExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); +} + +TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + disable: true + disable_macros: true + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_TRUE(stdlib_config.disable); + EXPECT_TRUE(stdlib_config.disable_macros); + EXPECT_THAT(stdlib_config.excluded_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); + EXPECT_THAT( + stdlib_config.excluded_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + include_macros: + - map + - filter + include_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_THAT(stdlib_config.included_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT( + stdlib_config.included_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseVariableConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "msg"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); + EXPECT_EQ(variable_config.description, + "msg represents all possible type permutation which CEL " + "understands from a proto perspective"); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type_name: "map" + params: + - type_name: "string" + - type_name: "A" + is_type_param: true + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +struct ParseConstantTestCase { + std::string type_name; + std::string value; + std::string expected_error; // Empty if no error. + Constant expected_constant; +}; + +class EnvYamlParseConstantTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { + const ParseConstantTestCase& param = GetParam(); + const std::string yaml = absl::StrFormat( + R"yaml( + variables: + - name: "const" + type_name: "%s" + value: %s + )yaml", + param.type_name, param.value); + absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); + if (!param.expected_error.empty()) { + EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + return; + } + ASSERT_OK_AND_ASSIGN(Config config, status_or_config); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "const"); + EXPECT_EQ(variable_config.type_info.name, param.type_name); + EXPECT_EQ(variable_config.value, param.expected_constant); +} + +std::vector GetParseConstantTestCases() { + return { + ParseConstantTestCase{ + .type_name = "null", + .value = "\"\"", + .expected_constant = Constant(nullptr), + }, + ParseConstantTestCase{ + .type_name = "null", + .value = "anything", + .expected_error = "Failed to parse null constant", + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "TRUE", + .expected_constant = Constant(true), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "false", + .expected_constant = Constant(false), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "yes", + .expected_error = "Failed to parse bool constant", + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "42", + .expected_constant = Constant(int64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "41.999", + .expected_error = "Failed to parse int constant", + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42u", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "-1", + .expected_error = "Failed to parse uint constant", + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "42.42", + .expected_constant = Constant(42.42), + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "abc", + .expected_error = "Failed to parse double constant", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "b\"\\xFF\\x00\\x01\"", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary /wAB", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary YWJj=", + .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "abc", + .expected_constant = Constant(StringConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "\"\\\"abc\\\"\"", + .expected_constant = Constant(StringConstant("\"abc\"")), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "1s", + .expected_constant = Constant(absl::Seconds(1)), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "abc", + .expected_error = "Failed to parse duration constant", + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "2023-01-01T00:00:00Z", + .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "abc", + .expected_error = "Failed to parse timestamp constant", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, + ::testing::ValuesIn(GetParseConstantTestCases())); + +struct ParseFunctionTestCase { + std::string yaml; + Config::FunctionConfig expected_function_config; +}; + +class EnvYamlParseFunctionTest + : public testing::TestWithParam {}; + +void ExpectTypeInfoEqual(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_THAT(actual.params, SizeIs(expected.params.size())); + for (size_t i = 0; i < expected.params.size(); ++i) { + ExpectTypeInfoEqual(actual.params[i], expected.params[i]); + } +} + +TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { + const ParseFunctionTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); + + ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); + const Config::FunctionConfig& function_config = + config.GetFunctionConfigs()[0]; + const Config::FunctionConfig& expected = param.expected_function_config; + + EXPECT_EQ(function_config.name, expected.name); + EXPECT_EQ(function_config.description, expected.description); + + ASSERT_THAT(function_config.overload_configs, + SizeIs(expected.overload_configs.size())); + + for (size_t i = 0; i < expected.overload_configs.size(); ++i) { + const auto& actual_overload = function_config.overload_configs[i]; + const auto& expected_overload = expected.overload_configs[i]; + + EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); + EXPECT_THAT(actual_overload.examples, + ElementsAreArray(expected_overload.examples)); + EXPECT_EQ(actual_overload.is_member_function, + expected_overload.is_member_function); + + ASSERT_THAT(actual_overload.parameters, + SizeIs(expected_overload.parameters.size())); + for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { + ExpectTypeInfoEqual(actual_overload.parameters[j], + expected_overload.parameters[j]); + } + + ExpectTypeInfoEqual(actual_overload.return_type, + expected_overload.return_type); + } +} + +std::vector GetParseFunctionTestCases() { + return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "wrapper_string_isEmpty", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = + {{.name = "google.protobuf.StringValue"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list_isEmpty", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - id: "global_contains" + examples: + - "contains([1, 2, 3], 2) // true" + args: + - type_name: "list" + params: + - type_name: "T" + is_type_param: true + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "global_contains", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, + ::testing::ValuesIn(GetParseFunctionTestCases())); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +class EnvYamlParseTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { + const ParseTestCase& param = GetParam(); + absl::StatusOr config = EnvConfigFromYaml(param.yaml); + EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlParseTest, EnvYamlParseTest, + ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( + name: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'name' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'container' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + -name: "optional" + - name: "other" + )yaml", + .expected_error = "5:21: end of map not found\n" + "| - name: \"other\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: "bar" + )yaml", + .expected_error = "2:27: Node 'extensions' is not a sequence\n" + "| extensions: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: + - something: "bar" + )yaml", + .expected_error = "4:19: Extension name is not a string\n" + "| - something: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: last + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: last\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: -15 + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: -15\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: 1 + - name: "math" + version: 2 + )yaml", + .expected_error = "5:19: Extension 'math' version 1 is already " + "included. Cannot also include version 2\n" + "| - name: \"math\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: "error" + )yaml", + .expected_error = "2:23: Standard library config ('stdlib') " + "is not a map\n" + "| stdlib: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable: "error" + )yaml", + .expected_error = "3:26: Node 'disable' is not a boolean\n" + "| disable: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable_macros: "error" + )yaml", + .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" + "| disable_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: "error" + )yaml", + .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" + "| exclude_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: + - foo: "error" + )yaml", + .expected_error = "4:19: Entry in 'exclude_macros' " + "is not a string\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: "error" + )yaml", + .expected_error = "3:36: Node 'include_functions' " + "is not a sequence\n" + "| include_functions: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - "error" + )yaml", + .expected_error = "4:19: Entry in 'include_functions' " + "is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - foo: "error" + )yaml", + .expected_error = "4:19: Function name in not specified in " + "'include_functions'\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "5:30: Overloads in 'include_functions' entry " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - foo_string + )yaml", + .expected_error = "6:21: Overload in 'include_functions' entry " + "is not a map\n" + "| - foo_string\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - id: + - foo_int64 + )yaml", + .expected_error = "7:21: Overload id in 'include_functions' entry " + "is not a string\n" + "| - foo_int64\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: + - type_name: "opaque" + )yaml", + .expected_error = "4:19: Variable name is not a string\n" + "| - type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: + - params: + )yaml", + .expected_error = "5:21: Node 'type_name' is not a string\n" + "| - params:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + params: + - type_name: "int" + - type_name: "A" + is_type_param: maybe + )yaml", + .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" + "| is_type_param: maybe\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: -1 + )yaml", + .expected_error = "5:26: Failed to parse uint constant\n" + "| value: -1\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: many + )yaml", + .expected_error = "2:26: Node 'functions' is not a sequence\n" + "| functions: many\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: + - overloads: + )yaml", + .expected_error = "4:19: Function name is not a string\n" + "| - overloads:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "4:30: Function 'overloads' item " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: + - "error" + )yaml", + .expected_error = "6:25: Function overload id is not a string\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + - "error" + )yaml", + .expected_error = "7:25: Function overload target is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + type_name: "Foo" + params: + - type_name: + - is_type_param: true + )yaml", + .expected_error = "10:31: Node 'type_name' is not a string\n" + "| " + "- is_type_param: true\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + args: "a bunch" + )yaml", + .expected_error = "6:29: Function overload args is not a sequence\n" + "| args: \"a bunch\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + return: "to sender" + )yaml", + .expected_error = "6:31: Function overload return type" + " is not a map\n" + "| return: \"to sender\"\n" + "| ^", + })); + +std::string Unindent(std::string_view yaml) { + std::vector lines = absl::StrSplit(yaml, '\n'); + int indent = -1; + std::vector unindented_lines; + for (auto& line : lines) { + std::size_t pos = line.find_first_not_of(" \t"); + if (pos == std::string::npos) { + // Skip blank lines. + continue; + } + if (indent == -1) { + indent = pos; + } + if (pos >= indent) { + unindented_lines.push_back(line.substr(indent)); + } else { + unindented_lines.push_back(line); + } + } + return absl::StrJoin(unindented_lines, "\n"); +} + +struct ExportTestCase { + absl::StatusOr config; + std::string expected_yaml; +}; + +class EnvYamlExportTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlExportTest, EnvYamlExport) { + const ExportTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, param.config); + std::stringstream ss; + EnvConfigToYaml(config, ss); + std::string yaml_output = Unindent(ss.str()); + std::string expected_yaml = Unindent(param.expected_yaml); + EXPECT_EQ(yaml_output, expected_yaml); +} + +std::vector GetExportTestCases() { + return { + ExportTestCase{ + .config = + []() { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); + return config; + }(), + .expected_yaml = R"yaml( + extensions: + - name: "bindings" + - name: "math" + - name: "optional" + version: 2 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable_macros = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable_macros: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "null"}, + .value = Constant(nullptr)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "null" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "bool"}, + .value = Constant(true)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bool" + value: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "int"}, + .value = Constant(int64_t{42})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "int" + value: 42 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "uint"}, + .value = Constant(uint64_t{777})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: 777 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "double"}, + .value = Constant(0.75)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "double" + value: 0.75 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "bytes"}, + .value = Constant( + BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bytes" + value: b"\xff\x00\x01" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + Constant c; + c.set_string_value("'single' \"double\""); + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "string"}, + .value = Constant(StringConstant("'single' \"double\""))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "string" + value: "'single' \"double\"" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "duration"}, + .value = Constant(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "duration" + value: 1h2m3s + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "timestamp"}, + .value = Constant(absl::FromUnixSeconds(1767323045))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = + "google.expr.proto3.test.TestAllTypes"}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "google.expr.proto3.test.TestAllTypes" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = { + .name = "A", + .params = {{.name = "int"}, + {.name = "B", .is_type_param = true}}}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "A" + params: + - type_name: "int" + - type_name: "B" + is_type_param: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_id", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_a", + .parameters = {{.name = "timestamp"}}, + .return_type = {.name = "list", + .params = {{.name = "int"}}}}, + {.overload_id = "foo_overload_b", + .parameters = {{.name = "double"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "string"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_b" + args: + - type_name: "double" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "string" + - id: "foo_overload_a" + args: + - type_name: "timestamp" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }, + }; +}; + +INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, + ::testing::ValuesIn(GetExportTestCases())); + +class EnvYamlRoundTripTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetRoundTripTestCases() { + return { + R"yaml( + stdlib: + disable: true + disable_macros: true + )yaml", + R"yaml( + name: "test.env" + container: "common.proto.prefix" + extensions: + - name: "math" + version: 0 + - name: "optional" + version: 2 + stdlib: + include_macros: + - "filter" + - "map" + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + extensions: + - name: "bindings" + - name: "math" + stdlib: + exclude_macros: + - "filter" + - "map" + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + variables: + - name: "a" + type_name: "null" + - name: "b" + type_name: "bool" + value: true + - name: "c" + type_name: "int" + value: 42 + - name: "d" + type_name: "uint" + value: 777 + - name: "e" + type_name: "double" + value: 0.75 + - name: "f" + type_name: "bytes" + value: b"\xff\x00\x01" + - name: "g" + type_name: "string" + value: "plain 'single' \"double\"" + - name: "h" + type_name: "duration" + value: 1h2m3s + - name: "i" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + args: + - type_name: "timestamp" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, + ::testing::ValuesIn(GetRoundTripTestCases())); + +} // namespace +} // namespace cel diff --git a/env/internal/BUILD b/env/internal/BUILD new file mode 100644 index 000000000..ec4a0b15c --- /dev/null +++ b/env/internal/BUILD @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ext_registry", + srcs = ["ext_registry.cc"], + hdrs = ["ext_registry.h"], + deps = [ + "//compiler", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_ext_registry", + srcs = ["runtime_ext_registry.cc"], + hdrs = ["runtime_ext_registry.h"], + deps = [ + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ext_registry_test", + srcs = ["ext_registry_test.cc"], + deps = [ + ":ext_registry", + "//checker:type_checker_builder", + "//compiler", + "//internal:testing", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "runtime_ext_registry_test", + srcs = ["runtime_ext_registry_test.cc"], + deps = [ + ":runtime_ext_registry", + "//common:ast", + "//common:source", + "//common:value", + "//common:value_testing", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/internal/ext_registry.cc b/env/internal/ext_registry.cc new file mode 100644 index 000000000..b32239ac3 --- /dev/null +++ b/env/internal/ext_registry.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +void ExtensionRegistry::RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + library_registry_.push_back( + LibraryRegistration(name, alias, version, std::move(library_factory))); +} + +absl::StatusOr ExtensionRegistry::GetCompilerLibrary( + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name)); + } + version = max_version; + } + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.GetLibrary(); + } + } + + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/ext_registry.h b/env/internal/ext_registry.h new file mode 100644 index 000000000..ab5b67a24 --- /dev/null +++ b/env/internal/ext_registry.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +// A registry for CEL compiler extension libraries. +// +// Used to register and retrieve CompilerLibraries by name (or alias) and +// version. +class ExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory); + + absl::StatusOr GetCompilerLibrary(absl::string_view name, + int version) const; + + private: + class LibraryRegistration final { + public: + LibraryRegistration( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + factory_(std::move(library_factory)) {} + + CompilerLibrary GetLibrary() const { return factory_(); } + + private: + std::string name_; + std::string alias_; + int version_; + absl::AnyInvocable factory_; + + friend class ExtensionRegistry; + }; + + std::vector library_registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ diff --git a/env/internal/ext_registry_test.cc b/env/internal/ext_registry_test.cc new file mode 100644 index 000000000..9e345c781 --- /dev/null +++ b/env/internal/ext_registry_test.cc @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "parser/parser_interface.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Field; +using ::testing::HasSubstr; + +TEST(ExtensionRegistryTest, GetCompilerLibrary) { + ExtensionRegistry registry; + registry.RegisterCompilerLibrary("foo1", "f", 1, []() { + return CompilerLibrary("foo1_1", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo1", "f", 2, []() { + return CompilerLibrary("foo1_2", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo2", "", 1, []() { + return CompilerLibrary("foo2_1", nullptr, nullptr); + }); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo1#3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/internal/runtime_ext_registry.cc b/env/internal/runtime_ext_registry.cc new file mode 100644 index 000000000..dc78a38e3 --- /dev/null +++ b/env/internal/runtime_ext_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +void RuntimeExtensionRegistry::AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) { + registry_.push_back(Registration(name, alias, version, + std::move(function_registration_callback))); +} + +absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); + } + version = max_version; + } + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.RegisterExtensionFunctions(runtime_builder, + runtime_options); + } + } + + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/runtime_ext_registry.h b/env/internal/runtime_ext_registry.h new file mode 100644 index 000000000..67838519f --- /dev/null +++ b/env/internal/runtime_ext_registry.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +using FunctionRegistrationCallback = absl::AnyInvocable; + +// A registry for CEL runtime extension functions. +// +// Used to register runtime functions for extensions by name (or alias) and +// version. +class RuntimeExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback); + + absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options, + absl::string_view name, + int version) const; + + private: + class Registration final { + public: + Registration(absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + function_registration_callback_( + std::move(function_registration_callback)) {} + + absl::Status RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) const { + return function_registration_callback_(runtime_builder, runtime_options); + } + + private: + std::string name_; + std::string alias_; + int version_; + FunctionRegistrationCallback function_registration_callback_; + + friend class RuntimeExtensionRegistry; + }; + + std::vector registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ diff --git a/env/internal/runtime_ext_registry_test.cc b/env/internal/runtime_ext_registry_test.cc new file mode 100644 index 000000000..c6125d20f --- /dev/null +++ b/env/internal/runtime_ext_registry_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::test::StringValueIs; + +Value Hello1(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, old " + input.ToString() + "!", + context.arena()); +} + +Value Hello2(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, new " + input.ToString() + "!", + context.arena()); +} + +RuntimeExtensionRegistry GetRuntimeExtensionRegistry() { + RuntimeExtensionRegistry registry; + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 1, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("hello", &Hello1, + runtime_builder.function_registry()); + }); + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 2, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterMemberOverload("hello", &Hello2, + runtime_builder.function_registry()); + }); + return registry; +} + +class RuntimeExtensionRegistryTest : public testing::Test { + protected: + absl::StatusOr Run(std::string_view extension_name, int version, + std::string_view expr) { + const RuntimeExtensionRegistry registry = GetRuntimeExtensionRegistry(); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, + NewParserBuilder(ParserOptions())->Build()); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); + + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + cel::RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); + + CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( + runtime_builder, runtime_options, extension_name, version)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(RuntimeExtensionRegistryTest, SpecificExtensionVersion) { + EXPECT_THAT(Run("hello_extension", 1, "hello('world')"), + IsOkAndHolds(StringValueIs("Hello, old world!"))); +} + +TEST_F(RuntimeExtensionRegistryTest, LatestExtensionVersion) { + EXPECT_THAT(Run("hello_extension_alias", RuntimeExtensionRegistry::kLatest, + "'world'.hello()"), + IsOkAndHolds(StringValueIs("Hello, new world!"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc new file mode 100644 index 000000000..167e3b104 --- /dev/null +++ b/env/runtime_std_extensions.cc @@ -0,0 +1,130 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/optional.h" +#include "env/env_runtime.h" +#include "env/internal/runtime_ext_registry.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" +#include "runtime/optional_types.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { + +void RegisterStandardExtensions(EnvRuntime& env_runtime) { + env_internal::RuntimeExtensionRegistry& registry = + env_runtime.GetRuntimeExtensionRegistry(); + registry.AddFunctionRegistration( + "cel.lib.ext.bindings", "bindings", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.encoders", "encoders", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterEncodersFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.lists", "lists", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterListsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.math", "math", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.optional", "optional", version, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::EnableOptionalTypes(runtime_builder); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.protos", "protos", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.sets", "sets", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterSetsFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.strings", "strings", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterStringsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.regex", "regex", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterRegexExtensionFunctions( + runtime_builder); + }); +} + +} // namespace cel diff --git a/env/runtime_std_extensions.h b/env/runtime_std_extensions.h new file mode 100644 index 000000000..d7f714226 --- /dev/null +++ b/env/runtime_std_extensions.h @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ + +#include "env/env_runtime.h" + +namespace cel { + +// Registers the standard CEL extension functions with the given environment +// runtime. This makes them available, but does not enable them. See Env::Config +// for how to enable extensions. +// +// Included in the standard runtime environment: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// +// NOTE: Not included in the standard runtime environment yet - include manually +// if needed: +// - cel.lib.ext.regex (alias: "regex") +// +void RegisterStandardExtensions(EnvRuntime& env_runtime); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ diff --git a/env/runtime_std_extensions_test.cc b/env/runtime_std_extensions_test.cc new file mode 100644 index 000000000..4c7cb9829 --- /dev/null +++ b/env/runtime_std_extensions_test.cc @@ -0,0 +1,229 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/optional.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/strings.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string extension_name; + std::vector extension_versions = {0}; + int latest_extension_version = 0; + std::string expr; + bool requires_optional_extension = false; +}; + +using RuntimeStdExtensionTest = testing::TestWithParam; + +TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { + const TestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env); + + Config compiler_config; + // For the compilation step, assume latest version of the extension to ensure + // a successful compilation. Later, we will test the runtime with different + // extension versions. + ASSERT_THAT(compiler_config.AddExtensionConfig( + param.extension_name, Config::ExtensionConfig::kLatest), + IsOk()); + env.SetConfig(compiler_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + for (int version = 0; version <= param.latest_extension_version; ++version) { + Config runtime_config; + // Request a specific version of the extension to be configured in the + // runtime. + ASSERT_THAT( + runtime_config.AddExtensionConfig(param.extension_name, version), + IsOk()); + if (param.requires_optional_extension) { + ASSERT_THAT(runtime_config.AddExtensionConfig("optional"), IsOk()); + } + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool( + cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(runtime_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + absl::StatusOr> program_or = + runtime->CreateProgram(std::make_unique(*ast)); + + // If the function is not supported in this extension version, check that + // the program creation returned an error. + if (!absl::c_contains(param.extension_versions, version)) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr << " version: " << version; + continue; + } + + ASSERT_THAT(program_or, IsOk()) + << " expr: " << param.expr << " version: " << version; + std::unique_ptr program = *std::move(program_or); + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) + << " expr: " << param.expr << " version: " << version; + } +} + +std::vector GetRuntimeStdExtensionTestCases() { + return { + TestCase{ + // The "bindings" extension does not register any runtime functions - + // only macros. + .extension_name = "bindings", + .expr = "cel.bind(t, 42, t + 1) == 43", + }, + TestCase{ + .extension_name = "encoders", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].slice(0, 1) == [3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].sort() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.least([1, -2, 3]) == -2", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.floor(42.9) == 42.0", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.sqrt(4) == 2.0", + }, + TestCase{ + .extension_name = "optional", + .extension_versions = {0, 1, 2}, + .latest_extension_version = kOptionalExtensionLatestVersion, + .expr = "optional.of(1).hasValue()", + }, + TestCase{ + // No runtime functions. + .extension_name = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension_name = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {0, 1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "strings.quote('hello') == '\"hello\"'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "['hello', 'world'].join(', ') == 'hello, world'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'stressed'.reverse() == 'desserts'", + }, + TestCase{ + // No runtime functions. + .extension_name = "cel.lib.ext.comprev2", + .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", + }, + TestCase{ + .extension_name = "cel.lib.ext.regex", + .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", + .requires_optional_extension = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(RuntimeStdExtensionTest, RuntimeStdExtensionTest, + ValuesIn(GetRuntimeStdExtensionTestCases())); + +} // namespace +} // namespace cel From f4d72d97c79f45c323cbe1036cfc067953bc070a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 24 Mar 2026 19:36:40 -0700 Subject: [PATCH 011/143] Add planner support for checking runtime extensions in the AST. Introduces check_ast_extensions to extract and validate runtime-affecting extensions from the SourceInfo. Currently, flat_expr_builder returns an error if any runtime extensions are present, as support is not yet implemented. Fixes ExtensionSpec copy constructor and assignment to correctly handle null version_. PiperOrigin-RevId: 888974129 --- common/ast/metadata.cc | 10 +- common/ast/metadata_test.cc | 31 ++++++ eval/compiler/BUILD | 29 +++++- eval/compiler/check_ast_extensions.cc | 58 +++++++++++ eval/compiler/check_ast_extensions.h | 34 +++++++ eval/compiler/check_ast_extensions_test.cc | 110 +++++++++++++++++++++ eval/compiler/flat_expr_builder.cc | 32 ++++++ eval/compiler/flat_expr_builder_test.cc | 42 +++++--- 8 files changed, 328 insertions(+), 18 deletions(-) create mode 100644 eval/compiler/check_ast_extensions.cc create mode 100644 eval/compiler/check_ast_extensions.h create mode 100644 eval/compiler/check_ast_extensions_test.cc diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc index f744deb00..eecb0dbb3 100644 --- a/common/ast/metadata.cc +++ b/common/ast/metadata.cc @@ -61,12 +61,18 @@ const ExtensionSpec& ExtensionSpec::DefaultInstance() { ExtensionSpec::ExtensionSpec(const ExtensionSpec& other) : id_(other.id_), affected_components_(other.affected_components_), - version_(std::make_unique(*other.version_)) {} + version_(other.version_ == nullptr + ? nullptr + : std::make_unique(*other.version_)) {} ExtensionSpec& ExtensionSpec::operator=(const ExtensionSpec& other) { id_ = other.id_; affected_components_ = other.affected_components_; - version_ = std::make_unique(*other.version_); + if (other.version_ != nullptr) { + version_ = std::make_unique(other.version()); + } else { + version_ = nullptr; + } return *this; } diff --git a/common/ast/metadata_test.cc b/common/ast/metadata_test.cc index 4afb0d07d..5553f4c8f 100644 --- a/common/ast/metadata_test.cc +++ b/common/ast/metadata_test.cc @@ -25,6 +25,8 @@ namespace cel { namespace { +using ::testing::ElementsAre; + TEST(AstTest, ListTypeSpecMutableConstruction) { ListTypeSpec type; type.mutable_elem_type() = TypeSpec(PrimitiveType::kBool); @@ -264,5 +266,34 @@ TEST(AstTest, ExtensionSpecEquality) { std::make_unique(0, 0), {})); } +TEST(AstTest, ExtensionCopyMove) { + ExtensionSpec a("constant_folding", nullptr, {}); + a.mutable_version().set_major(1); + a.mutable_version().set_minor(2); + a.mutable_affected_components().push_back(ExtensionSpec::Component::kRuntime); + + ExtensionSpec b(a); + + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 1); + EXPECT_EQ(b.version().minor(), 2); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + ExtensionSpec c(std::move(b)); + EXPECT_EQ(c, a); + + a.set_version(nullptr); + b = a; + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 0); + EXPECT_EQ(b.version().minor(), 0); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + c = std::move(b); + EXPECT_EQ(c, a); +} + } // namespace } // namespace cel diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 62b208772..ed8e4d20c 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,7 +33,6 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", - "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", @@ -96,6 +95,7 @@ cc_library( "flat_expr_builder.h", ], deps = [ + ":check_ast_extensions", ":flat_expr_builder_extensions", ":resolver", "//base:ast", @@ -413,6 +413,33 @@ cc_library( ], ) +cc_library( + name = "check_ast_extensions", + srcs = ["check_ast_extensions.cc"], + hdrs = ["check_ast_extensions.h"], + deps = [ + "//common:ast", + "//common/ast:metadata", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "check_ast_extensions_test", + srcs = ["check_ast_extensions_test.cc"], + deps = [ + ":check_ast_extensions", + "//common:ast", + "//common:expr", + "//common/ast:metadata", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "resolver", srcs = ["resolver.cc"], diff --git a/eval/compiler/check_ast_extensions.cc b/eval/compiler/check_ast_extensions.cc new file mode 100644 index 000000000..37181b535 --- /dev/null +++ b/eval/compiler/check_ast_extensions.cc @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast) { + std::vector runtime_extensions; + absl::flat_hash_set seen_extension_ids; + + for (const cel::ExtensionSpec& extension : ast.source_info().extensions()) { + bool is_runtime = false; + for (const cel::ExtensionSpec::Component& component : + extension.affected_components()) { + if (component == cel::ExtensionSpec::Component::kRuntime) { + is_runtime = true; + break; + } + } + + if (!is_runtime) { + continue; + } + + if (!seen_extension_ids.insert(extension.id()).second) { + return absl::InvalidArgumentError( + absl::StrCat("duplicate extension ID: ", extension.id())); + } + runtime_extensions.push_back(extension); + } + + return runtime_extensions; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/check_ast_extensions.h b/eval/compiler/check_ast_extensions.h new file mode 100644 index 000000000..443c6ac09 --- /dev/null +++ b/eval/compiler/check_ast_extensions.h @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ + +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +// Extracts and validates extension tags from the AST `ast` that affect the +// runtime component. Returns the validated list of runtime extensions, or an +// error if there are multiple runtime extensions with the same ID. +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ diff --git a/eval/compiler/check_ast_extensions_test.cc b/eval/compiler/check_ast_extensions_test.cc new file mode 100644 index 000000000..9e5838905 --- /dev/null +++ b/eval/compiler/check_ast_extensions_test.cc @@ -0,0 +1,110 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/ast/metadata.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Ast; +using ::cel::Expr; +using ::cel::ExtensionSpec; +using ::cel::SourceInfo; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; +using ::testing::SizeIs; + +TEST(ExtractAndValidateRuntimeExtensionsTest, EmptyExtensions) { + Ast ast(Expr{}, SourceInfo{}); + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FiltersNonRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext2", nullptr, {ExtensionSpec::Component::kTypeChecker})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, ExtractsRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext2", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext3", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")), + Property(&ExtensionSpec::id, Eq("ext2")))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FailsOnDuplicateRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext1", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + StatusIs(absl::StatusCode::kInvalidArgument, + "duplicate extension ID: ext1")); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, IgnoresDuplicateNonRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 91822092c..e38c912c0 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -60,6 +60,7 @@ #include "common/kind.h" #include "common/type.h" #include "common/value.h" +#include "eval/compiler/check_ast_extensions.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" @@ -2512,6 +2513,22 @@ std::vector FlattenExpressionTable( return subexpression_indexes; } +absl::Status CheckAstExtensions( + const std::vector& extensions) { + for (const cel::ExtensionSpec& extension : extensions) { + if (extension.id() == "cel_block" && extension.version().major() == 1) { + // cel_block v1 is always supported. + continue; + } + + // TODO(uncreated-issue/89): Add support for json field names. + return absl::InvalidArgumentError(absl::StrCat( + "unsupported CEL extension: ", extension.id(), "@", + extension.version().major(), ".", extension.version().minor())); + } + return absl::OkStatus(); +} + } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( @@ -2525,6 +2542,21 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); + + absl::StatusOr> runtime_extensions = + ExtractAndValidateRuntimeExtensions(*ast); + + if (!runtime_extensions.ok()) { + CEL_RETURN_IF_ERROR(issue_collector.AddIssue( + RuntimeIssue::CreateError(runtime_extensions.status()))); + } + + auto status = CheckAstExtensions(*runtime_extensions); + if (!status.ok()) { + CEL_RETURN_IF_ERROR( + issue_collector.AddIssue(RuntimeIssue::CreateError(status))); + } + Resolver resolver(container_, function_registry_, type_registry_, GetTypeProvider(), options_.enable_qualified_type_identifiers); diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2b705398a..5fc20f01e 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "eval/compiler/flat_expr_builder.h" @@ -187,6 +185,20 @@ TEST(FlatExprBuilderTest, ExprUnset) { HasSubstr("Invalid empty expression"))); } +TEST(FlatExprBuilderTest, RuntimeExtensionsError) { + Expr expr; + SourceInfo source_info; + auto* ext = source_info.add_extensions(); + ext->set_id("ext1"); + ext->add_affected_components( + cel::expr::SourceInfo_Extension_Component_COMPONENT_RUNTIME); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unsupported CEL extension: ext1"))); +} + TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; From f6cd0c895e42954475b3734c867e8902557d1b37 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Wed, 25 Mar 2026 10:49:56 -0700 Subject: [PATCH 012/143] Fix YAML syntax error reporting on non-map YAML (e.g. "hello") PiperOrigin-RevId: 889328156 --- env/env_yaml.cc | 9 +++++++++ env/env_yaml_test.cc | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 0035709e9..5e7c9631d 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -998,6 +998,15 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { Config config; CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + if (!root.IsDefined() || root.IsNull()) { + return config; + } + + if (!root.IsMap()) { + return absl::InvalidArgumentError(FormatYamlErrorMessage( + yaml, "Invalid CEL environment config YAML", root.Mark())); + } + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 828a39b48..b34c25254 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -528,6 +528,12 @@ TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { INSTANTIATE_TEST_SUITE_P( EnvYamlParseTest, EnvYamlParseTest, ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Invalid CEL environment config YAML\n" + "| invalid yaml \n" + "| ^", + }, ParseTestCase{ .yaml = R"yaml( name: From 89d4f5fae45941cfb6a258256ead905a1d443414 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 27 Mar 2026 14:45:46 -0700 Subject: [PATCH 013/143] Update the CelExpressionBuilder documentation, relaxing the input parameter lifetime requirement PiperOrigin-RevId: 890645482 --- eval/public/cel_expression.h | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 3f52ad60d..4cf029e89 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -80,10 +80,10 @@ class CelExpressionBuilder { virtual ~CelExpressionBuilder() = default; // Creates CelExpression object from AST tree. - // expr specifies root of AST tree - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // expr specifies root of AST tree. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info) const = 0; @@ -91,9 +91,9 @@ class CelExpressionBuilder { // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, @@ -101,8 +101,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr) const { // Default implementation just passes through the expr and source info. @@ -113,8 +114,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const { From 4015e768b83cc2928a73e78a2bcd10c3894d0668 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 27 Mar 2026 14:56:12 -0700 Subject: [PATCH 014/143] Add option for format precision limit. High values for precision can lead to very large strings (exponential time / memory cost w.r.t format specifier), but are technically allowed. Will lower the default value in a follow-up. PiperOrigin-RevId: 890649934 --- env/runtime_std_extensions.cc | 5 ++- extensions/BUILD | 1 + extensions/formatting.cc | 60 ++++++++++++++++++++--------------- extensions/formatting.h | 13 ++++++-- extensions/formatting_test.cc | 26 +++++++++++++++ extensions/strings.cc | 19 +++++++---- extensions/strings.h | 34 +++++++++++++++++--- extensions/strings_test.cc | 44 +++++++++++++++++++++++++ 8 files changed, 161 insertions(+), 41 deletions(-) diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc index 167e3b104..b866a5965 100644 --- a/env/runtime_std_extensions.cc +++ b/env/runtime_std_extensions.cc @@ -105,8 +105,11 @@ void RegisterStandardExtensions(EnvRuntime& env_runtime) { "cel.lib.ext.strings", "strings", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { + cel::extensions::StringsExtensionOptions strings_options; + strings_options.version = version; return cel::extensions::RegisterStringsFunctions( - runtime_builder.function_registry(), runtime_options, version); + runtime_builder.function_registry(), runtime_options, + strings_options); }); } diff --git a/extensions/BUILD b/extensions/BUILD index fe97af46a..c393ec13a 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -609,6 +609,7 @@ cc_test( "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", + "//common:ast", "//common:decl", "//common:type", "//common:value", diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 6e58a7b86..935815569 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -14,20 +14,19 @@ #include "extensions/formatting.h" +#include #include #include #include #include #include #include -#include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" -#include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -64,7 +63,7 @@ absl::StatusOr FormatString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); absl::StatusOr>> ParsePrecision( - absl::string_view format) { + absl::string_view format, int max_precision) { if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; int64_t i = 1; @@ -80,9 +79,9 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } - if (precision > kMaxPrecision) { + if (precision > max_precision) { return absl::InvalidArgumentError( - absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision)); + absl::StrCat("precision specifier exceeds maximum of ", max_precision)); } return std::pair{i, precision}; } @@ -444,12 +443,13 @@ absl::StatusOr FormatScientific( } absl::StatusOr> ParseAndFormatClause( - absl::string_view format, const Value& value, + absl::string_view format, const Value& value, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { - CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format)); + CEL_ASSIGN_OR_RETURN(auto precision_pair, + ParsePrecision(format, max_precision)); auto [read, precision] = precision_pair; switch (format[read]) { case 's': { @@ -494,7 +494,7 @@ absl::StatusOr> ParseAndFormatClause( } absl::StatusOr Format( - const StringValue& format_value, const ListValue& args, + const StringValue& format_value, const ListValue& args, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -512,43 +512,51 @@ absl::StatusOr Format( } ++i; if (i >= format.size()) { - return absl::InvalidArgumentError("unexpected end of format string"); + return ErrorValue( + absl::InvalidArgumentError("unexpected end of format string")); } if (format[i] == '%') { result.push_back('%'); continue; } if (arg_index >= args_size) { - return absl::InvalidArgumentError( - absl::StrFormat("index %d out of range", arg_index)); + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("index %d out of range", arg_index))); } CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, message_factory, arena)); - CEL_ASSIGN_OR_RETURN( - auto clause, - ParseAndFormatClause(format.substr(i), value, descriptor_pool, - message_factory, arena, clause_scratch)); - absl::StrAppend(&result, clause.second); - i += clause.first; + + auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, + descriptor_pool, message_factory, arena, + clause_scratch); + if (!clause.ok()) { + return ErrorValue(std::move(clause).status()); + } + absl::StrAppend(&result, clause->second); + i += clause->first; } - return StringValue(arena, std::move(result)); + return StringValue::From(std::move(result), arena); } } // namespace -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options) { + const int max_precision = + std::clamp(format_options.max_precision, 0, kMaxPrecision); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: CreateDescriptor("format", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( - [](const StringValue& format, const ListValue& args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) { - return Format(format, args, descriptor_pool, message_factory, - arena); + [max_precision]( + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); }))); return absl::OkStatus(); } diff --git a/extensions/formatting.h b/extensions/formatting.h index bc2002006..88954857b 100644 --- a/extensions/formatting.h +++ b/extensions/formatting.h @@ -21,9 +21,18 @@ namespace cel::extensions { +struct StringsExtensionFormatOptions { + // The maximum precision to permit for formatting floating-point numbers. + int max_precision = 1000; +}; + // Register extension functions for string formatting. -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +// +// This implements (string).format([args...]) in the strings extension. Most +// users should add these functions via `extensions/strings.h` instead. +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options = {}); } // namespace cel::extensions diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 824f14e45..b80fe9bc0 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -96,6 +96,32 @@ TEST_P(StringFormatLimitsTest, FormatLimits) { } } +TEST(StringFormatLimitsTest, MaxPrecisionOption) { + google::protobuf::Arena arena; + const RuntimeOptions options; + StringsExtensionFormatOptions format_options; + format_options.max_precision = 99; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), + options, format_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", + "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus().message(), + HasSubstr("precision specifier exceeds maximum of 99")); +} + INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, ValuesIn({ "double('%.326f'.format([x])) == x", diff --git a/extensions/strings.cc b/extensions/strings.cc index ed6f27319..54fda20d6 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -305,9 +305,10 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options, - int version) { +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options) { + const int version = extension_options.version; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -382,7 +383,8 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, return absl::OkStatus(); } - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( + registry, options, {extension_options.max_precision})); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); @@ -412,13 +414,16 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options) { + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options) { return RegisterStringsFunctions( registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + google::api::expr::runtime::ConvertToRuntimeOptions(options), + extension_options); } -CheckerLibrary StringsCheckerLibrary(int version) { +CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { + const int version = options.version; return {"strings", [version](TypeCheckerBuilder& builder) { return RegisterStringsDecls(builder, version); }}; diff --git a/extensions/strings.h b/extensions/strings.h index 3cbc9f19f..3ec92d603 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -27,21 +27,45 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; +struct StringsExtensionOptions { + int version = kStringsExtensionLatestVersion; + + // Maximum precision allowed for floating point format specifiers in + // format() function. This is used for both fixed and scientific notations. + // Value must be in the range [0, 1000], otherwise clamped. + // + // Does not affect default precisions for %e and %f format specifiers. + int max_precision = 1000; +}; + // Register extension functions for strings. absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options); + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options = {}); CheckerLibrary StringsCheckerLibrary( - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); + +inline CheckerLibrary StringsCheckerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCheckerLibrary(options); +} inline CompilerLibrary StringsCompilerLibrary( - int version = kStringsExtensionLatestVersion) { - return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(version)); + const StringsExtensionOptions& options = {}) { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); +} + +inline CompilerLibrary StringsCompilerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCompilerLibrary(options); } } // namespace cel::extensions diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index a5d56eaed..c3059808f 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -27,6 +27,7 @@ #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" @@ -50,6 +51,7 @@ namespace cel::extensions { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; @@ -85,6 +87,48 @@ TEST(StringsCheckerLibrary, SmokeTest) { )~bool^equals)"); } +TEST(StringsExtTest, MaxPrecisionOption) { + StringsExtensionOptions extension_options; + extension_options.max_precision = 99; + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("'abc %.100f'.format([2.0])", "")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), + opts, extension_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("precision specifier exceeds maximum of 99"))); +} + using StringsExtFunctionsTest = testing::TestWithParam; TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { From c93de861d872fb345512117b79aa2c54f4f93759 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 30 Mar 2026 10:32:41 -0700 Subject: [PATCH 015/143] Enable escaped backtick quoted identifiers by default. PiperOrigin-RevId: 891789266 --- parser/options.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/parser/options.h b/parser/options.h index ad03102e8..a41d16104 100644 --- a/parser/options.h +++ b/parser/options.h @@ -57,8 +57,9 @@ struct ParserOptions final { // Enables support for identifier quoting syntax: // "message.`skewer-case-field`" // - // Limited to field specifiers in select and message creation. - bool enable_quoted_identifiers = false; + // Limited to field specifiers in select and message creation, + // enabled by default + bool enable_quoted_identifiers = true; }; } // namespace cel From 5a3463337cf2a9b90b53833af2bbc1f35da90d64 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 30 Mar 2026 11:20:09 -0700 Subject: [PATCH 016/143] Expose TypeInfoToType and NewCompilerBuilder APIs PiperOrigin-RevId: 891814525 --- env/BUILD | 39 ++++++--- env/env.cc | 168 ++++----------------------------------- env/env.h | 5 ++ env/env_test.cc | 112 -------------------------- env/env_yaml.cc | 5 +- env/env_yaml_test.cc | 6 +- env/type_info.cc | 178 ++++++++++++++++++++++++++++++++++++++++++ env/type_info.h | 35 +++++++++ env/type_info_test.cc | 127 ++++++++++++++++++++++++++++++ 9 files changed, 397 insertions(+), 278 deletions(-) create mode 100644 env/type_info.cc create mode 100644 env/type_info.h create mode 100644 env/type_info_test.cc diff --git a/env/BUILD b/env/BUILD index f5ce35557..8d477cc1f 100644 --- a/env/BUILD +++ b/env/BUILD @@ -19,17 +19,28 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "config", - srcs = ["config.cc"], - hdrs = ["config.h"], + srcs = [ + "config.cc", + "type_info.cc", + ], + hdrs = [ + "config.h", + "type_info.h", + ], deps = [ "//common:constant", + "//common:type", + "//common:type_kind", "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -43,23 +54,16 @@ cc_library( "//common:constant", "//common:decl", "//common:type", - "//common:type_kind", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//env/internal:ext_registry", "//internal:status_macros", "//parser:macro", - "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/functional:overload", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) @@ -163,6 +167,20 @@ cc_test( ], ) +cc_test( + name = "type_info_test", + srcs = ["type_info_test.cc"], + deps = [ + ":config", + "//common:type", + "//common:type_proto", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "env_test", srcs = ["env_test.cc"], @@ -173,7 +191,6 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", - "//common:ast_proto", "//common:constant", "//common:decl", "//common:expr", diff --git a/env/env.cc b/env/env.cc index 2c2555f14..5a4198497 100644 --- a/env/env.cc +++ b/env/env.cc @@ -15,12 +15,10 @@ #include "env/env.h" #include -#include #include #include #include -#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -28,11 +26,11 @@ #include "common/constant.h" #include "common/decl.h" #include "common/type.h" -#include "common/type_kind.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "env/config.h" +#include "env/type_info.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "google/protobuf/arena.h" @@ -95,149 +93,6 @@ absl::StatusOr MakeStdlibSubset( return subset; } -std::optional TypeNameToTypeKind(absl::string_view type_name) { - // Excluded types: - // kUnknown - // kError - // kTypeParam - // kFunction - // kEnum - - static const absl::NoDestructor< - absl::flat_hash_map> - kTypeNameToTypeKind({ - {"null", TypeKind::kNull}, - {"bool", TypeKind::kBool}, - {"int", TypeKind::kInt}, - {"uint", TypeKind::kUint}, - {"double", TypeKind::kDouble}, - {"string", TypeKind::kString}, - {"bytes", TypeKind::kBytes}, - {"timestamp", TypeKind::kTimestamp}, - {TimestampType::kName, TypeKind::kTimestamp}, - {"duration", TypeKind::kDuration}, - {DurationType::kName, TypeKind::kDuration}, - {"list", TypeKind::kList}, - {"map", TypeKind::kMap}, - {"", TypeKind::kDyn}, - {"any", TypeKind::kAny}, - {"dyn", TypeKind::kDyn}, - {BoolWrapperType::kName, TypeKind::kBoolWrapper}, - {IntWrapperType::kName, TypeKind::kIntWrapper}, - {UintWrapperType::kName, TypeKind::kUintWrapper}, - {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, - {StringWrapperType::kName, TypeKind::kStringWrapper}, - {BytesWrapperType::kName, TypeKind::kBytesWrapper}, - {"type", TypeKind::kType}, - }); - if (auto it = kTypeNameToTypeKind->find(type_name); - it != kTypeNameToTypeKind->end()) { - return it->second; - } - - return std::nullopt; -} - -absl::StatusOr TypeInfoToType( - const Config::TypeInfo& type_info, google::protobuf::Arena* arena, - const google::protobuf::DescriptorPool* descriptor_pool) { - if (type_info.is_type_param) { - return TypeParamType(type_info.name); - } - - std::optional type_kind = TypeNameToTypeKind(type_info.name); - if (!type_kind.has_value()) { - if (type_info.params.empty() && descriptor_pool != nullptr) { - const google::protobuf::Descriptor* type = - descriptor_pool->FindMessageTypeByName(type_info.name); - if (type != nullptr) { - return MessageType(type); - } - } - // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types - std::vector parameter_types; - for (const Config::TypeInfo& param : type_info.params) { - CEL_ASSIGN_OR_RETURN(Type parameter_type, - TypeInfoToType(param, arena, descriptor_pool)); - parameter_types.push_back(parameter_type); - } - - return OpaqueType(arena, type_info.name, parameter_types); - } - - switch (*type_kind) { - case TypeKind::kNull: - return NullType(); - case TypeKind::kBool: - return BoolType(); - case TypeKind::kInt: - return IntType(); - case TypeKind::kUint: - return UintType(); - case TypeKind::kDouble: - return DoubleType(); - case TypeKind::kString: - return StringType(); - case TypeKind::kBytes: - return BytesType(); - case TypeKind::kDuration: - return DurationType(); - case TypeKind::kTimestamp: - return TimestampType(); - case TypeKind::kList: { - Type element_type; - if (!type_info.params.empty()) { - CEL_ASSIGN_OR_RETURN( - element_type, - TypeInfoToType(type_info.params[0], arena, descriptor_pool)); - } else { - element_type = DynType(); - } - return ListType(arena, element_type); - } - case TypeKind::kMap: { - Type key_type = DynType(); - Type value_type = DynType(); - if (!type_info.params.empty()) { - CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], - arena, descriptor_pool)); - } - if (type_info.params.size() > 1) { - CEL_ASSIGN_OR_RETURN( - value_type, - TypeInfoToType(type_info.params[1], arena, descriptor_pool)); - } - return MapType(arena, key_type, value_type); - } - case TypeKind::kDyn: - return DynType(); - case TypeKind::kAny: - return AnyType(); - case TypeKind::kBoolWrapper: - return BoolWrapperType(); - case TypeKind::kIntWrapper: - return IntWrapperType(); - case TypeKind::kUintWrapper: - return UintWrapperType(); - case TypeKind::kDoubleWrapper: - return DoubleWrapperType(); - case TypeKind::kStringWrapper: - return StringWrapperType(); - case TypeKind::kBytesWrapper: - return BytesWrapperType(); - case TypeKind::kType: { - if (type_info.params.empty()) { - return TypeType(arena, DynType()); - } - CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], arena, - descriptor_pool)); - return TypeType(arena, type); - } - default: - return DynType(); - } -} - absl::StatusOr FunctionConfigToFunctionDecl( const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool* descriptor_pool) { @@ -250,12 +105,12 @@ absl::StatusOr FunctionConfigToFunctionDecl( overload_decl.set_member(overload_config.is_member_function); for (const Config::TypeInfo& parameter : overload_config.parameters) { CEL_ASSIGN_OR_RETURN(Type parameter_type, - TypeInfoToType(parameter, arena, descriptor_pool)); + TypeInfoToType(parameter, descriptor_pool, arena)); overload_decl.mutable_args().push_back(parameter_type); } CEL_ASSIGN_OR_RETURN( Type return_type, - TypeInfoToType(overload_config.return_type, arena, descriptor_pool)); + TypeInfoToType(overload_config.return_type, descriptor_pool, arena)); overload_decl.set_result(return_type); CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); } @@ -264,7 +119,11 @@ absl::StatusOr FunctionConfigToFunctionDecl( } // namespace -absl::StatusOr> Env::NewCompiler() { +Env::Env() { + compiler_options_.parser_options.enable_quoted_identifiers = true; +} + +absl::StatusOr> Env::NewCompilerBuilder() { CEL_ASSIGN_OR_RETURN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); @@ -295,8 +154,8 @@ absl::StatusOr> Env::NewCompiler() { VariableDecl variable_decl; variable_decl.set_name(variable_config.name); CEL_ASSIGN_OR_RETURN(Type type, - TypeInfoToType(variable_config.type_info, arena, - descriptor_pool_.get())); + TypeInfoToType(variable_config.type_info, + descriptor_pool_.get(), arena)); variable_decl.set_type(type); if (variable_config.value.has_value()) { variable_decl.set_value(variable_config.value); @@ -312,7 +171,12 @@ absl::StatusOr> Env::NewCompiler() { CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); } - return compiler_builder->Build(); + return compiler_builder; } +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler_builder, + NewCompilerBuilder()); + return compiler_builder->Build(); +} } // namespace cel diff --git a/env/env.h b/env/env.h index f46e5947c..9830b67d7 100644 --- a/env/env.h +++ b/env/env.h @@ -36,6 +36,8 @@ namespace cel { // customizable CEL features. class Env { public: + Env(); + // Registers a `CompilerLibrary` with the environment. Note that the library // does not automatically get added to a `Compiler`. `NewCompiler` relies // on `Config` to determine which libraries to load. @@ -57,6 +59,9 @@ class Env { void SetConfig(const Config& config) { config_ = config; } + absl::StatusOr> NewCompilerBuilder(); + + // Shortcut for NewCompilerBuilder() followed by Build(). absl::StatusOr> NewCompiler(); private: diff --git a/env/env_test.cc b/env/env_test.cc index dcd2d97fa..076eb57bc 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -30,7 +30,6 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" -#include "common/ast_proto.h" #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" @@ -38,7 +37,6 @@ #include "common/value.h" #include "compiler/compiler.h" #include "env/config.h" -#include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -52,16 +50,13 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" -#include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; -using ::cel::internal::test::EqualsProto; using ::testing::HasSubstr; using ::testing::IsEmpty; -using ::testing::NotNull; using ::testing::Property; using ::testing::UnorderedElementsAre; using ::testing::Values; @@ -319,113 +314,6 @@ TEST(ContainerConfigTest, ContainerConfig) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } -struct TypeInfoTestCase { - Config::TypeInfo type_info; - std::string expected_type_pb; -}; - -using TypeInfoTest = testing::TestWithParam; - -TEST_P(TypeInfoTest, TypeInfo) { - const TypeInfoTestCase& param = GetParam(); - cel::expr::Type expected_type_pb; - ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, - &expected_type_pb)); - - Env env; - env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); - Config config; - Config::VariableConfig variable_config; - variable_config.name = "test"; - variable_config.type_info = param.type_info; - ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); - env.SetConfig(config); - - ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); - ASSERT_THAT(compiler, NotNull()); - ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("test")); - EXPECT_THAT(result.GetIssues(), IsEmpty()) - << " error: " << result.FormatError(); - - // Obtain the inferred return type of the expression `test`. - const Ast* ast = result.GetAst(); - ASSERT_THAT(ast, NotNull()); - cel::expr::CheckedExpr checked_expr; - ASSERT_THAT(cel::AstToCheckedExpr(*ast, &checked_expr), IsOk()); - auto it = checked_expr.type_map().find(checked_expr.expr().id()); - ASSERT_NE(it, checked_expr.type_map().end()); - - cel::expr::Type actual_type_pb = it->second; - EXPECT_THAT(actual_type_pb, EqualsProto(expected_type_pb)); -} - -std::vector GetTypeInfoTestCases() { - return { - TypeInfoTestCase{ - .type_info = {.name = "int"}, - .expected_type_pb = "primitive: INT64", - }, - TypeInfoTestCase{ - .type_info = {.name = "list", - .params = {Config::TypeInfo{.name = "int"}}}, - .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", - }, - TypeInfoTestCase{ - .type_info = {.name = "list"}, - .expected_type_pb = "list_type { elem_type { dyn {} }}", - }, - TypeInfoTestCase{ - .type_info = {.name = "map", - .params = {Config::TypeInfo{.name = "string"}, - Config::TypeInfo{.name = "int"}}}, - .expected_type_pb = "map_type { key_type { primitive: STRING } " - "value_type { primitive: INT64 }}", - }, - TypeInfoTestCase{ - .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, - .expected_type_pb = - "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", - }, - TypeInfoTestCase{ - .type_info = {.name = "A", - .params = {Config::TypeInfo{.name = "B", - .is_type_param = true}}}, - // TypeParam is replaced with dyn by the type checker. - .expected_type_pb = - "abstract_type { name: 'A' parameter_types { dyn {} } }", - }, - TypeInfoTestCase{ - .type_info = {.name = "any"}, - .expected_type_pb = "well_known: ANY", - }, - TypeInfoTestCase{ - .type_info = {.name = "timestamp"}, - .expected_type_pb = "well_known: TIMESTAMP", - }, - TypeInfoTestCase{ - .type_info = {.name = "google.protobuf.DoubleValue"}, - .expected_type_pb = "wrapper: DOUBLE", - }, - TypeInfoTestCase{ - .type_info = {.name = "type", - .params = {Config::TypeInfo{.name = "duration"}}}, - .expected_type_pb = "type: { well_known: DURATION }", - }, - TypeInfoTestCase{ - .type_info = {.name = "parameterized", - .params = {{.name = "A", .is_type_param = true}, - {.name = "double"}}}, - // TypeParam is replaced with dyn by the type checker. - .expected_type_pb = "abstract_type { name: 'parameterized' " - "parameter_types { dyn {} } " - "parameter_types { primitive: DOUBLE } }", - }, - }; -} - -INSTANTIATE_TEST_SUITE_P(VariableConfigTest, TypeInfoTest, - ValuesIn(GetTypeInfoTestCases())); - struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 5e7c9631d..a6f66bd83 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -552,10 +552,13 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, variable_config.type_info = type_info; - if (constant_kind_case != ConstantKindCase::kUnspecified) { + if (constant_kind_case != ConstantKindCase::kUnspecified && + !value_str.empty()) { CEL_ASSIGN_OR_RETURN( variable_config.value, ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } else if (constant_kind_case == ConstantKindCase::kNull) { + variable_config.value = Constant(nullptr); } CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index b34c25254..c3e4839af 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -226,8 +226,10 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); - EXPECT_EQ(variable_config.type_info.name, param.type_name); - EXPECT_EQ(variable_config.value, param.expected_constant); + EXPECT_EQ(variable_config.type_info.name, param.type_name) + << " yaml: " << yaml; + EXPECT_EQ(variable_config.value, param.expected_constant) + << " yaml: " << yaml; } std::vector GetParseConstantTestCases() { diff --git a/env/type_info.cc b/env/type_info.cc new file mode 100644 index 000000000..ed72a842f --- /dev/null +++ b/env/type_info.cc @@ -0,0 +1,178 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} +} // namespace + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return Type::Message(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, descriptor_pool, arena)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], descriptor_pool, arena)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], descriptor_pool, arena)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} + +} // namespace cel diff --git a/env/type_info.h b/env/type_info.h new file mode 100644 index 000000000..bb3cfde43 --- /dev/null +++ b/env/type_info.h @@ -0,0 +1,35 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ + +#include "absl/status/statusor.h" +#include "common/type.h" +#include "env/config.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Converts a Config::TypeInfo to a cel::Type. Returns an error if the type_info +// cannot be converted to a known cel::Type, a list configured with more than +// one parameter. +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc new file mode 100644 index 000000000..ca9d0467c --- /dev/null +++ b/env/type_info_test.cc @@ -0,0 +1,127 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include + +#include "common/type.h" +#include "common/type_proto.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using absl_testing::IsOk; +using testing::ValuesIn; + +struct TestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + cel::internal::GetTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN( + cel::Type actual_type, + cel::TypeInfoToType(param.type_info, descriptor_pool, &arena)); + + cel::expr::Type actual_type_pb; + ASSERT_THAT(cel::TypeToProto(actual_type, &actual_type_pb), IsOk()); + EXPECT_THAT(actual_type_pb, + cel::internal::test::EqualsProto(expected_type_pb)); +} + +std::vector GetTestCases() { + return { + TestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { type_param: 'B' } }", + }, + TestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { type_param: 'A' } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); + +} // namespace +} // namespace cel From fb214306e1517a61220af3a5cf51395b49a0299c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 10:20:16 -0700 Subject: [PATCH 017/143] Remove support for hierarchical type_envs. This was not used anywhere and extended envs will likely need to deep copy. PiperOrigin-RevId: 892414339 --- checker/internal/type_check_env.cc | 91 ++++++++++++------------------ checker/internal/type_check_env.h | 20 +------ 2 files changed, 37 insertions(+), 74 deletions(-) diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index d856a7230..e76621435 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -34,25 +34,18 @@ namespace cel::checker_internal { const VariableDecl* absl_nullable TypeCheckEnv::LookupVariable( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->variables_.find(name); it != scope->variables_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = variables_.find(name); it != variables_.end()) { + return &it->second; } return nullptr; } const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->functions_.find(name); it != scope->functions_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = functions_.find(name); it != functions_.end()) { + return &it->second; } + return nullptr; } @@ -71,17 +64,13 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return Type::Enum(enum_descriptor); } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto type = (*iter)->FindType(name); - if (!type.ok() || type->has_value()) { - return type; - } + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto type = (*iter)->FindType(name); + if (!type.ok() || type->has_value()) { + return type; } - scope = scope->parent_; - } while ((scope != nullptr)); + } return absl::nullopt; } @@ -106,26 +95,21 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return decl; } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type, value); - if (!enum_constant.ok()) { - return enum_constant.status(); - } - if (enum_constant->has_value()) { - auto decl = - MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", - (**enum_constant).value_name), - (**enum_constant).type); - decl.set_value( - Constant(static_cast((**enum_constant).number))); - return decl; - } + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto enum_constant = (*iter)->FindEnumConstant(type, value); + if (!enum_constant.ok()) { + return enum_constant.status(); } - scope = scope->parent_; - } while (scope != nullptr); + if (enum_constant->has_value()) { + auto decl = + MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", + (**enum_constant).value_name), + (**enum_constant).type); + decl.set_value(Constant(static_cast((**enum_constant).number))); + return decl; + } + } return absl::nullopt; } @@ -165,22 +149,17 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( return cel::MessageTypeField(field_descriptor); } } - const TypeCheckEnv* scope = this; - do { - // Check the type providers in reverse registration order. - // Note: this doesn't allow for shadowing a type with a subset type of the - // same name -- the parent type provider will still be considered when - // checking field accesses. - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto field_info = - (*iter)->FindStructTypeFieldByName(type_name, field_name); - if (!field_info.ok() || field_info->has_value()) { - return field_info; - } + // Check the type providers in reverse registration order. + // Note: this doesn't allow for shadowing a type with a subset type of the + // same name -- the prior type provider will still be considered when + // checking field accesses. + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto field_info = (*iter)->FindStructTypeFieldByName(type_name, field_name); + if (!field_info.ok() || field_info->has_value()) { + return field_info; } - scope = scope->parent_; - } while (scope != nullptr); + } return absl::nullopt; } diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index a4d242fdf..5c8b3629c 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -89,17 +89,14 @@ class TypeCheckEnv { explicit TypeCheckEnv( absl_nonnull std::shared_ptr descriptor_pool) - : descriptor_pool_(std::move(descriptor_pool)), - container_(""), - parent_(nullptr) {} + : descriptor_pool_(std::move(descriptor_pool)), container_("") {} TypeCheckEnv(absl_nonnull std::shared_ptr descriptor_pool, std::shared_ptr arena) : descriptor_pool_(std::move(descriptor_pool)), arena_(std::move(arena)), - container_(""), - parent_(nullptr) {} + container_("") {} // Move-only. TypeCheckEnv(TypeCheckEnv&&) = default; @@ -163,9 +160,6 @@ class TypeCheckEnv { functions_[decl.name()] = std::move(decl); } - const TypeCheckEnv* absl_nullable parent() const { return parent_; } - void set_parent(TypeCheckEnv* parent) { parent_ = parent; } - // Returns the declaration for the given name if it is found in the current // or any parent scope. // Note: the returned declaration ptr is only valid as long as no changes are @@ -184,10 +178,6 @@ class TypeCheckEnv { absl::StatusOr> LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view type_name) const; - TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return TypeCheckEnv(this); - } - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_.get(); } @@ -200,11 +190,6 @@ class TypeCheckEnv { } private: - explicit TypeCheckEnv(const TypeCheckEnv* absl_nonnull parent) - : descriptor_pool_(parent->descriptor_pool_), - container_(parent != nullptr ? parent->container() : ""), - parent_(parent) {} - absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; @@ -212,7 +197,6 @@ class TypeCheckEnv { // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; std::string container_; - const TypeCheckEnv* absl_nullable parent_; // Maps fully qualified names to declarations. absl::flat_hash_map variables_; From 463ddc0a6dd94979c7f3779ec6e8c68e7e89ad2b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 13:30:29 -0700 Subject: [PATCH 018/143] Delete types/interfaces related to TypeFactory and TypeManager. PiperOrigin-RevId: 892512872 --- common/BUILD | 3 -- common/type_factory.h | 30 ---------------- common/type_introspector.h | 2 -- common/type_manager.h | 57 ------------------------------ common/types/legacy_type_manager.h | 45 ----------------------- common/values/opaque_value.h | 2 +- eval/compiler/flat_expr_builder.h | 12 ------- eval/public/cel_type_registry.h | 2 +- 8 files changed, 2 insertions(+), 151 deletions(-) delete mode 100644 common/type_factory.h delete mode 100644 common/type_manager.h delete mode 100644 common/types/legacy_type_manager.h diff --git a/common/BUILD b/common/BUILD index da96b1c98..a4ac6a3ef 100644 --- a/common/BUILD +++ b/common/BUILD @@ -541,12 +541,9 @@ cc_library( ], ) + [ "type.h", - "type_factory.h", "type_introspector.h", - "type_manager.h", ], deps = [ - ":memory", ":type_kind", "//internal:string_pool", "@com_google_absl//absl/algorithm:container", diff --git a/common/type_factory.h b/common/type_factory.h deleted file mode 100644 index 33829ea8b..000000000 --- a/common/type_factory.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ - -namespace cel { - -// `TypeFactory` is the preferred way for constructing compound types such as -// lists, maps, structs, and opaques. It caches types and avoids constructing -// them multiple times. -class TypeFactory { - public: - virtual ~TypeFactory() = default; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ diff --git a/common/type_introspector.h b/common/type_introspector.h index 7f4a19a31..159e49ab4 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -24,8 +24,6 @@ namespace cel { -class TypeFactory; - // `TypeIntrospector` is an interface which allows querying type-related // information. It handles type introspection, but not type reflection. That is, // it is not capable of instantiating new values or understanding values. Its diff --git a/common/type_manager.h b/common/type_manager.h deleted file mode 100644 index 354f4c9b8..000000000 --- a/common/type_manager.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/memory.h" -#include "common/type.h" -#include "common/type_factory.h" -#include "common/type_introspector.h" - -namespace cel { - -// `TypeManager` is an additional layer on top of `TypeFactory` and -// `TypeIntrospector` which combines the two and adds additional functionality. -class TypeManager : public virtual TypeFactory { - public: - virtual ~TypeManager() = default; - - // See `TypeIntrospector::FindType`. - absl::StatusOr> FindType(absl::string_view name) { - return GetTypeIntrospector().FindType(name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - const StructType& type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - protected: - virtual const TypeIntrospector& GetTypeIntrospector() const = 0; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ diff --git a/common/types/legacy_type_manager.h b/common/types/legacy_type_manager.h deleted file mode 100644 index 238335b52..000000000 --- a/common/types/legacy_type_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ - -#include "common/memory.h" -#include "common/type_introspector.h" -#include "common/type_manager.h" - -namespace cel::common_internal { - -// `LegacyTypeManager` is an implementation which should be used when -// converting between `cel::Value` and `google::api::expr::runtime::CelValue` -// and only then. -class LegacyTypeManager : public virtual TypeManager { - public: - explicit LegacyTypeManager(const TypeIntrospector& type_introspector) - : type_introspector_(type_introspector) {} - - protected: - const TypeIntrospector& GetTypeIntrospector() const final { - return type_introspector_; - } - - private: - const TypeIntrospector& type_introspector_; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ diff --git a/common/values/opaque_value.h b/common/values/opaque_value.h index 273b7889a..57af78ae0 100644 --- a/common/values/opaque_value.h +++ b/common/values/opaque_value.h @@ -52,7 +52,7 @@ class Value; class OpaqueValueInterface; class OpaqueValueInterfaceIterator; class OpaqueValue; -class TypeFactory; + using OpaqueValueContent = CustomValueContent; struct OpaqueValueDispatcher { diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index eab1e7ff8..7d770b443 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -53,18 +53,6 @@ class FlatExprBuilder { type_registry_(env_->type_registry), use_legacy_type_provider_(use_legacy_type_provider) {} - FlatExprBuilder( - absl_nonnull std::shared_ptr env, - const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) - : env_(std::move(env)), - options_(options), - container_(options.container), - function_registry_(function_registry), - type_registry_(type_registry), - use_legacy_type_provider_(use_legacy_type_provider) {} - void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 290726bfe..0c01eb8e9 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -86,7 +86,7 @@ class CelTypeRegistry { // registry. // // This is a composited type provider that should check in order: - // - builtins (via TypeManager) + // - builtins // - custom enumerations // - registered extension type providers in the order registered. const cel::TypeProvider& GetTypeProvider() const { From 8773590dc5fd54ac50963b78bb33cbb82043d28f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 18:05:30 -0700 Subject: [PATCH 019/143] Refactor cel::TypeIntrospector - simplify base class logic, add an explicit implementation for looking up WKTs - Make WellKnownType lookups explicit in runtime implementations. - delete extensions/protobuf/type_introspector and related. They should be unused now and don't work as expected with the type checker or runtime. PiperOrigin-RevId: 892637710 --- common/BUILD | 1 + common/type_introspector.cc | 55 +++++----- common/type_introspector.h | 51 ++++++++- .../thread_compatible_type_introspector.h | 34 ------ eval/compiler/flat_expr_builder.h | 2 - eval/public/structs/legacy_type_provider.cc | 8 ++ extensions/protobuf/BUILD | 36 ------ extensions/protobuf/type_introspector.cc | 80 -------------- extensions/protobuf/type_introspector.h | 58 ---------- extensions/protobuf/type_introspector_test.cc | 103 ------------------ extensions/protobuf/type_reflector.h | 41 ------- runtime/internal/BUILD | 1 + runtime/internal/runtime_type_provider.cc | 17 ++- 13 files changed, 96 insertions(+), 391 deletions(-) delete mode 100644 common/types/thread_compatible_type_introspector.h delete mode 100644 extensions/protobuf/type_introspector.cc delete mode 100644 extensions/protobuf/type_introspector.h delete mode 100644 extensions/protobuf/type_introspector_test.cc delete mode 100644 extensions/protobuf/type_reflector.h diff --git a/common/BUILD b/common/BUILD index a4ac6a3ef..8dd8921cc 100644 --- a/common/BUILD +++ b/common/BUILD @@ -548,6 +548,7 @@ cc_library( "//internal:string_pool", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/common/type_introspector.cc b/common/type_introspector.cc index c69235b3b..6d5158a2f 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -211,35 +211,6 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { } // namespace -absl::StatusOr> TypeIntrospector::FindType( - absl::string_view name) const { - const auto& well_known_types = GetWellKnownTypesMap(); - if (auto it = well_known_types.find(name); it != well_known_types.end()) { - return it->second.type; - } - return FindTypeImpl(name); -} - -absl::StatusOr> -TypeIntrospector::FindEnumConstant(absl::string_view type, - absl::string_view value) const { - if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { - return EnumConstant{NullType{}, "google.protobuf.NullValue", "NULL_VALUE", - 0}; - } - return FindEnumConstantImpl(type, value); -} - -absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByName(absl::string_view type, - absl::string_view name) const { - const auto& well_known_types = GetWellKnownTypesMap(); - if (auto it = well_known_types.find(type); it != well_known_types.end()) { - return it->second.FieldByName(name); - } - return FindStructTypeFieldByNameImpl(type, name); -} - absl::StatusOr> TypeIntrospector::FindTypeImpl( absl::string_view) const { return absl::nullopt; @@ -257,4 +228,30 @@ TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, return absl::nullopt; } +absl::optional FindWellKnownType(absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(name); it != well_known_types.end()) { + return it->second.type; + } + return absl::nullopt; +} + +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value) { + if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { + return TypeIntrospector::EnumConstant{ + NullType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; + } + return absl::nullopt; +} + +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(type); it != well_known_types.end()) { + return it->second.FieldByName(name); + } + return absl::nullopt; +} + } // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h index 159e49ab4..fb6ea09c1 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -43,17 +43,23 @@ class TypeIntrospector { virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. - absl::StatusOr> FindType(absl::string_view name) const; + absl::StatusOr> FindType(absl::string_view name) const { + return FindTypeImpl(name); + } // `FindEnumConstant` find a fully qualified enumerator name `name` in enum // type `type`. absl::StatusOr> FindEnumConstant( - absl::string_view type, absl::string_view value) const; + absl::string_view type, absl::string_view value) const { + return FindEnumConstantImpl(type, value); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) const; + absl::string_view type, absl::string_view name) const { + return FindStructTypeFieldByNameImpl(type, name); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. @@ -74,6 +80,45 @@ class TypeIntrospector { absl::string_view name) const; }; +// Looks up a well-known type by name. +absl::optional FindWellKnownType(absl::string_view name); + +// Looks up a well-known enum constant by type and value. +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value); + +// Looks up a well-known struct type field by type and field name. +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name); + +// `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which +// handles well known types that are treated specially by CEL. +// +// This also serves as a minimal implementation of a TypeInstrospector when no +// custom types are present. +// +// This class has no mutable state, so trivially thread-safe. +class WellKnownTypeIntrospector : public virtual TypeIntrospector { + public: + WellKnownTypeIntrospector() = default; + + private: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final { + return FindWellKnownType(name); + } + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final { + return FindWellKnownTypeEnumConstant(type, value); + } + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final { + return FindWellKnownTypeFieldByName(type, name); + } +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ diff --git a/common/types/thread_compatible_type_introspector.h b/common/types/thread_compatible_type_introspector.h deleted file mode 100644 index 870ea9054..000000000 --- a/common/types/thread_compatible_type_introspector.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ - -#include "common/type_introspector.h" - -namespace cel::common_internal { - -// `ThreadCompatibleTypeIntrospector` is a basic implementation of -// `TypeIntrospector` which is thread compatible. By default this implementation -// just returns `NOT_FOUND` for most methods. -class ThreadCompatibleTypeIntrospector : public virtual TypeIntrospector { - public: - ThreadCompatibleTypeIntrospector() = default; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 7d770b443..aa4d0b4e5 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -23,12 +23,10 @@ #include #include "absl/base/nullability.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/type_provider.h" -#include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index a85f08911..f87ab9645 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -27,6 +27,7 @@ #include "common/legacy_value.h" #include "common/memory.h" #include "common/type.h" +#include "common/type_introspector.h" #include "common/value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" @@ -175,6 +176,9 @@ LegacyTypeProvider::NewValueBuilder( absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::string_view name) const { + if (auto type = cel::FindWellKnownType(name); type.has_value()) { + return type; + } if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); if (descriptor != nullptr) { @@ -189,6 +193,10 @@ absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::StatusOr> LegacyTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { + if (auto result = cel::FindWellKnownTypeFieldByName(type, name); + result.has_value()) { + return result; + } if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { if (auto field_desc = (*type_info)->FindFieldByName(name); field_desc.has_value()) { diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 6c3f654f9..3f4081b09 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -87,48 +87,12 @@ cc_library( ], ) -cc_library( - name = "type", - srcs = [ - "type_introspector.cc", - ], - hdrs = [ - "type_introspector.h", - ], - deps = [ - "//common:type", - "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "type_test", - srcs = [ - "type_introspector_test.cc", - ], - deps = [ - ":type", - "//common:type", - "//common:type_kind", - "//internal:testing", - "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", - "@com_google_protobuf//:protobuf", - ], -) - cc_library( name = "value", hdrs = [ - "type_reflector.h", "value.h", ], deps = [ - ":type", "//common:memory", "//common:type", "//common:value", diff --git a/extensions/protobuf/type_introspector.cc b/extensions/protobuf/type_introspector.cc deleted file mode 100644 index 8b445c359..000000000 --- a/extensions/protobuf/type_introspector.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( - absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(name); - if (desc == nullptr) { - return absl::nullopt; - } - return MessageType(desc); -} - -absl::StatusOr> -ProtoTypeIntrospector::FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const { - const google::protobuf::EnumDescriptor* enum_desc = - descriptor_pool()->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. - if (enum_desc == nullptr) { - return absl::nullopt; - } - - // Note: we don't support strong enum typing at this time so only the fully - // qualified enum values are meaningful, so we don't provide any signal if the - // enum type is found but can't match the value name. - const google::protobuf::EnumValueDescriptor* value_desc = - enum_desc->FindValueByName(value); - if (value_desc == nullptr) { - return absl::nullopt; - } - - return TypeIntrospector::EnumConstant{ - EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), - value_desc->number()}; -} - -absl::StatusOr> -ProtoTypeIntrospector::FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(type); - if (desc == nullptr) { - return absl::nullopt; - } - const auto* field_desc = desc->FindFieldByName(name); - if (field_desc == nullptr) { - field_desc = descriptor_pool()->FindExtensionByPrintableName(desc, name); - if (field_desc == nullptr) { - return absl::nullopt; - } - } - return MessageTypeField(field_desc); -} - -} // namespace cel::extensions diff --git a/extensions/protobuf/type_introspector.h b/extensions/protobuf/type_introspector.h deleted file mode 100644 index 5eb9c3ddc..000000000 --- a/extensions/protobuf/type_introspector.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ - -#include "absl/base/nullability.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeIntrospector : public virtual TypeIntrospector { - public: - ProtoTypeIntrospector() - : ProtoTypeIntrospector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeIntrospector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : descriptor_pool_(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return descriptor_pool_; - } - - protected: - absl::StatusOr> FindTypeImpl( - absl::string_view name) const final; - - absl::StatusOr> - FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const final; - - absl::StatusOr> FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const final; - - private: - const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc deleted file mode 100644 index 0a7b21524..000000000 --- a/extensions/protobuf/type_introspector_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_kind.h" -#include "internal/testing.h" -#include "cel/expr/conformance/proto2/test_all_types.pb.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { -namespace { - -using ::absl_testing::IsOkAndHolds; -using ::cel::expr::conformance::proto2::TestAllTypes; -using ::testing::Eq; -using ::testing::Optional; - -TEST(ProtoTypeIntrospector, FindType) { - ProtoTypeIntrospector introspector; - EXPECT_THAT( - introspector.FindType(TestAllTypes::descriptor()->full_name()), - IsOkAndHolds(Optional(Eq(MessageType(TestAllTypes::GetDescriptor()))))); - EXPECT_THAT(introspector.FindType("type.that.does.not.Exist"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindStructTypeFieldByName) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto field, introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "single_int32")); - ASSERT_TRUE(field.has_value()); - EXPECT_THAT(field->name(), Eq("single_int32")); - EXPECT_THAT(field->number(), Eq(1)); - EXPECT_THAT( - introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "field_that_does_not_exist"), - IsOkAndHolds(Eq(absl::nullopt))); - EXPECT_THAT(introspector.FindStructTypeFieldByName("type.that.does.not.Exist", - "does_not_matter"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindEnumConstant) { - ProtoTypeIntrospector introspector; - const auto* enum_desc = TestAllTypes::NestedEnum_descriptor(); - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "BAZ")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); - EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); - EXPECT_EQ(enum_constant->value_name, "BAZ"); - EXPECT_EQ(enum_constant->number, 2); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantNull) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant("google.protobuf.NullValue", "NULL_VALUE")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kNull); - EXPECT_EQ(enum_constant->type_full_name, "google.protobuf.NullValue"); - EXPECT_EQ(enum_constant->value_name, "NULL_VALUE"); - EXPECT_EQ(enum_constant->number, 0); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownEnum) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN(auto enum_constant, - introspector.FindEnumConstant("NotARealEnum", "BAZ")); - EXPECT_FALSE(enum_constant.has_value()); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownValue) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "QUX")); - ASSERT_FALSE(enum_constant.has_value()); -} - -} // namespace -} // namespace cel::extensions diff --git a/extensions/protobuf/type_reflector.h b/extensions/protobuf/type_reflector.h deleted file mode 100644 index 4665235fe..000000000 --- a/extensions/protobuf/type_reflector.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ - -#include "absl/base/nullability.h" -#include "common/type_reflector.h" -#include "extensions/protobuf/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeReflector : public TypeReflector, public ProtoTypeIntrospector { - public: - ProtoTypeReflector() - : ProtoTypeReflector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeReflector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : ProtoTypeIntrospector(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return ProtoTypeIntrospector::descriptor_pool(); - } -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 28f9bd1cb..1223ff6d1 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -195,6 +195,7 @@ cc_library( deps = [ "//common:type", "//common:value", + "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/runtime/internal/runtime_type_provider.cc b/runtime/internal/runtime_type_provider.cc index 1acb52223..40f5ff575 100644 --- a/runtime/internal/runtime_type_provider.cc +++ b/runtime/internal/runtime_type_provider.cc @@ -44,8 +44,10 @@ absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. + auto type = FindWellKnownType(name); + if (type.has_value()) { + return type; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(name); if (desc != nullptr) { return MessageType(desc); @@ -60,9 +62,12 @@ absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::StatusOr> RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::string_view value) const { + auto enum_constant = FindWellKnownTypeEnumConstant(type, value); + if (enum_constant.has_value()) { + return enum_constant; + } const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool_->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. if (enum_desc == nullptr) { return absl::nullopt; } @@ -84,8 +89,10 @@ RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::StatusOr> RuntimeTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. + auto field = FindWellKnownTypeFieldByName(type, name); + if (field.has_value()) { + return field; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(type); if (desc == nullptr) { return absl::nullopt; From 6e1cb5311aa17ddc2ed9e4fd30e042d4c592b359 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 1 Apr 2026 10:27:42 -0700 Subject: [PATCH 020/143] Fix compatibility with newer versions of protobuf PiperOrigin-RevId: 893000224 --- internal/well_known_types.cc | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index f66a9360b..c736be69f 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -71,6 +71,17 @@ using ::google::protobuf::util::TimeUtil; using CppStringType = ::google::protobuf::FieldDescriptor::CppStringType; +FieldDescriptor::Label GetFieldLabel( + const FieldDescriptor* absl_nonnull field) { + if (field->is_required()) { + return FieldDescriptor::LABEL_REQUIRED; + } else if (field->is_repeated()) { + return FieldDescriptor::LABEL_REPEATED; + } else { + return FieldDescriptor::LABEL_OPTIONAL; + } +} + absl::string_view FlatStringValue( const StringValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { @@ -264,11 +275,11 @@ absl::string_view LabelToString(FieldDescriptor::Label label) { absl::Status CheckFieldCardinality(const FieldDescriptor* absl_nonnull field, FieldDescriptor::Label label) { - if (ABSL_PREDICT_FALSE(field->label() != label)) { - return absl::InvalidArgumentError( - absl::StrCat("unexpected field cardinality for protocol buffer message " - "well known type: ", - field->full_name(), " ", LabelToString(field->label()))); + if (ABSL_PREDICT_FALSE(GetFieldLabel(field) != label)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field cardinality for protocol buffer message " + "well known type: ", + field->full_name(), " ", LabelToString(GetFieldLabel(field)))); } return absl::OkStatus(); } From cf07a8775491c705739416adcb3a77f3cf5c0916 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 10:14:53 -0700 Subject: [PATCH 021/143] Expose source_info position -> SourceLocation helper. PiperOrigin-RevId: 893571451 --- checker/internal/type_checker_impl.cc | 58 +++++---------------------- common/BUILD | 2 + common/ast.cc | 38 ++++++++++++++++++ common/ast.h | 8 ++++ common/ast_test.cc | 52 ++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 48 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1e9995b19..28cbf21e0 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -48,7 +48,6 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" -#include "common/source.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -66,43 +65,6 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } -SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { - const auto& source_info = ast.source_info(); - auto iter = source_info.positions().find(expr_id); - if (iter == source_info.positions().end()) { - return SourceLocation{}; - } - int32_t absolute_position = iter->second; - if (absolute_position < 0) { - return SourceLocation{}; - } - - // Find the first line offset that is greater than the absolute position. - int32_t line_idx = -1; - int32_t offset = 0; - for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { - int32_t next_offset = source_info.line_offsets()[i]; - if (next_offset <= offset) { - // Line offset is not monotonically increasing, so line information is - // invalid. - return SourceLocation{}; - } - if (absolute_position < next_offset) { - line_idx = i; - break; - } - offset = next_offset; - } - - if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { - return SourceLocation{}; - } - - int32_t rel_position = absolute_position - offset; - - return SourceLocation{line_idx + 1, rel_position}; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -362,7 +324,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportMissingReference(const Expr& expr, absl::string_view name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", container_, "')"))); } @@ -370,7 +332,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); } @@ -378,7 +340,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("expected type '", FormatTypeName(inference_context_->FinalizeType(expected)), "' but found '", @@ -408,7 +370,7 @@ class ResolveVisitor : public AstVisitorBase { } if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, field.id()), + ast_->ComputeSourceLocation(field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", FormatTypeName(inference_context_->FinalizeType(field_type)), @@ -553,7 +515,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); types_[&expr] = ErrorType(); @@ -605,7 +567,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. ReportIssue(TypeCheckIssue( - Severity::kWarning, ComputeSourceLocation(*ast_, key->id()), + Severity::kWarning, ast_->ComputeSourceLocation(key->id()), absl::StrCat( "unsupported map key type: ", FormatTypeName(inference_context_->FinalizeType(key_type))))); @@ -711,7 +673,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); types_[&expr] = ErrorType(); @@ -862,7 +824,7 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, comprehension.iter_range().id()), + ast_->ComputeSourceLocation(comprehension.iter_range().id()), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(range_type)), @@ -933,7 +895,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, if (!resolution.has_value()) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", absl::StrJoin(arg_types, ", ", @@ -1133,7 +1095,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, } ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, id), + ast_->ComputeSourceLocation(id), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), diff --git a/common/BUILD b/common/BUILD index 8dd8921cc..e289ef413 100644 --- a/common/BUILD +++ b/common/BUILD @@ -25,6 +25,7 @@ cc_library( hdrs = ["ast.h"], deps = [ ":expr", + ":source", "//common/ast:metadata", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", @@ -39,6 +40,7 @@ cc_test( deps = [ ":ast", ":expr", + ":source", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/common/ast.cc b/common/ast.cc index aea153197..48b6f5e0b 100644 --- a/common/ast.cc +++ b/common/ast.cc @@ -19,6 +19,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "common/ast/metadata.h" +#include "common/source.h" namespace cel { namespace { @@ -57,4 +58,41 @@ const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { return &iter->second; } +SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { + const auto& source_info = this->source_info(); + auto iter = source_info.positions().find(expr_id); + if (iter == source_info.positions().end()) { + return SourceLocation{}; + } + int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. + int32_t line_idx = -1; + int32_t offset = 0; + for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { + line_idx = i; + break; + } + offset = next_offset; + } + + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; + } + + int32_t rel_position = absolute_position - offset; + + return SourceLocation{line_idx + 1, rel_position}; +} + } // namespace cel diff --git a/common/ast.h b/common/ast.h index 1b07b9878..db336f52d 100644 --- a/common/ast.h +++ b/common/ast.h @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" +#include "common/source.h" namespace cel { @@ -135,6 +136,13 @@ class Ast final { expr_version_ = expr_version; } + // Computes the source location (line and column) for the given expression id + // from the source info (which stores absolute positions). + // + // Returns a default (empty) source location if the expression id is not found + // or the source info is not populated correctly. + SourceLocation ComputeSourceLocation(int64_t expr_id) const; + private: Expr root_expr_; SourceInfo source_info_; diff --git a/common/ast_test.cc b/common/ast_test.cc index 744b9e8d3..56e1bcd1e 100644 --- a/common/ast_test.cc +++ b/common/ast_test.cc @@ -18,6 +18,7 @@ #include "absl/container/flat_hash_map.h" #include "common/expr.h" +#include "common/source.h" #include "internal/testing.h" namespace cel { @@ -132,5 +133,56 @@ TEST(AstImpl, CheckedExprDeepCopy) { EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); } +TEST(AstImpl, ComputeSourceLocation) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20, 30}); + source_info.mutable_positions()[1] = 0; // Start of first line + source_info.mutable_positions()[2] = 5; // Middle of first line + source_info.mutable_positions()[3] = 10; // ... + source_info.mutable_positions()[4] = 15; + source_info.mutable_positions()[5] = 20; + source_info.mutable_positions()[6] = 25; + + Ast ast(Expr{}, std::move(source_info)); + + EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); +} + +TEST(AstImpl, ComputeSourceLocationFailures) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20}); + source_info.mutable_positions()[1] = -1; // Negative position + source_info.mutable_positions()[2] = 25; // Beyond last line offset + // ID 3 is missing + + Ast ast; + ast.mutable_source_info() = std::move(source_info); + + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); +} + +TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { + { + // Empty line offsets + Ast ast; + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } + { + // Non-monotonic + SourceInfo source_info; + source_info.set_line_offsets({10, 5}); + source_info.mutable_positions()[1] = 12; + Ast ast(Expr{}, std::move(source_info)); + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } +} + } // namespace } // namespace cel From 6bf474b38bcd754488e44ef30f9c61031ee47b37 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 2 Apr 2026 10:57:13 -0700 Subject: [PATCH 022/143] No public description PiperOrigin-RevId: 893593905 --- common/optional_ref.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/common/optional_ref.h b/common/optional_ref.h index 454926c80..c7ba580fc 100644 --- a/common/optional_ref.h +++ b/common/optional_ref.h @@ -84,7 +84,12 @@ class optional_ref final { constexpr T& value() const { return ABSL_PREDICT_TRUE(has_value()) ? *value_ - : (absl::optional().value(), *value_); + // Replicate the same error logic as in `absl::optional`'s + // `value()`. It either throws an exception or aborts the + // program. We intentionally ignore the return value of + // the constructed optional's value as we only need to run + // the code for error checking. + : ((void)absl::optional().value(), *value_); } constexpr T& operator*() const { From 74e7666b01c14dad63052d7e0b6c2a7673161e8d Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 13:46:30 -0700 Subject: [PATCH 023/143] Initial implementation for Ast Validator Adds a new `Validator` type for applying semantic checks on a compiled expresion. Adds timestamp and duration literal validators as examples for `Validations`. PiperOrigin-RevId: 893679645 --- checker/validation_result.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/checker/validation_result.h b/checker/validation_result.h index 45f949739..8c84a84da 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -58,6 +58,8 @@ class ValidationResult { absl::Span GetIssues() const { return issues_; } + void AddIssue(TypeCheckIssue issue) { issues_.push_back(std::move(issue)); } + // The source expression may optionally be set if it is available. const cel::Source* absl_nullable GetSource() const { return source_.get(); } From 082088649a2af3342a6fbdcb8f338d4c6557dc70 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 15:30:04 -0700 Subject: [PATCH 024/143] Add proto type introspector to C++ type checker. Intended to mediate field lookups so the checker follows any options consistently. Adds support for resolving fields via JSON name. PiperOrigin-RevId: 893734042 --- checker/BUILD | 1 + checker/checker_options.h | 8 + checker/internal/BUILD | 36 ++- .../descriptor_pool_type_introspector.cc | 245 ++++++++++++++++++ .../descriptor_pool_type_introspector.h | 105 ++++++++ .../descriptor_pool_type_introspector_test.cc | 175 +++++++++++++ checker/internal/test_ast_helpers.cc | 22 +- checker/internal/type_check_env.cc | 88 ++----- checker/internal/type_check_env.h | 27 +- checker/internal/type_checker_builder_impl.cc | 57 +++- .../type_checker_builder_impl_test.cc | 1 - checker/internal/type_checker_impl.cc | 6 + checker/type_checker_builder_factory_test.cc | 108 ++++++++ common/type_introspector.cc | 24 +- common/type_introspector.h | 33 +++ internal/BUILD | 1 + testutil/BUILD | 6 + testutil/test_json_names.proto | 31 +++ 18 files changed, 869 insertions(+), 105 deletions(-) create mode 100644 checker/internal/descriptor_pool_type_introspector.cc create mode 100644 checker/internal/descriptor_pool_type_introspector.h create mode 100644 checker/internal/descriptor_pool_type_introspector_test.cc create mode 100644 testutil/test_json_names.proto diff --git a/checker/BUILD b/checker/BUILD index 42e37e81d..d5eb3601c 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -128,6 +128,7 @@ cc_test( ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", + "//common:ast", "//common:decl", "//common:type", "//internal:status_macros", diff --git a/checker/checker_options.h b/checker/checker_options.h index 0b6d1af7f..cb85337fa 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -95,6 +95,14 @@ struct CheckerOptions { // Temporary flag to allow rolling out the change. No functional changes to // evaluation behavior in either mode. bool enable_function_name_in_reference = true; + + // If true, the checker will use the proto json field names for protobuf + // messages. Unlike protojson parsers, it will not accept the standard proto + // field names as valid json field names. + // + // Note: The checked AST will contain the json field names and an extension + // tag, but will require runtime support for resolving the json field names. + bool use_json_field_names = false; }; } // namespace cel diff --git a/checker/internal/BUILD b/checker/internal/BUILD index c539a2cc9..3f64417a0 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -27,10 +27,11 @@ cc_library( hdrs = ["test_ast_helpers.h"], deps = [ "//common:ast", - "//extensions/protobuf:ast_converters", "//internal:status_macros", "//parser", "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -64,6 +65,7 @@ cc_library( srcs = ["type_check_env.cc"], hdrs = ["type_check_env.h"], deps = [ + ":descriptor_pool_type_introspector", "//common:constant", "//common:decl", "//common:type", @@ -118,6 +120,7 @@ cc_library( "type_checker_impl.h", ], deps = [ + ":descriptor_pool_type_introspector", ":format_type_name", ":namespace_generator", ":type_check_env", @@ -261,14 +264,35 @@ cc_library( ], ) +cc_library( + name = "descriptor_pool_type_introspector", + srcs = ["descriptor_pool_type_introspector.cc"], + hdrs = ["descriptor_pool_type_introspector.h"], + deps = [ + "//common:type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( - name = "format_type_name_test", - srcs = ["format_type_name_test.cc"], + name = "descriptor_pool_type_introspector_test", + srcs = ["descriptor_pool_type_introspector_test.cc"], deps = [ - ":format_type_name", + ":descriptor_pool_type_introspector", "//common:type", "//internal:testing", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", - "@com_google_protobuf//:protobuf", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", ], ) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc new file mode 100644 index 000000000..f6001e947 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -0,0 +1,245 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +// Standard implementation for field lookups. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr> +FindStructTypeFieldByNameDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type, absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return absl::nullopt; + } + const google::protobuf::FieldDescriptor* absl_nullable field = + descriptor->FindFieldByName(name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + + field = descriptor_pool->FindExtensionByPrintableName(descriptor, name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + return absl::nullopt; +} + +// Standard implementation for listing fields. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr< + absl::optional>> +ListStructTypeFieldsDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return absl::nullopt; + } + + std::vector extensions; + descriptor_pool->FindAllExtensions(descriptor, &extensions); + + std::vector fields; + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back({field->name(), StructTypeField(MessageTypeField(field))}); + } + + return fields; +} + +} // namespace + +using Field = DescriptorPoolTypeIntrospector::Field; + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return Type::Message(descriptor); + } + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + return Type::Enum(enum_descriptor); + } + return absl::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const { + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(type); + if (enum_descriptor != nullptr) { + const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = + enum_descriptor->FindValueByName(value); + if (enum_value_descriptor == nullptr) { + return absl::nullopt; + } + return EnumConstant{ + .type = Type::Enum(enum_descriptor), + .type_full_name = enum_descriptor->full_name(), + .value_name = enum_value_descriptor->name(), + .number = enum_value_descriptor->number(), + }; + } + return absl::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + if (!use_json_name_) { + return FindStructTypeFieldByNameDirectly(descriptor_pool_, type, name); + } + + const FieldTable* field_table = GetFieldTable(type); + + if (field_table == nullptr) { + return absl::nullopt; + } + + if (auto it = field_table->json_name_map.find(name); + it != field_table->json_name_map.end()) { + return field_table->fields[it->second].field; + } + + if (auto it = field_table->extension_name_map.find(name); + it != field_table->extension_name_map.end()) { + return field_table->fields[it->second].field; + } + + return absl::nullopt; +} + +absl::StatusOr< + absl::optional>> +DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( + absl::string_view type) const { + if (!use_json_name_) { + return ListStructTypeFieldsDirectly(descriptor_pool_, type); + } + + const FieldTable* field_table = GetFieldTable(type); + if (field_table == nullptr) { + return absl::nullopt; + } + std::vector fields; + fields.reserve(field_table->non_extensions.size()); + for (const auto& field : field_table->non_extensions) { + fields.push_back({field.json_name, field.field}); + } + return fields; +} + +const DescriptorPoolTypeIntrospector::FieldTable* +DescriptorPoolTypeIntrospector::GetFieldTable( + absl::string_view type_name) const { + absl::MutexLock lock(mu_); + if (auto it = field_tables_.find(type_name); it != field_tables_.end()) { + return it->second.get(); + } + if (cel::IsWellKnownMessageType(type_name)) { + return nullptr; + } + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return nullptr; + } + absl::string_view stable_type_name = descriptor->full_name(); + ABSL_DCHECK(stable_type_name == type_name); + std::unique_ptr field_table = CreateFieldTable(descriptor); + const FieldTable* field_table_ptr = field_table.get(); + field_tables_[stable_type_name] = std::move(field_table); + return field_table_ptr; +} + +std::unique_ptr +DescriptorPoolTypeIntrospector::CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const { + ABSL_DCHECK(!IsWellKnownMessageType(descriptor)); + std::vector fields; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + + std::vector extensions; + descriptor_pool_->FindAllExtensions(descriptor, &extensions); + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); i++) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(field)), + .json_name = field->json_name(), + .is_extension = false, + }); + field_name_map[field->name()] = fields.size() - 1; + if (use_json_name_ && !field->json_name().empty()) { + json_name_map[field->json_name()] = fields.size() - 1; + } + } + int non_extension_count = fields.size(); + + for (const google::protobuf::FieldDescriptor* extension : extensions) { + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(extension)), + .json_name = "", + .is_extension = true, + }); + extension_name_map[extension->full_name()] = fields.size() - 1; + } + int extension_count = fields.size() - non_extension_count; + auto result = std::make_unique(); + result->descriptor = descriptor; + result->fields = std::move(fields); + result->non_extensions = + absl::MakeConstSpan(result->fields).subspan(0, non_extension_count); + result->extensions = absl::MakeConstSpan(result->fields) + .subspan(non_extension_count, extension_count); + result->json_name_map = std::move(json_name_map); + result->field_name_map = std::move(field_name_map); + result->extension_name_map = std::move(extension_name_map); + return result; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/descriptor_pool_type_introspector.h b/checker/internal/descriptor_pool_type_introspector.h new file mode 100644 index 000000000..8a970ea00 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Implementation of `TypeIntrospector` that uses a `google::protobuf::DescriptorPool`. +// +// This is used by the type checker to resolve protobuf types and their fields +// and apply any options like using JSON names. +// +// Neither copyable nor movable. Should be managed by a TypeCheckEnv. +class DescriptorPoolTypeIntrospector : public TypeIntrospector { + public: + struct Field { + StructTypeField field; + absl::string_view json_name; + bool is_extension = false; + }; + + DescriptorPoolTypeIntrospector() = delete; + explicit DescriptorPoolTypeIntrospector( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} + + DescriptorPoolTypeIntrospector(const DescriptorPoolTypeIntrospector&) = + delete; + DescriptorPoolTypeIntrospector& operator=( + const DescriptorPoolTypeIntrospector&) = delete; + DescriptorPoolTypeIntrospector(DescriptorPoolTypeIntrospector&&) = delete; + DescriptorPoolTypeIntrospector& operator=(DescriptorPoolTypeIntrospector&&) = + delete; + + void set_use_json_name(bool use_json_name) { use_json_name_ = use_json_name; } + + bool use_json_name() const { return use_json_name_; } + + private: + struct FieldTable { + const google::protobuf::Descriptor* absl_nonnull descriptor; + std::vector fields; + absl::Span non_extensions; + absl::Span extensions; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + }; + + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final; + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final; + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final; + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final; + + std::unique_ptr CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const; + + const FieldTable* GetFieldTable(absl::string_view type_name) const; + + // Cached map of type to field table. + mutable absl::flat_hash_map> + field_tables_ ABSL_GUARDED_BY(mu_); + + mutable absl::Mutex mu_; + bool use_json_name_ = false; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc new file mode 100644 index 000000000..e2fdc9d40 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Optional; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::Truly; + +TEST(DescriptorPoolTypeIntrospectorTest, FindType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + EXPECT_THAT(introspector.FindType("cel.expr.conformance.proto3.TestAllTypes"), + IsOkAndHolds(Optional(Property(&Type::IsMessage, true)))); + EXPECT_THAT(introspector.FindType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), + IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); + EXPECT_THAT(introspector.FindType("non.existent.Type"), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto result = introspector.FindEnumConstant( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", "FOO"); + ASSERT_THAT(result, IsOkAndHolds(Optional(AllOf( + Truly([](const TypeIntrospector::EnumConstant& v) { + return v.value_name == "FOO" && v.number == 0; + }))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByName) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + introspector.set_use_json_name(false); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameJsonNameIgnored) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(false); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + EXPECT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto2.TestAllTypes", + "cel.expr.conformance.proto2.int32_ext"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + + ASSERT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + absl::StatusOr> field = + introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +MATCHER_P(FieldListingIs, field_name, "") { return arg.name == field_name; } + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + absl::StatusOr< + absl::optional>> + fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(*fields, Optional(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeExtensions) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto2.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(259)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("single_int64"))); + EXPECT_THAT( + **fields, + Not(Contains(FieldListingIs("cel.expr.conformance.proto2.int32_ext")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + ListFieldsForStructTypeWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("singleInt64"))); + EXPECT_THAT(**fields, Not(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.SomeOtherType"); + EXPECT_THAT(fields, IsOkAndHolds(Eq(absl::nullopt))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/test_ast_helpers.cc b/checker/internal/test_ast_helpers.cc index 6ef7c2c05..543f70a89 100644 --- a/checker/internal/test_ast_helpers.cc +++ b/checker/internal/test_ast_helpers.cc @@ -14,29 +14,31 @@ #include "checker/internal/test_ast_helpers.h" #include -#include +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" -#include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "parser/options.h" #include "parser/parser.h" +#include "parser/parser_interface.h" namespace cel::checker_internal { -using ::cel::extensions::CreateAstFromParsedExpr; -using ::google::api::expr::parser::Parse; - absl::StatusOr> MakeTestParsedAst( absl::string_view expression) { - static ParserOptions options; - options.enable_optional_syntax = true; - CEL_ASSIGN_OR_RETURN(auto parsed, - Parse(expression, /*description=*/expression, options)); + static const cel::Parser* parser = []() { + cel::ParserOptions options = {.enable_optional_syntax = true}; + auto parser = NewParserBuilder(options)->Build(); + ABSL_CHECK_OK(parser); + return parser->release(); + }(); - return CreateAstFromParsedExpr(std::move(parsed)); + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + return parser->Parse(*source); } } // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index e76621435..c080326cb 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -28,7 +28,6 @@ #include "common/type_introspector.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" namespace cel::checker_internal { @@ -51,23 +50,10 @@ const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(name); - if (descriptor != nullptr) { - return Type::Message(descriptor); - } - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(name); - if (enum_descriptor != nullptr) { - return Type::Enum(enum_descriptor); - } - } - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto type = (*iter)->FindType(name); - if (!type.ok() || type->has_value()) { + CEL_ASSIGN_OR_RETURN(auto type, (*iter)->FindType(name)); + if (type.has_value()) { return type; } } @@ -76,37 +62,15 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(type); - if (enum_descriptor != nullptr) { - const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = - enum_descriptor->FindValueByName(value); - if (enum_value_descriptor == nullptr) { - return absl::nullopt; - } - auto decl = - MakeVariableDecl(absl::StrCat(enum_descriptor->full_name(), ".", - enum_value_descriptor->name()), - Type::Enum(enum_descriptor)); - decl.set_value( - Constant(static_cast(enum_value_descriptor->number()))); - return decl; - } - } - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type, value); - if (!enum_constant.ok()) { - return enum_constant.status(); - } - if (enum_constant->has_value()) { - auto decl = - MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", - (**enum_constant).value_name), - (**enum_constant).type); - decl.set_value(Constant(static_cast((**enum_constant).number))); + CEL_ASSIGN_OR_RETURN(auto enum_constant, + (*iter)->FindEnumConstant(type, value)); + if (enum_constant.has_value()) { + auto decl = MakeVariableDecl(absl::StrCat(enum_constant->type_full_name, + ".", enum_constant->value_name), + enum_constant->type); + decl.set_value(Constant(static_cast(enum_constant->number))); return decl; } } @@ -132,32 +96,16 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(type_name); - if (descriptor != nullptr) { - const google::protobuf::FieldDescriptor* absl_nullable field_descriptor = - descriptor->FindFieldByName(field_name); - if (field_descriptor == nullptr) { - field_descriptor = descriptor_pool_->FindExtensionByPrintableName( - descriptor, field_name); - if (field_descriptor == nullptr) { - return absl::nullopt; - } - } - return cel::MessageTypeField(field_descriptor); - } - } - // Check the type providers in reverse registration order. + // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the - // same name -- the prior type provider will still be considered when + // same name -- the later type provider will still be considered when // checking field accesses. - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto field_info = (*iter)->FindStructTypeFieldByName(type_name, field_name); - if (!field_info.ok() || field_info->has_value()) { - return field_info; + CEL_ASSIGN_OR_RETURN( + auto field, (*iter)->FindStructTypeFieldByName(type_name, field_name)); + if (field.has_value()) { + return field; } } return absl::nullopt; diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 5c8b3629c..520b0eab6 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -28,6 +28,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "checker/internal/descriptor_pool_type_introspector.h" #include "common/constant.h" #include "common/decl.h" #include "common/type.h" @@ -89,14 +90,15 @@ class TypeCheckEnv { explicit TypeCheckEnv( absl_nonnull std::shared_ptr descriptor_pool) - : descriptor_pool_(std::move(descriptor_pool)), container_("") {} - - TypeCheckEnv(absl_nonnull std::shared_ptr - descriptor_pool, - std::shared_ptr arena) : descriptor_pool_(std::move(descriptor_pool)), - arena_(std::move(arena)), - container_("") {} + container_(""), + proto_type_introspector_( + std::make_shared( + descriptor_pool_.get())) { + type_providers_.push_back( + std::make_shared()); + type_providers_.push_back(proto_type_introspector_); + } // Move-only. TypeCheckEnv(TypeCheckEnv&&) = default; @@ -108,6 +110,13 @@ class TypeCheckEnv { container_ = std::move(container); } + const DescriptorPoolTypeIntrospector& proto_type_introspector() const { + return *proto_type_introspector_; + } + DescriptorPoolTypeIntrospector& proto_type_introspector() { + return *proto_type_introspector_; + } + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } const absl::optional& expected_type() const { return expected_type_; } @@ -194,10 +203,14 @@ class TypeCheckEnv { absl::string_view type, absl::string_view value) const; absl_nonnull std::shared_ptr descriptor_pool_; + // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; std::string container_; + // Used to resolve fields on message types. + std::shared_ptr proto_type_introspector_; + // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 8aa5177a5..7545aa949 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -84,19 +84,55 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { return absl::OkStatus(); } +absl::Status AddWellKnownContextDeclarationVariables( + const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, + bool use_json_name) { + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + Type type = MessageTypeField(field).GetType(); + if (type.IsEnum()) { + type = IntType(); + } + absl::string_view name = field->name(); + if (use_json_name) { + name = field->json_name(); + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", name, + "' declared multiple times (from context declaration: '", + descriptor->full_name(), "')")); + } + } + return absl::OkStatus(); +} + absl::Status AddContextDeclarationVariables( const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { - for (int i = 0; i < descriptor->field_count(); i++) { - const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i); - MessageTypeField cel_field(proto_field); - Type field_type = cel_field.GetType(); - if (field_type.IsEnum()) { - field_type = IntType(); + const bool use_json_name = env.proto_type_introspector().use_json_name(); + if (IsWellKnownMessageType(descriptor)) { + return AddWellKnownContextDeclarationVariables(descriptor, env, + use_json_name); + } + CEL_ASSIGN_OR_RETURN(auto fields, + env.proto_type_introspector().ListFieldsForStructType( + descriptor->full_name())); + if (!fields.has_value()) { + return absl::InternalError(absl::StrCat("context declaration '", + descriptor->full_name(), + "' not found, but was expected")); + } + for (const auto& field_entry : *fields) { + Type type = field_entry.field.GetType(); + if (type.IsEnum()) { + type = IntType(); } - if (!env.InsertVariableIfAbsent( - MakeVariableDecl(cel_field.name(), field_type))) { + + absl::string_view name = field_entry.name; + + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( - absl::StrCat("variable '", cel_field.name(), + absl::StrCat("variable '", name, "' declared multiple times (from context declaration: '", descriptor->full_name(), "')")); } @@ -324,6 +360,9 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config)); } + env.proto_type_introspector().set_use_json_name( + options_.use_json_field_names); + for (const ConfigRecord& config : configs) { TypeCheckerSubset* subset = nullptr; if (!config.id.empty()) { diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index e23c26165..f7a3dff97 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -42,7 +42,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; - struct ContextDeclsTestCase { std::string expr; TypeSpec expected_type; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 28cbf21e0..8e8047755 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1319,6 +1319,12 @@ absl::StatusOr TypeCheckerImpl::Check( CEL_RETURN_IF_ERROR(rewriter.status()); ast->set_is_checked(true); + if (options_.use_json_field_names) { + ast->mutable_source_info().mutable_extensions().push_back( + cel::ExtensionSpec("json_name", + std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime})); + } return ValidationResult(std::move(ast), std::move(issues)); } diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index a15d2e173..030186c83 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -27,6 +27,7 @@ #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "internal/status_macros.h" @@ -496,6 +497,113 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationInt64Value) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, ContextDeclarationWithJsonName) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("cel.cpp.testutil.TestJsonNames"), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(int32_snake_case_json_name == 1 && + int64CamelCaseJsonName == 2 && + uint32DefaultJsonName == 3u && + // `uint64-custom-json-name` == 4u && + single_string == 'shadows' && + singleString == 'shadowed')cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionStructCreation) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(cel.cpp.testutil.TestJsonNames{ + int32_snake_case_json_name: 1, + int64CamelCaseJsonName: 2, + uint32DefaultJsonName: 3u, + `uint64-custom-json-name`: 4u, + single_string: 'shadows', + singleString: 'shadowed' + })cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), + TypeSpec(MessageTypeSpec("cel.cpp.testutil.TestJsonNames"))); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionFieldAccess) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + builder->AddVariable(MakeVariableDecl( + "jsonObj", + cel::MessageType(builder->descriptor_pool()->FindMessageTypeByName( + "cel.cpp.testutil.TestJsonNames")))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel( + jsonObj.int32_snake_case_json_name == 1 && + jsonObj.int64CamelCaseJsonName == 2 && + jsonObj.uint32DefaultJsonName == 3u && + jsonObj.`uint64-custom-json-name` == 4u && + jsonObj.single_string == 'shadows' && + jsonObj.singleString == 'shadowed' && + jsonObj.`cel.cpp.testutil.int32_snake_case_ext` == 5 && + jsonObj.`cel.cpp.testutil.int64CamelCaseExt` == 6 + )cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, diff --git a/common/type_introspector.cc b/common/type_introspector.cc index 6d5158a2f..26f53685e 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -17,7 +17,9 @@ #include #include #include +#include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" @@ -173,7 +175,8 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { "google.protobuf.Value", WellKnownType{ DynType{}, - {MakeBasicStructTypeField("null_value", NullType{}, 1), + {// NullValue enum is an int. Not normally referenced directly. + MakeBasicStructTypeField("null_value", IntType{}, 1), MakeBasicStructTypeField("number_value", DoubleType{}, 2), MakeBasicStructTypeField("string_value", StringType{}, 3), MakeBasicStructTypeField("bool_value", BoolType{}, 4), @@ -228,6 +231,12 @@ TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, return absl::nullopt; } +absl::StatusOr< + absl::optional>> +TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { + return absl::nullopt; +} + absl::optional FindWellKnownType(absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(name); it != well_known_types.end()) { @@ -240,7 +249,7 @@ absl::optional FindWellKnownTypeEnumConstant( absl::string_view type, absl::string_view value) { if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { return TypeIntrospector::EnumConstant{ - NullType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; + IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } return absl::nullopt; } @@ -254,4 +263,15 @@ absl::optional FindWellKnownTypeFieldByName( return absl::nullopt; } +absl::optional> +ListFieldsForWellKnownType(absl::string_view type) { + const auto& well_known_types = GetWellKnownTypesMap(); + auto it = well_known_types.find(type); + if (it == well_known_types.end()) { + return absl::nullopt; + } + // The fields are not normally gettable. + return {}; +} + } // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h index fb6ea09c1..932fb108e 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #include +#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -40,6 +41,15 @@ class TypeIntrospector { int32_t number; }; + struct StructTypeFieldListing { + // The name used to access the field in source CEL. + // This is assumed owned by the TypeIntrospector or a dependency that + // outlives it. + absl::string_view name; + // The field description. + StructTypeField field; + }; + virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. @@ -61,6 +71,18 @@ class TypeIntrospector { return FindStructTypeFieldByNameImpl(type, name); } + // `ListFieldsForStructType` returns the fields of struct type `type`. + // + // This is used when the struct is declared as a context type. + // + // If the type is not found, returns `absl::nullopt`. + // If the type exists but is not a struct or has no fields, returns an empty + // vector. + absl::StatusOr>> + ListFieldsForStructType(absl::string_view type) const { + return ListFieldsForStructTypeImpl(type); + } + // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. absl::StatusOr> FindStructTypeFieldByName( @@ -78,6 +100,9 @@ class TypeIntrospector { virtual absl::StatusOr> FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const; + + virtual absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const; }; // Looks up a well-known type by name. @@ -91,6 +116,9 @@ absl::optional FindWellKnownTypeEnumConstant( absl::optional FindWellKnownTypeFieldByName( absl::string_view type, absl::string_view name); +absl::optional> +ListFieldsForWellKnownType(absl::string_view type); + // `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which // handles well known types that are treated specially by CEL. // @@ -117,6 +145,11 @@ class WellKnownTypeIntrospector : public virtual TypeIntrospector { absl::string_view type, absl::string_view name) const final { return FindWellKnownTypeFieldByName(type, name); } + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final { + return ListFieldsForWellKnownType(type); + } }; } // namespace cel diff --git a/internal/BUILD b/internal/BUILD index 59f68df9b..6bd0f0a46 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -523,6 +523,7 @@ cel_proto_transitive_descriptor_set( deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", + "//testutil:test_json_names_proto", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", diff --git a/testutil/BUILD b/testutil/BUILD index 3f1aa4fe8..292696033 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") @@ -86,3 +87,8 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +proto_library( + name = "test_json_names_proto", + srcs = ["test_json_names.proto"], +) diff --git a/testutil/test_json_names.proto b/testutil/test_json_names.proto new file mode 100644 index 000000000..a9551085b --- /dev/null +++ b/testutil/test_json_names.proto @@ -0,0 +1,31 @@ +edition = "2024"; + +package cel.cpp.testutil; + +option features.enforce_naming_style = STYLE_LEGACY; + +// This proto tests json_name options +message TestJsonNames { + int32 int32_snake_case_json_name = 1 + [json_name = "int32_snake_case_json_name"]; + int64 int64_camel_case_json_name = 2 [json_name = "int64CamelCaseJsonName"]; + uint32 uint32_default_json_name = 3; + uint64 uint64_custom_json_name = 4 [json_name = "uint64-custom-json-name"]; + + // Collides with normal field name. + string string_json_name_shadows = 5 [json_name = "single_string"]; + string single_string = 6; + + // protoc should fail on cases like these + // double double_json_shadow_default = 7 [json_name = "doubleJsonDefault"] + // double double_json_default = 8; + // double double_json_swapped_a = 7 [json_name = "double_json_swapped_b"]; + // double double_json_swapped_b = 8 [json_name = "double_json_swapped_a"]; + + extensions 100 to 199; +} + +extend TestJsonNames { + int32 int32_snake_case_ext = 100; + int64 int64CamelCaseExt = 101; +} From 5ef8a8f30dcbf2b891866273f2d74ef53fdaa763 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 10:53:00 -0700 Subject: [PATCH 025/143] Add `Validator` and update Compiler to support. Adds a `Validator` type for applying post type-check validations. Add a timestamp/duration literal validator as in cel-java. Adds a Validator to the cel::Compiler to optionally add post-compilation validations to run after type-checking. PiperOrigin-RevId: 895985690 --- compiler/BUILD | 3 + compiler/compiler.h | 5 + compiler/compiler_factory.cc | 18 ++- compiler/compiler_factory_test.cc | 18 +++ validator/BUILD | 89 +++++++++++ validator/timestamp_literal_validator.cc | 134 ++++++++++++++++ validator/timestamp_literal_validator.h | 29 ++++ validator/timestamp_literal_validator_test.cc | 146 +++++++++++++++++ validator/validator.cc | 85 ++++++++++ validator/validator.h | 151 ++++++++++++++++++ validator/validator_test.cc | 85 ++++++++++ 11 files changed, 760 insertions(+), 3 deletions(-) create mode 100644 validator/BUILD create mode 100644 validator/timestamp_literal_validator.cc create mode 100644 validator/timestamp_literal_validator.h create mode 100644 validator/timestamp_literal_validator_test.cc create mode 100644 validator/validator.cc create mode 100644 validator/validator.h create mode 100644 validator/validator_test.cc diff --git a/compiler/BUILD b/compiler/BUILD index 02bbb37dd..44ef4f537 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -27,6 +27,7 @@ cc_library( "//checker:validation_result", "//parser:options", "//parser:parser_interface", + "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -48,6 +49,7 @@ cc_library( "//internal:status_macros", "//parser", "//parser:parser_interface", + "//validator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -78,6 +80,7 @@ cc_test( "//parser:macro", "//parser:parser_interface", "//testutil:baseline_tests", + "//validator:timestamp_literal_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/compiler/compiler.h b/compiler/compiler.h index 8b867cd60..6178cf2dc 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -28,6 +28,7 @@ #include "checker/validation_result.h" #include "parser/options.h" #include "parser/parser_interface.h" +#include "validator/validator.h" namespace cel { @@ -109,6 +110,7 @@ class CompilerBuilder { virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; virtual ParserBuilder& GetParserBuilder() = 0; + virtual Validator& GetValidator() = 0; virtual absl::StatusOr> Build() = 0; }; @@ -135,6 +137,9 @@ class Compiler { // Accessor for the underlying parser. virtual const Parser& GetParser() const = 0; + + // Accessor for the underlying validator. + virtual const Validator& GetValidator() const = 0; }; } // namespace cel diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 6530dd816..c83633f68 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -32,6 +32,7 @@ #include "internal/status_macros.h" #include "parser/parser.h" #include "parser/parser_interface.h" +#include "validator/validator.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -41,8 +42,12 @@ namespace { class CompilerImpl : public Compiler { public: CompilerImpl(std::unique_ptr type_checker, - std::unique_ptr parser) - : type_checker_(std::move(type_checker)), parser_(std::move(parser)) {} + std::unique_ptr parser, + // Copy the validator in case builder is reused. + Validator validator) + : type_checker_(std::move(type_checker)), + parser_(std::move(parser)), + validator_(std::move(validator)) {} absl::StatusOr Compile( absl::string_view expression, @@ -54,15 +59,20 @@ class CompilerImpl : public Compiler { type_checker_->Check(std::move(ast))); result.SetSource(std::move(source)); + if (!validator_.validations().empty()) { + validator_.UpdateValidationResult(result); + } return result; } const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } + const Validator& GetValidator() const override { return validator_; } private: std::unique_ptr type_checker_; std::unique_ptr parser_; + Validator validator_; }; class CompilerBuilderImpl : public CompilerBuilder { @@ -126,17 +136,19 @@ class CompilerBuilderImpl : public CompilerBuilder { TypeCheckerBuilder& GetCheckerBuilder() override { return *type_checker_builder_; } + Validator& GetValidator() override { return validator_; } absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); return std::make_unique(std::move(type_checker), - std::move(parser)); + std::move(parser), validator_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; + Validator validator_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 5df0f4794..cfdc68e26 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -35,6 +35,7 @@ #include "parser/macro.h" #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" +#include "validator/timestamp_literal_validator.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -287,6 +288,23 @@ TEST(CompilerFactoryTest, DisableStandardMacrosWithStdlib) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, AddValidator) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + builder->GetValidator().AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("timestamp('invalid')")); + EXPECT_FALSE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(result, + compiler->Compile("timestamp('2024-01-01T00:00:00Z')")); + EXPECT_TRUE(result.IsValid()); +} + TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { ASSERT_OK_AND_ASSIGN( auto builder, diff --git a/validator/BUILD b/validator/BUILD new file mode 100644 index 000000000..65f7fd6b3 --- /dev/null +++ b/validator/BUILD @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "validator", + srcs = ["validator.cc"], + hdrs = ["validator.h"], + deps = [ + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:navigable_ast", + "//common:source", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "validator_test", + srcs = ["validator_test.cc"], + deps = [ + ":validator", + "//checker:type_check_issue", + "//common:ast", + "//common:expr", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "timestamp_literal_validator_test", + srcs = ["timestamp_literal_validator_test.cc"], + deps = [ + ":timestamp_literal_validator", + ":validator", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "timestamp_literal_validator", + srcs = ["timestamp_literal_validator.cc"], + hdrs = ["timestamp_literal_validator.h"], + deps = [ + ":validator", + "//common:constant", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:time", + "//tools:navigable_ast", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +licenses(["notice"]) diff --git a/validator/timestamp_literal_validator.cc b/validator/timestamp_literal_validator.cc new file mode 100644 index 000000000..8b9b76ebb --- /dev/null +++ b/validator/timestamp_literal_validator.cc @@ -0,0 +1,134 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "common/navigable_ast.h" +#include "common/standard_definitions.h" +#include "internal/time.h" +#include "tools/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +bool ValidateTimestamps(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kTimestamp) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + absl::Time ts; + const Constant& constant = child.expr()->const_expr(); + if (constant.has_string_value()) { + absl::string_view timestamp_str = + child.expr()->const_expr().string_value(); + if (!absl::ParseTime(absl::RFC3339_full, timestamp_str, &ts, nullptr)) { + context.ReportErrorAt(child.expr()->id(), "invalid timestamp literal"); + valid = false; + continue; + } + } else if (constant.has_int_value()) { + ts = absl::FromUnixSeconds(constant.int_value()); + } else { + // Checker should have already reported an error. + continue; + } + + if (absl::Status status = internal::ValidateTimestamp(ts); !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid timestamp literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +bool ValidateDurations(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kDuration) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + const Constant& constant = child.expr()->const_expr(); + if (!constant.has_string_value()) { + continue; + } + absl::Duration duration; + + absl::string_view duration_str = child.expr()->const_expr().string_value(); + if (!absl::ParseDuration(duration_str, &duration)) { + context.ReportErrorAt(child.expr()->id(), "invalid duration literal"); + valid = false; + continue; + } + + if (absl::Status status = internal::ValidateDuration(duration); + !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid duration literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +} // namespace + +const Validation& TimestampLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateTimestamps, "cel.validator.timestamp"); + return *kInstance; +} + +// Returns a validator that checks duration literals. +const Validation& DurationLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateDurations, "cel.validator.duration"); + return *kInstance; +} + +} // namespace cel diff --git a/validator/timestamp_literal_validator.h b/validator/timestamp_literal_validator.h new file mode 100644 index 000000000..6d2a39318 --- /dev/null +++ b/validator/timestamp_literal_validator.h @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ + +#include "validator/validator.h" +namespace cel { + +// Returns a `Validation` that checks timestamp literals are valid for CEL. +const Validation& TimestampLiteralValidator(); + +// Returns a `Validation` that checks duration literals are valid for CEL. +const Validation& DurationLiteralValidator(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ diff --git a/validator/timestamp_literal_validator_test.cc b/validator/timestamp_literal_validator_test.cc new file mode 100644 index 000000000..136f7d645 --- /dev/null +++ b/validator/timestamp_literal_validator_test.cc @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + auto builder = + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()).value(); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + return builder->Build(); +} + +class TimestampLiteralValidatorTest : public ::testing::Test { + protected: + TimestampLiteralValidatorTest() { + validator_.AddValidation(TimestampLiteralValidator()); + } + + std::unique_ptr compiler_; + Validator validator_; +}; + +TEST(TimestampLiteralValidatorTest, FormatsIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile("timestamp('invalid')")); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_EQ(result.FormatError(), + R"(ERROR: :1:11: invalid timestamp literal + | timestamp('invalid') + | ..........^)"); +} + +TEST(TimestampLiteralValidatorTest, AccumulatesIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + constexpr absl::string_view kExpression = R"cel( + [ timestamp('invalid'), + timestamp('9999-12-31T23:59:59Z'), + timestamp('10000-01-01T00:00:00Z') + ].all(t, + t - timestamp(0) < duration('10000s') && + t - timestamp(0) > duration("invalid") + ))cel"; + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile(kExpression)); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + AllOf(HasSubstr("2:17: invalid timestamp literal"), + HasSubstr("4:17: invalid timestamp literal"), + HasSubstr("7:35: invalid duration literal"))); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using TimestampLiteralValidatorParameterizedTest = + testing::TestWithParam; + +TEST_P(TimestampLiteralValidatorParameterizedTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + TimestampLiteralValidatorParameterizedTest, + TimestampLiteralValidatorParameterizedTest, + ::testing::Values( + TestCase{"timestamp('2023-01-01T00:00:00Z')", true}, + TestCase{"timestamp('9999-12-31T23:59:59Z')", true}, + TestCase{"timestamp('invalid')", false, "invalid timestamp literal"}, + TestCase{"timestamp('10000-01-01T00:00:00Z')", false, + "invalid timestamp literal"}, + TestCase{"timestamp(0)", true}, + TestCase{"timestamp(-62135596801)", false, + "invalid timestamp literal: Timestamp \"0-12-31T23:59:59Z\" " + "below minimum allowed timestamp \"1-01-01T00:00:00Z\""}, + TestCase{"timestamp(253402300800)", false, + "invalid timestamp literal: Timestamp " + "\"10000-01-01T00:00:00Z\" above maximum allowed timestamp " + "\"9999-12-31T23:59:59.999999999Z\""}, + TestCase{"duration('1s')", true}, + TestCase{"duration('invalid')", false, "invalid duration literal"}, + TestCase{"duration('-1000000000000s')", false, + "below minimum allowed duration"}, + TestCase{"duration('1000000000000s')", false, + "above maximum allowed duration"})); + +} // namespace +} // namespace cel diff --git a/validator/validator.cc b/validator/validator.cc new file mode 100644 index 000000000..e000c71e8 --- /dev/null +++ b/validator/validator.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" + +namespace cel { + +void Validator::AddValidation(Validation validation) { + ABSL_DCHECK(validation); + if (!validation) return; + validations_.push_back(std::move(validation)); +} + +Validator::ValidationOutput Validator::Validate(const Ast& ast) const { + ValidationOutput result; + ValidationContext context(ast); + for (const auto& validation : validations_) { + if (!validation(context)) { + result.valid = false; + } + } + result.issues = context.ReleaseIssues(); + return result; +} + +void Validator::UpdateValidationResult(ValidationResult& in) const { + if (!in.IsValid() || in.GetAst() == nullptr) { + // If the result is already decided invalid, just return it. + return; + } + + auto result = Validate(*in.GetAst()); + if (!result.valid) { + in.ReleaseAst().IgnoreError(); + } + for (auto& issue : result.issues) { + in.AddIssue(std::move(issue)); + } +} + +void ValidationContext::ReportWarningAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportErrorAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportWarning(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + SourceLocation{}, std::string(message))); +} + +void ValidationContext::ReportError(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + SourceLocation{}, std::string(message))); +} + +} // namespace cel diff --git a/validator/validator.h b/validator/validator.h new file mode 100644 index 000000000..a278bd44f --- /dev/null +++ b/validator/validator.h @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/navigable_ast.h" +namespace cel { + +// Context for a validation pass. +// +// Assumed to be scoped to a Validator::Validate() call. Instances must not +// outlive the `ast` passed to the constructor. +class ValidationContext { + public: + explicit ValidationContext(const Ast& ast ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ast_(ast) {} + + const Ast& ast() const { return ast_; } + const NavigableAst& navigable_ast() const { + if (!navigable_ast_) { + navigable_ast_ = NavigableAst::Build(ast_.root_expr()); + } + return navigable_ast_; + } + + void ReportWarningAt(int64_t id, absl::string_view message); + void ReportErrorAt(int64_t id, absl::string_view message); + void ReportWarning(absl::string_view message); + void ReportError(absl::string_view message); + + std::vector ReleaseIssues() { + auto out = std::move(issues_); + issues_.clear(); + return out; + } + + private: + const Ast& ast_; + mutable NavigableAst navigable_ast_; + std::vector issues_; +}; + +// A single validation to apply to an AST. +// +// May be empty if default constructed or moved from. +// use operator bool() to check if the validation is empty. +class Validation { + public: + // Tests the AST reports any issues to the context. + // + // Returns false if the AST is invalid. + // + // The same instance is used across Validate() so must be thread safe + // (typically stateless). + using ImplFunction = + absl::AnyInvocable; + + Validation() = default; + explicit Validation(ImplFunction impl); + Validation(ImplFunction impl, absl::string_view id); + + const ImplFunction& impl() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl; + } + + absl::string_view id() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->id; + } + + bool operator()(ValidationContext& context) const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl(context); + } + + explicit operator bool() const { return rep_ != nullptr; } + + private: + struct Rep { + ImplFunction impl; + // Optional id if supported in environment config. + std::string id; + }; + + std::shared_ptr rep_; +}; + +// A validator checks a set of semantic rules for a given AST. +class Validator { + public: + Validator() = default; + + void AddValidation(Validation validation); + absl::Span validations() const { return validations_; } + + struct ValidationOutput { + bool valid = true; + std::vector issues; + }; + + // Validates the given AST by applying all of the validations. + ValidationOutput Validate(const Ast& ast) const; + + // Validates the given AST, updating the validation result in place. + // + // Used to apply validators to the output of the type checker. + void UpdateValidationResult(ValidationResult& in) const; + + private: + std::vector validations_; +}; + +// Implementation details. +inline Validation::Validation(ImplFunction impl) + : rep_(std::make_shared( + Validation::Rep{std::move(impl)})) {} + +inline Validation::Validation(ImplFunction impl, absl::string_view id) + : rep_(std::make_shared( + Validation::Rep{std::move(impl), std::string(id)})) {} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ diff --git a/validator/validator_test.cc b/validator/validator_test.cc new file mode 100644 index 000000000..744475ec1 --- /dev/null +++ b/validator/validator_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; + +TEST(ValidatorTest, AddValidationAndValidate) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportError("error 1"); + return false; + })); + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportWarning("warning 1"); + return true; + })); + + Ast ast; + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + ElementsAre(Property(&TypeCheckIssue::message, Eq("error 1")), + Property(&TypeCheckIssue::message, Eq("warning 1")))); + EXPECT_EQ(output.issues[0].severity(), TypeCheckIssue::Severity::kError); + EXPECT_EQ(output.issues[1].severity(), TypeCheckIssue::Severity::kWarning); +} + +TEST(ValidatorTest, ReportAt) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportErrorAt(1, "error at 1"); + context.ReportWarningAt(2, "warning at 2"); + return false; + })); + + Expr expr; + expr.set_id(1); + SourceInfo source_info; + source_info.mutable_positions()[1] = 10; + source_info.mutable_positions()[2] = 20; + source_info.set_line_offsets({15, 25}); + + Ast ast(std::move(expr), std::move(source_info)); + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + ASSERT_EQ(output.issues.size(), 2); + + EXPECT_EQ(output.issues[0].location().line, 1); + EXPECT_EQ(output.issues[0].location().column, 10); + + EXPECT_EQ(output.issues[1].location().line, 2); + EXPECT_EQ(output.issues[1].location().column, 5); +} + +} // namespace +} // namespace cel From ea0deec704ddafab3b4bdad6c50d1906cdbcd7df Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 11:16:40 -0700 Subject: [PATCH 026/143] Add AST depth validator. PiperOrigin-RevId: 895998515 --- validator/BUILD | 26 +++++++++ validator/ast_depth_validator.cc | 34 +++++++++++ validator/ast_depth_validator.h | 27 +++++++++ validator/ast_depth_validator_test.cc | 81 +++++++++++++++++++++++++++ 4 files changed, 168 insertions(+) create mode 100644 validator/ast_depth_validator.cc create mode 100644 validator/ast_depth_validator.h create mode 100644 validator/ast_depth_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index 65f7fd6b3..dec3d8616 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -86,4 +86,30 @@ cc_library( ], ) +cc_library( + name = "ast_depth_validator", + srcs = ["ast_depth_validator.cc"], + hdrs = ["ast_depth_validator.h"], + deps = [ + ":validator", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ast_depth_validator_test", + srcs = ["ast_depth_validator_test.cc"], + deps = [ + ":ast_depth_validator", + ":validator", + "//checker:type_check_issue", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/log:absl_check", + ], +) + licenses(["notice"]) diff --git a/validator/ast_depth_validator.cc b/validator/ast_depth_validator.cc new file mode 100644 index 000000000..0f6b8d93d --- /dev/null +++ b/validator/ast_depth_validator.cc @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include "absl/strings/str_cat.h" +#include "validator/validator.h" + +namespace cel { + +Validation AstDepthValidator(int max_depth) { + return Validation([max_depth](ValidationContext& context) { + int height = context.navigable_ast().Root().height(); + if (height > max_depth) { + context.ReportError(absl::StrCat("AST depth ", height, + " exceeds maximum of ", max_depth)); + return false; + } + return true; + }); +} + +} // namespace cel diff --git a/validator/ast_depth_validator.h b/validator/ast_depth_validator.h new file mode 100644 index 000000000..a640af12e --- /dev/null +++ b/validator/ast_depth_validator.h @@ -0,0 +1,27 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks the AST depth is less than or equal to +// max_depth. +Validation AstDepthValidator(int max_depth); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ diff --git a/validator/ast_depth_validator_test.cc b/validator/ast_depth_validator_test.cc new file mode 100644 index 000000000..eda59b40d --- /dev/null +++ b/validator/ast_depth_validator_test.cc @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "checker/type_check_issue.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +std::unique_ptr CreateCompiler() { + auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(builder); + ABSL_CHECK_OK((*builder)->AddLibrary(StandardCompilerLibrary())); + auto compiler = (*builder)->Build(); + ABSL_CHECK_OK(compiler); + return *std::move(compiler); +} + +TEST(AstDepthValidatorTest, Basic) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + 2 + 3")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(2)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 3 exceeds maximum of 2")))); +} + +TEST(AstDepthValidatorTest, Nested) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile("1 + (2 + (3 + (4 + 5)))")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(4)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 5 exceeds maximum of 4")))); +} + +} // namespace +} // namespace cel From 1b2a84162ac9f891a22b2b4d3e1018d190184b45 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 14:51:06 -0700 Subject: [PATCH 027/143] Expose TypeSpec formatter in common/ast/metadata.h PiperOrigin-RevId: 896098407 --- common/ast/BUILD | 3 + common/ast/metadata.cc | 117 ++++++++++++++++++++++++++++++++ common/ast/metadata.h | 3 + testutil/baseline_tests.cc | 71 +------------------ testutil/baseline_tests_test.cc | 2 +- 5 files changed, 125 insertions(+), 71 deletions(-) diff --git a/common/ast/BUILD b/common/ast/BUILD index 410d38c65..17456566b 100644 --- a/common/ast/BUILD +++ b/common/ast/BUILD @@ -98,10 +98,13 @@ cc_library( deps = [ "//common:constant", "//common:expr", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc index eecb0dbb3..38f7ef610 100644 --- a/common/ast/metadata.cc +++ b/common/ast/metadata.cc @@ -14,11 +14,18 @@ #include "common/ast/metadata.h" +#include #include +#include +#include +#include #include #include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" #include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" #include "absl/types/variant.h" namespace cel { @@ -30,6 +37,96 @@ const TypeSpec& DefaultTypeSpec() { return *type; } +std::string FormatPrimitive(PrimitiveType t) { + switch (t) { + case PrimitiveType::kBool: + return "bool"; + case PrimitiveType::kInt64: + return "int"; + case PrimitiveType::kUint64: + return "uint"; + case PrimitiveType::kDouble: + return "double"; + case PrimitiveType::kString: + return "string"; + case PrimitiveType::kBytes: + return "bytes"; + default: + return "*unspecified primitive*"; + } +} + +std::string FormatWellKnown(WellKnownTypeSpec t) { + switch (t) { + case WellKnownTypeSpec::kAny: + return "google.protobuf.Any"; + case WellKnownTypeSpec::kDuration: + return "google.protobuf.Duration"; + case WellKnownTypeSpec::kTimestamp: + return "google.protobuf.Timestamp"; + default: + return "*unspecified well known*"; + } +} + +using FormatIns = std::variant; +using FormatStack = std::vector; + +void HandleFormatTypeSpec(const TypeSpec& t, FormatStack& stack, + std::string* out) { + if (t.has_dyn()) { + absl::StrAppend(out, "dyn"); + } else if (t.has_null()) { + absl::StrAppend(out, "null"); + } else if (t.has_primitive()) { + absl::StrAppend(out, FormatPrimitive(t.primitive())); + } else if (t.has_wrapper()) { + absl::StrAppend(out, "wrapper(", FormatPrimitive(t.wrapper()), ")"); + } else if (t.has_well_known()) { + absl::StrAppend(out, FormatWellKnown(t.well_known())); + return; + } else if (t.has_abstract_type()) { + const auto& abs_type = t.abstract_type(); + if (abs_type.parameter_types().empty()) { + absl::StrAppend(out, abs_type.name()); + return; + } + absl::StrAppend(out, abs_type.name(), "("); + stack.push_back(")"); + for (size_t i = abs_type.parameter_types().size(); i > 0; --i) { + stack.push_back(&abs_type.parameter_types()[i - 1]); + if (i > 1) { + stack.push_back(", "); + } + } + + } else if (t.has_type()) { + if (t.type() == TypeSpec()) { + absl::StrAppend(out, "type"); + return; + } + absl::StrAppend(out, "type("); + stack.push_back(")"); + stack.push_back(&t.type()); + } else if (t.has_message_type()) { + absl::StrAppend(out, t.message_type().type()); + } else if (t.has_type_param()) { + absl::StrAppend(out, t.type_param().type()); + } else if (t.has_list_type()) { + absl::StrAppend(out, "list("); + stack.push_back(")"); + stack.push_back(&t.list_type().elem_type()); + } else if (t.has_map_type()) { + absl::StrAppend(out, "map("); + stack.push_back(")"); + stack.push_back(&t.map_type().value_type()); + stack.push_back(", "); + stack.push_back(&t.map_type().key_type()); + } else { + absl::StrAppend(out, "*error*"); + } +} + TypeSpecKind CopyImpl(const TypeSpecKind& other) { return absl::visit( absl::Overload( @@ -142,4 +239,24 @@ FunctionTypeSpec& FunctionTypeSpec::operator=(const FunctionTypeSpec& other) { return *this; } +std::string FormatTypeSpec(const TypeSpec& t) { + // Use a stack to avoid recursion. + // Probably overly defensive, but fuzzers will often notice the recursion + // and try to trigger it. + std::string out; + FormatStack seq; + seq.push_back(&t); + while (!seq.empty()) { + FormatIns ins = std::move(seq.back()); + seq.pop_back(); + if (std::holds_alternative(ins)) { + absl::StrAppend(&out, std::get(ins)); + continue; + } + ABSL_DCHECK(std::holds_alternative(ins)); + HandleFormatTypeSpec(*std::get(ins), seq, &out); + } + return out; +} + } // namespace cel diff --git a/common/ast/metadata.h b/common/ast/metadata.h index a82e999f8..197790ff3 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -740,6 +740,9 @@ class TypeSpec { TypeSpecKind type_kind_; }; +// Returns a string representation of the given TypeSpec. +std::string FormatTypeSpec(const TypeSpec& t); + // Describes a resolved reference to a declaration. class Reference { public: diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index 4e56ad485..8ce43e63d 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -28,75 +28,6 @@ namespace cel::test { namespace { -std::string FormatPrimitive(PrimitiveType t) { - switch (t) { - case PrimitiveType::kBool: - return "bool"; - case PrimitiveType::kInt64: - return "int"; - case PrimitiveType::kUint64: - return "uint"; - case PrimitiveType::kDouble: - return "double"; - case PrimitiveType::kString: - return "string"; - case PrimitiveType::kBytes: - return "bytes"; - default: - return ""; - } -} - -std::string FormatType(const TypeSpec& t) { - if (t.has_dyn()) { - return "dyn"; - } else if (t.has_null()) { - return "null"; - } else if (t.has_primitive()) { - return FormatPrimitive(t.primitive()); - } else if (t.has_wrapper()) { - return absl::StrCat("wrapper(", FormatPrimitive(t.wrapper()), ")"); - } else if (t.has_well_known()) { - switch (t.well_known()) { - case WellKnownTypeSpec::kAny: - return "google.protobuf.Any"; - case WellKnownTypeSpec::kDuration: - return "google.protobuf.Duration"; - case WellKnownTypeSpec::kTimestamp: - return "google.protobuf.Timestamp"; - default: - return ""; - } - } else if (t.has_abstract_type()) { - const auto& abs_type = t.abstract_type(); - std::string s = abs_type.name(); - if (!abs_type.parameter_types().empty()) { - absl::StrAppend(&s, "(", - absl::StrJoin(abs_type.parameter_types(), ",", - [](std::string* out, const auto& t) { - absl::StrAppend(out, FormatType(t)); - }), - ")"); - } - return s; - } else if (t.has_type()) { - if (t.type() == TypeSpec()) { - return "type"; - } - return absl::StrCat("type(", FormatType(t.type()), ")"); - } else if (t.has_message_type()) { - return t.message_type().type(); - } else if (t.has_type_param()) { - return t.type_param().type(); - } else if (t.has_list_type()) { - return absl::StrCat("list(", FormatType(t.list_type().elem_type()), ")"); - } else if (t.has_map_type()) { - return absl::StrCat("map(", FormatType(t.map_type().key_type()), ", ", - FormatType(t.map_type().value_type()), ")"); - } - return ""; -} - std::string FormatReference(const cel::Reference& r) { if (r.overload_id().empty()) { return r.name(); @@ -113,7 +44,7 @@ class TypeAdorner : public ExpressionAdorner { auto t = ast_.type_map().find(e.id()); if (t != ast_.type_map().end()) { - absl::StrAppend(&s, "~", FormatType(t->second)); + absl::StrAppend(&s, "~", FormatTypeSpec(t->second)); } if (const auto r = ast_.reference_map().find(e.id()); r != ast_.reference_map().end()) { diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index 33050583f..f4e89706c 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -184,7 +184,7 @@ INSTANTIATE_TEST_SUITE_P( "x~google.protobuf.Timestamp"}, TestCase{TypeSpec(DynTypeSpec()), "x~dyn"}, TestCase{TypeSpec(NullTypeSpec()), "x~null"}, - TestCase{TypeSpec(UnsetTypeSpec()), "x~"}, + TestCase{TypeSpec(UnsetTypeSpec()), "x~*error*"}, TestCase{TypeSpec(MessageTypeSpec("com.example.Type")), "x~com.example.Type"}, TestCase{TypeSpec(AbstractType("optional_type", From 7022451780867bd9c173956f5793b5dfb600928e Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 7 Apr 2026 16:58:41 -0700 Subject: [PATCH 028/143] Expose EnvRuntime::RegisterExtensionFunctions API PiperOrigin-RevId: 896158131 --- env/BUILD | 7 +++++++ env/config.cc | 8 +++++++- env/env_runtime.cc | 12 ++++++++++++ env/env_runtime.h | 12 ++++++++++++ env/env_runtime_test.cc | 39 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 1 deletion(-) diff --git a/env/BUILD b/env/BUILD index 8d477cc1f..55297b190 100644 --- a/env/BUILD +++ b/env/BUILD @@ -81,7 +81,10 @@ cc_library( "//runtime:runtime_builder_factory", "//runtime:runtime_options", "//runtime:standard_functions", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) @@ -236,10 +239,14 @@ cc_test( "//common:source", "//common:value", "//compiler", + "//extensions:math_ext", + "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", diff --git a/env/config.cc b/env/config.cc index ccb4de34c..202a607bf 100644 --- a/env/config.cc +++ b/env/config.cc @@ -58,9 +58,15 @@ absl::Status Config::AddExtensionConfig(std::string name, int version) { if (extension_config.version == version) { return absl::OkStatus(); } + std::string version_str; + if (version == ExtensionConfig::kLatest) { + version_str = "'latest'"; + } else { + version_str = absl::StrCat(version); + } return absl::AlreadyExistsError(absl::StrCat( "Extension '", name, "' version ", extension_config.version, - " is already included. Cannot also include version ", version)); + " is already included. Cannot also include version ", version_str)); } } extension_configs_.push_back( diff --git a/env/env_runtime.cc b/env/env_runtime.cc index 09bbcde04..33e0747cc 100644 --- a/env/env_runtime.cc +++ b/env/env_runtime.cc @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "internal/status_macros.h" #include "runtime/runtime.h" @@ -29,6 +32,15 @@ namespace cel { +void EnvRuntime::RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback) { + extension_registry_.AddFunctionRegistration( + name, alias, version, std::move(function_registration_callback)); +} + absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { const std::vector& extension_configs = config_.GetExtensionConfigs(); diff --git a/env/env_runtime.h b/env/env_runtime.h index ff62ec1d4..63473c295 100644 --- a/env/env_runtime.h +++ b/env/env_runtime.h @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "env/internal/runtime_ext_registry.h" #include "runtime/runtime.h" @@ -41,6 +44,15 @@ namespace cel { // compilation. This ensures consistency between compilation and runtime. class EnvRuntime { public: + // Registers a function registration callback for an extension. The callback + // is invoked when a runtime is created, if the corresponding functions are + // enabled in the runtime config. + void RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback); + void SetDescriptorPool( std::shared_ptr descriptor_pool) { descriptor_pool_ = std::move(descriptor_pool); diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc index 1c4205224..47892772c 100644 --- a/env/env_runtime_test.cc +++ b/env/env_runtime_test.cc @@ -31,10 +31,13 @@ #include "env/env_std_extensions.h" #include "env/env_yaml.h" #include "env/runtime_std_extensions.h" +#include "extensions/math_ext.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel { @@ -156,5 +159,41 @@ std::vector GetEnvRuntimeTestCases() { INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, ValuesIn(GetEnvRuntimeTestCases())); +TEST(EnvRuntimeTest, RegisterExtensionFunctions) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("math.sqrt(4) == 2.0")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + env_runtime.RegisterExtensionFunctions( + "cel.lib.math", "math", 2, + [](cel::RuntimeBuilder& runtime_builder, + const cel::RuntimeOptions& opts) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), opts, 2); + }); + env_runtime.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()); +} } // namespace } // namespace cel From e9ba71a59cb8c7c35315c330c79119dc61d68733 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 13 Apr 2026 09:05:30 -0700 Subject: [PATCH 029/143] Use proto2::ConstMapIterator and stop const_cast'ing. PiperOrigin-RevId: 899029119 --- common/values/parsed_json_map_value.cc | 4 +- common/values/parsed_map_field_value.cc | 26 ++++++------ .../protobuf/internal/map_reflection.cc | 42 ++++++++----------- extensions/protobuf/internal/map_reflection.h | 12 +++--- internal/json.cc | 10 ++--- internal/message_equality.cc | 8 ++-- internal/well_known_types.cc | 12 +++--- internal/well_known_types.h | 9 ++-- 8 files changed, 59 insertions(+), 64 deletions(-) diff --git a/common/values/parsed_json_map_value.cc b/common/values/parsed_json_map_value.cc index 6072a0b21..ec8c91a4f 100644 --- a/common/values/parsed_json_map_value.cc +++ b/common/values/parsed_json_map_value.cc @@ -408,8 +408,8 @@ class ParsedJsonMapValueIterator final : public ValueIterator { private: const google::protobuf::Message* absl_nonnull const message_; const well_known_types::StructReflection reflection_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; std::string scratch_; }; diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc index 737593cca..47b737f82 100644 --- a/common/values/parsed_map_field_value.cc +++ b/common/values/parsed_map_field_value.cc @@ -415,10 +415,10 @@ absl::Status ParsedMapFieldValue::ListKeys( field_->message_type()->map_key())); auto builder = NewListValueBuilder(arena); builder->Reserve(Size()); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, *message_, *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); for (; begin != end; ++begin) { Value scratch; (*key_accessor)(begin.GetKey(), message_, arena, &scratch); @@ -446,10 +446,10 @@ absl::Status ParsedMapFieldValue::ForEach( CEL_ASSIGN_OR_RETURN( auto value_accessor, common_internal::MapFieldValueAccessorFor(value_field)); - auto begin = extensions::protobuf_internal::MapBegin(*reflection, *message_, - *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin( + *reflection, *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); Value key_scratch; Value value_scratch; for (; begin != end; ++begin) { @@ -479,10 +479,10 @@ class ParsedMapFieldValueIterator final : public ValueIterator { value_field_(field->message_type()->map_value()), key_accessor_(key_accessor), value_accessor_(value_accessor), - begin_(extensions::protobuf_internal::MapBegin( + begin_(extensions::protobuf_internal::ConstMapBegin( *message_->GetReflection(), *message_, *field)), - end_(extensions::protobuf_internal::MapEnd(*message_->GetReflection(), - *message_, *field)) {} + end_(extensions::protobuf_internal::ConstMapEnd( + *message_->GetReflection(), *message_, *field)) {} bool HasNext() override { return begin_ != end_; } @@ -545,8 +545,8 @@ class ParsedMapFieldValueIterator final : public ValueIterator { const google::protobuf::FieldDescriptor* absl_nonnull const value_field_; const absl_nonnull common_internal::MapFieldKeyAccessor key_accessor_; const absl_nonnull common_internal::MapFieldValueAccessor value_accessor_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; }; } // namespace diff --git a/extensions/protobuf/internal/map_reflection.cc b/extensions/protobuf/internal/map_reflection.cc index 22a6dc23c..605e4437d 100644 --- a/extensions/protobuf/internal/map_reflection.cc +++ b/extensions/protobuf/internal/map_reflection.cc @@ -42,22 +42,16 @@ class CelMapReflectionFriend final { return reflection.MapSize(message, &field); } - static google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapBegin( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapBegin( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapBegin(&message, &field); } - static google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapEnd( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapEnd( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapEnd(&message, &field); } static bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, @@ -104,18 +98,18 @@ int MapSize(const google::protobuf::Reflection& reflection, field); } -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapBegin(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapBegin(reflection, + message, field); } -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapEnd(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapEnd(reflection, message, + field); } bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, diff --git a/extensions/protobuf/internal/map_reflection.h b/extensions/protobuf/internal/map_reflection.h index 6e696bbe3..681d7693d 100644 --- a/extensions/protobuf/internal/map_reflection.h +++ b/extensions/protobuf/internal/map_reflection.h @@ -42,13 +42,13 @@ int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, diff --git a/internal/json.cc b/internal/json.cc index 200d18bfb..630ceb267 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -803,10 +803,10 @@ class MessageToJsonState { const auto* value_descriptor = field->message_type()->map_value(); CEL_ASSIGN_OR_RETURN(const auto value_to_value, GetMapFieldValueToValue(value_descriptor)); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, message, *field); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, message, *field); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + message, *field); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, message, *field); for (; begin != end; ++begin) { auto key = (*key_to_string)(begin.GetKey()); CEL_RETURN_IF_ERROR((this->*value_to_value)( @@ -1381,7 +1381,7 @@ class JsonMapIterator final { using Generated = typename google::protobuf::Map::const_iterator; - using Dynamic = google::protobuf::MapIterator; + using Dynamic = google::protobuf::ConstMapIterator; using Value = std::pair; diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 628432d66..945cca8df 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -50,9 +50,9 @@ namespace cel::internal { namespace { +using ::cel::extensions::protobuf_internal::ConstMapBegin; +using ::cel::extensions::protobuf_internal::ConstMapEnd; using ::cel::extensions::protobuf_internal::LookupMapValue; -using ::cel::extensions::protobuf_internal::MapBegin; -using ::cel::extensions::protobuf_internal::MapEnd; using ::cel::extensions::protobuf_internal::MapSize; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; @@ -904,8 +904,8 @@ class MessageEqualsState final { MapSize(*rhs_reflection, rhs, *rhs_field)) { return false; } - auto lhs_begin = MapBegin(*lhs_reflection, lhs, *lhs_field); - const auto lhs_end = MapEnd(*lhs_reflection, lhs, *lhs_field); + auto lhs_begin = ConstMapBegin(*lhs_reflection, lhs, *lhs_field); + const auto lhs_end = ConstMapEnd(*lhs_reflection, lhs, *lhs_field); Unique lhs_unpacked; EquatableValue lhs_value; Unique rhs_unpacked; diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index c736be69f..dee029534 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1643,20 +1643,20 @@ int StructReflection::FieldsSize(const google::protobuf::Message& message) const message, *fields_field_); } -google::protobuf::MapIterator StructReflection::BeginFields( +google::protobuf::ConstMapIterator StructReflection::BeginFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapBegin(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapBegin( + *message.GetReflection(), message, *fields_field_); } -google::protobuf::MapIterator StructReflection::EndFields( +google::protobuf::ConstMapIterator StructReflection::EndFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapEnd(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapEnd( + *message.GetReflection(), message, *fields_field_); } bool StructReflection::ContainsField(const google::protobuf::Message& message, diff --git a/internal/well_known_types.h b/internal/well_known_types.h index dce88a420..f63e5e76b 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -698,8 +698,9 @@ absl::StatusOr GetAnyReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); -AnyReflection GetAnyReflectionOrDie(const google::protobuf::Descriptor* absl_nonnull - descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); +AnyReflection GetAnyReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); class DurationReflection final { public: @@ -1193,10 +1194,10 @@ class StructReflection final { int FieldsSize(const google::protobuf::Message& message) const; - google::protobuf::MapIterator BeginFields( + google::protobuf::ConstMapIterator BeginFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; - google::protobuf::MapIterator EndFields( + google::protobuf::ConstMapIterator EndFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; bool ContainsField(const google::protobuf::Message& message, From 15009ff50e3eb011b8e0ba3b46d55c559f03ac95 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 15:00:28 -0700 Subject: [PATCH 030/143] Add homogeneous literal validator. Ported from the go implementation. PiperOrigin-RevId: 899199115 --- validator/BUILD | 35 ++++ validator/homogeneous_literal_validator.cc | 190 ++++++++++++++++++ validator/homogeneous_literal_validator.h | 38 ++++ .../homogeneous_literal_validator_test.cc | 145 +++++++++++++ 4 files changed, 408 insertions(+) create mode 100644 validator/homogeneous_literal_validator.cc create mode 100644 validator/homogeneous_literal_validator.h create mode 100644 validator/homogeneous_literal_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index dec3d8616..98d1316c7 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -96,6 +96,41 @@ cc_library( ], ) +cc_library( + name = "homogeneous_literal_validator", + srcs = ["homogeneous_literal_validator.cc"], + hdrs = ["homogeneous_literal_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "homogeneous_literal_validator_test", + srcs = ["homogeneous_literal_validator_test.cc"], + deps = [ + ":homogeneous_literal_validator", + ":validator", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:strings", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_test( name = "ast_depth_validator_test", srcs = ["ast_depth_validator_test.cc"], diff --git a/validator/homogeneous_literal_validator.cc b/validator/homogeneous_literal_validator.cc new file mode 100644 index 000000000..4a490dea2 --- /dev/null +++ b/validator/homogeneous_literal_validator.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool InExemptFunction(const NavigableAstNode& node, + const std::vector& exempt_functions) { + const NavigableAstNode* parent = node.parent(); + while (parent != nullptr) { + if (parent->node_kind() == NodeKind::kCall) { + absl::string_view fn_name = parent->expr()->call_expr().function(); + for (const auto& exempt : exempt_functions) { + if (exempt == fn_name) { + return true; + } + } + } + parent = parent->parent(); + } + return false; +} + +bool IsOptional(const TypeSpec& t) { + return t.has_abstract_type() && t.abstract_type().name() == "optional_type"; +} + +const TypeSpec& GetOptionalParameter(const TypeSpec& t) { + return t.abstract_type().parameter_types()[0]; +} + +void TypeMismatch(ValidationContext& context, int64_t id, + const TypeSpec& expected, const TypeSpec& actual) { + context.ReportErrorAt( + id, absl::StrCat("expected type '", FormatTypeSpec(expected), + "' but found '", FormatTypeSpec(actual), "'")); +} + +bool TypeEquiv(const TypeSpec& a, const TypeSpec& b) { + if (a == b) { + return true; + } + + if (a.has_error() || b.has_error()) { + // Don't report mismatch if there's an error (type checking failed for the + // expression). + return true; + } + + if (a.has_wrapper() && b.has_primitive()) { + return a.wrapper() == b.primitive(); + } else if (a.has_primitive() && b.has_wrapper()) { + return a.primitive() == b.wrapper(); + } + + if (a.has_list_type() && b.has_list_type()) { + return TypeEquiv(a.list_type().elem_type(), b.list_type().elem_type()); + } + + if (a.has_map_type() && b.has_map_type()) { + return TypeEquiv(a.map_type().key_type(), b.map_type().key_type()) && + TypeEquiv(a.map_type().value_type(), b.map_type().value_type()); + } + + if (a.has_abstract_type() && b.has_abstract_type() && + a.abstract_type().name() == b.abstract_type().name() && + a.abstract_type().parameter_types().size() == + b.abstract_type().parameter_types().size()) { + for (int i = 0; i < a.abstract_type().parameter_types().size(); ++i) { + if (!TypeEquiv(a.abstract_type().parameter_types()[i], + b.abstract_type().parameter_types()[i])) { + return false; + } + } + return true; + } + + return false; +} + +} // namespace + +Validation HomogeneousLiteralValidator( + std::vector exempt_functions) { + return Validation([exempt_functions = std::move(exempt_functions)]( + ValidationContext& context) -> bool { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kList) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& list_expr = node.expr()->list_expr(); + const auto& elements = list_expr.elements(); + const TypeSpec* expected_type = nullptr; + + for (const auto& element : elements) { + int64_t id = element.expr().id(); + const TypeSpec& actual_type = context.ast().GetTypeOrDyn(id); + const TypeSpec* type_to_check = &actual_type; + + if (element.optional() && IsOptional(actual_type)) { + type_to_check = &GetOptionalParameter(actual_type); + } + + if (expected_type == nullptr) { + expected_type = type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_type, *type_to_check))) { + TypeMismatch(context, id, *expected_type, *type_to_check); + valid = false; + break; + } + } + } else if (node.node_kind() == NodeKind::kMap) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& map_expr = node.expr()->map_expr(); + const auto& entries = map_expr.entries(); + const TypeSpec* expected_key_type = nullptr; + const TypeSpec* expected_value_type = nullptr; + + for (const auto& entry : entries) { + int64_t key_id = entry.key().id(); + int64_t val_id = entry.value().id(); + const TypeSpec& actual_key_type = context.ast().GetTypeOrDyn(key_id); + const TypeSpec& actual_val_type = context.ast().GetTypeOrDyn(val_id); + const TypeSpec* key_type_to_check = &actual_key_type; + const TypeSpec* val_type_to_check = &actual_val_type; + + if (entry.optional() && IsOptional(actual_val_type)) { + val_type_to_check = &GetOptionalParameter(actual_val_type); + } + + if (expected_key_type == nullptr) { + expected_key_type = key_type_to_check; + expected_value_type = val_type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_key_type, *key_type_to_check))) { + TypeMismatch(context, key_id, *expected_key_type, + *key_type_to_check); + valid = false; + break; + } + if (!(TypeEquiv(*expected_value_type, *val_type_to_check))) { + TypeMismatch(context, val_id, *expected_value_type, + *val_type_to_check); + valid = false; + break; + } + } + } + } + return valid; + }); +} + +} // namespace cel diff --git a/validator/homogeneous_literal_validator.h b/validator/homogeneous_literal_validator.h new file mode 100644 index 000000000..e37648a25 --- /dev/null +++ b/validator/homogeneous_literal_validator.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ + +#include +#include + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that all literals in map or list literals +// are the same type. If the list or map is part of an argument to an exempted +// function, it is not checked. +Validation HomogeneousLiteralValidator( + std::vector exempt_functions); + +inline Validation HomogeneousLiteralValidator() { + // Default to exempting the strings extension "format" function. + return HomogeneousLiteralValidator({"format"}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ diff --git a/validator/homogeneous_literal_validator_test.cc b/validator/homogeneous_literal_validator_test.cc new file mode 100644 index 000000000..b027fa4b0 --- /dev/null +++ b/validator/homogeneous_literal_validator_test.cc @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); + builder->AddLibrary(extensions::StringsCompilerLibrary()).IgnoreError(); + cel::Type message_type = cel::Type::Message( + builder->GetCheckerBuilder().descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("msg", message_type))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using HomogeneousLiteralValidatorTest = testing::TestWithParam; + +TEST_P(HomogeneousLiteralValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(HomogeneousLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + HomogeneousLiteralValidatorTest, HomogeneousLiteralValidatorTest, + testing::Values( + // Lists + TestCase{"[1, 2, 3]", true}, TestCase{"['a', 'b', 'c']", true}, + TestCase{"[1, 'a']", false, "expected type 'int' but found 'string'"}, + TestCase{"[1, 2, 'a']", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1], [2]]", true}, + TestCase{"[[1], ['a']]", false, + "expected type 'list(int)' but found 'list(string)'"}, + + // Dyn casts + TestCase{"[dyn(1), dyn('a')]", true, ""}, + TestCase{"[dyn(1), 2]", false, "expected type 'dyn' but found 'int'"}, + + // Maps + TestCase{"{1: 'a', 2: 'b'}", true}, TestCase{"{'a': 1, 'b': 2}", true}, + TestCase{"{1: 'a', 'b': 2}", false, + "expected type 'int' but found 'string'"}, + TestCase{"{1: 'a', 2: 3}", false, + "expected type 'string' but found 'int'"}, + + // Optionals + TestCase{"[optional.of(1), optional.of(2)]", true}, + TestCase{"[optional.of(1), optional.of('b')]", false, + "expected type 'optional_type(int)' but found " + "'optional_type(string)'"}, + + TestCase{"[?optional.of(1), ?optional.of(2)]", true}, + TestCase{"[?optional.of(1), ?optional.of('a')]", false, + "expected type 'int' but found 'string'"}, + TestCase{"{?1: optional.of('a'), ?2: optional.none()}", true}, + TestCase{"{?1: optional.of('a'), ?2: optional.of(1)}", false, + "expected type 'string' but found 'int'"}, + + // Exempted Functions + TestCase{"'%v %v'.format([1, 'a'])", true}, + + // Mixed Primitives and Wrappers + TestCase{"[1, msg.single_int64_wrapper]", true}, + TestCase{"[msg.single_int64_wrapper, 1]", true}, + TestCase{"['foo', msg.single_string_wrapper]", true}, + TestCase{"[msg.single_string_wrapper, 'foo']", true}, + TestCase{"{1: msg.single_int64_wrapper, 2: 3}", true}, + TestCase{"{1: 2, 2: msg.single_int64_wrapper}", true}, + TestCase{"[[1], [msg.single_int64_wrapper]]", true}, + TestCase{"[optional.of(1), optional.of(msg.single_int64_wrapper)]", + true}, + TestCase{"[1, msg.single_string_wrapper]", false, + "expected type 'int' but found 'wrapper(string)'"}, + TestCase{"[msg.single_int64_wrapper, 'foo']", false, + "expected type 'wrapper(int)' but found 'string'"}, + TestCase{"[msg.single_int64_wrapper, msg.single_string_wrapper]", false, + "expected type 'wrapper(int)' but found 'wrapper(string)'"}, + + // Nested + TestCase{"[1, [2, 'a']]", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1, 2], [3, 4]]", true, ""}, + TestCase{"[{1: 2}, {'foo': 3}]", false, + "expected type 'map(int, int)' but found 'map(string, int)'"}, + TestCase{"[{1: 2}, {3: 'foo'}]", false, + "expected type 'map(int, int)' but found 'map(int, string)'"}, + TestCase{"[{1: 2}, {3: 4}]", true, ""})); + +} // namespace +} // namespace cel From 072542b0a1128e1c6b26ce774dc9c2824cb98ad2 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 16:41:07 -0700 Subject: [PATCH 031/143] Add RegexPatternValidator. PiperOrigin-RevId: 899243591 --- extensions/BUILD | 3 + extensions/regex_ext.cc | 8 +++ extensions/regex_ext.h | 8 +++ extensions/regex_ext_test.cc | 40 +++++++++++++ validator/BUILD | 35 +++++++++++ validator/regex_validator.cc | 96 +++++++++++++++++++++++++++++++ validator/regex_validator.h | 53 +++++++++++++++++ validator/regex_validator_test.cc | 91 +++++++++++++++++++++++++++++ 8 files changed, 334 insertions(+) create mode 100644 validator/regex_validator.cc create mode 100644 validator/regex_validator.h create mode 100644 validator/regex_validator_test.cc diff --git a/extensions/BUILD b/extensions/BUILD index c393ec13a..ff37e2c3f 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -774,6 +774,8 @@ cc_library( "//runtime:runtime_builder", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", + "//validator", + "//validator:regex_validator", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:bind_front", @@ -814,6 +816,7 @@ cc_test( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index c3d7cae53..9c06d90c2 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -42,6 +42,8 @@ #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" +#include "validator/regex_validator.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -341,4 +343,10 @@ CompilerLibrary RegexExtCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } +Validation RegexExtValidator() { + return RegexPatternValidator( + /*id=*/"", + {{"regex.extract", 1}, {"regex.extractAll", 1}, {"regex.replace", 1}}); +} + } // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h index dc401f5bd..7b32aee00 100644 --- a/extensions/regex_ext.h +++ b/extensions/regex_ext.h @@ -81,6 +81,7 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/runtime_builder.h" +#include "validator/validator.h" namespace cel::extensions { @@ -119,5 +120,12 @@ CheckerLibrary RegexExtCheckerLibrary(); // regex.extractAll(target: str, pattern: str) -> list CompilerLibrary RegexExtCompilerLibrary(); +// Returns a `Validation` that checks all calls to the CEL regex extension +// functions. +// +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +Validation RegexExtValidator(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index e69f7cce1..26d9936aa 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -46,6 +46,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/extension_set.h" @@ -497,5 +498,44 @@ std::vector createRegexCheckerParams() { INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, ValuesIn(createRegexCheckerParams())); + +absl::StatusOr> CreateRegexExtCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(RegexExtCompilerLibrary())); + return std::move(*builder).Build(); +} + +class RegexExtValidatorTest : public TestWithParam {}; + +TEST_P(RegexExtValidatorTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateRegexExtCompiler()); + + Validator validator; + validator.AddValidation(RegexExtValidator()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(GetParam().expr_string)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), GetParam().error_substr.empty()) + << "Expression: " << GetParam().expr_string; + if (!GetParam().error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(GetParam().error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtValidatorTest, RegexExtValidatorTest, + testing::ValuesIn(std::vector{ + {"regex.extract('hello world', 'hello (.*)')"}, + {"regex.extract('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.extractAll('hello world', 'hello (.*)')"}, + {"regex.extractAll('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.replace('hello world', 'hello', 'hi')"}, + {"regex.replace('hello world', 'he([', 'hi') ", + "invalid regular expression"}, + })); } // namespace } // namespace cel::extensions diff --git a/validator/BUILD b/validator/BUILD index 98d1316c7..e3f639142 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -109,6 +109,24 @@ cc_library( ], ) +cc_library( + name = "regex_validator", + srcs = ["regex_validator.cc"], + hdrs = ["regex_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:constant", + "//common:expr", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:re2_options", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + cc_test( name = "homogeneous_literal_validator_test", srcs = ["homogeneous_literal_validator_test.cc"], @@ -147,4 +165,21 @@ cc_test( ], ) +cc_test( + name = "regex_validator_test", + srcs = ["regex_validator_test.cc"], + deps = [ + ":regex_validator", + ":validator", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + licenses(["notice"]) diff --git a/validator/regex_validator.cc b/validator/regex_validator.cc new file mode 100644 index 000000000..df92bfb1e --- /dev/null +++ b/validator/regex_validator.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "internal/re2_options.h" +#include "validator/validator.h" +#include "re2/re2.h" + +namespace cel { + +namespace { + +bool CheckPattern(ValidationContext& context, const NavigableAstNode& node, + int arg_index) { + ABSL_DCHECK(node.expr()->has_call_expr()); + const auto& call_expr = node.expr()->call_expr(); + + const Expr* pattern_expr = nullptr; + + if (call_expr.has_target()) { + if (arg_index == 0) { + pattern_expr = &call_expr.target(); + } else if (call_expr.args().size() > arg_index - 1) { + pattern_expr = &call_expr.args()[arg_index - 1]; + } + } else if (call_expr.args().size() > arg_index) { + pattern_expr = &call_expr.args()[arg_index]; + } + + if (pattern_expr == nullptr || !pattern_expr->has_const_expr()) { + return true; + } + + const auto& const_expr = pattern_expr->const_expr(); + if (!const_expr.has_string_value()) { + return true; + } + + absl::string_view pattern_string = const_expr.string_value(); + RE2 re(pattern_string, internal::MakeRE2Options()); + if (!re.ok()) { + context.ReportErrorAt( + pattern_expr->id(), + absl::StrCat("invalid regular expression: ", re.error())); + return false; + } + return true; +} + +} // namespace + +Validation RegexPatternValidator( + absl::string_view id, std::vector config) { + return Validation( + [config = std::move(config)](ValidationContext& context) -> bool { + bool result = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kCall) { + for (const auto& config : config) { + if (node.expr()->call_expr().function() == config.function_name) { + if (!CheckPattern(context, node, config.pattern_arg_index)) { + result = false; + } + break; + } + } + } + } + return result; + }, + id); +} + +} // namespace cel diff --git a/validator/regex_validator.h b/validator/regex_validator.h new file mode 100644 index 000000000..15ee1755e --- /dev/null +++ b/validator/regex_validator.h @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "common/standard_definitions.h" +#include "validator/validator.h" + +namespace cel { + +// Configuration for the regex pattern validator. +struct RegexPatternValidatorConfig { + // The resolved function name. + std::string function_name; + // the index of the pattern argument (counting the receiver as arg 0 if + // present). + int pattern_arg_index; +}; + +// Returns a `Validation` that checks all calls to the given regex functions +// It validates that the specified argument is a valid regular expression if it +// is a literal string. +Validation RegexPatternValidator( + absl::string_view id, std::vector config); + +// Returns a `Validation` that checks all calls to the CEL `matches` function. +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +inline Validation MatchesValidator() { + return RegexPatternValidator( + "cel.validator.matches", + {{std::string(StandardFunctions::kRegexMatch), 1}}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ diff --git a/validator/regex_validator_test.cc b/validator/regex_validator_test.cc new file mode 100644 index 000000000..cfab1468d --- /dev/null +++ b/validator/regex_validator_test.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("p", StringType()))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using MatchesValidatorTest = testing::TestWithParam; + +TEST_P(MatchesValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(MatchesValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + MatchesValidatorTest, MatchesValidatorTest, + testing::Values( + // Member calls + TestCase{"'hello'.matches('h.*')", true}, + TestCase{"'hello'.matches('h[')", false, "invalid regular expression"}, + TestCase{"'hello'.matches('h(a|b)')", true}, + TestCase{"'hello'.matches('h(a|b')", false, + "invalid regular expression"}, + // Global calls + TestCase{"matches('hello', 'h.*')", true}, + TestCase{"matches('hello', 'h[')", false, "invalid regular expression"}, + // Non-literal patterns (should not report regex errors) + TestCase{"'hello'.matches(p)", true}, + TestCase{"'hello'.matches('h' + 'ello')", true}, + TestCase{"'hello'.matches(dyn(1))", true}, + + // Empty pattern + TestCase{"'hello'.matches('')", true})); + +} // namespace +} // namespace cel From 9fd4d79dff5f5b95654f1eb60a55a104e063deac Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 17:10:35 -0700 Subject: [PATCH 032/143] Add comprehension nesting limit validator. PiperOrigin-RevId: 899255266 --- validator/BUILD | 29 ++++++ validator/comprehension_nesting_validator.cc | 72 ++++++++++++++ validator/comprehension_nesting_validator.h | 31 ++++++ .../comprehension_nesting_validator_test.cc | 96 +++++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 validator/comprehension_nesting_validator.cc create mode 100644 validator/comprehension_nesting_validator.h create mode 100644 validator/comprehension_nesting_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index e3f639142..9910a6b97 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -182,4 +182,33 @@ cc_test( ], ) +cc_library( + name = "comprehension_nesting_validator", + srcs = ["comprehension_nesting_validator.cc"], + hdrs = ["comprehension_nesting_validator.h"], + deps = [ + ":validator", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "comprehension_nesting_validator_test", + srcs = ["comprehension_nesting_validator_test.cc"], + deps = [ + ":comprehension_nesting_validator", + ":validator", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + licenses(["notice"]) diff --git a/validator/comprehension_nesting_validator.cc b/validator/comprehension_nesting_validator.cc new file mode 100644 index 000000000..81c47cbc3 --- /dev/null +++ b/validator/comprehension_nesting_validator.cc @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool IsEmptyRangeComprehension(const NavigableAstNode& node) { + ABSL_DCHECK(node.expr()->has_comprehension_expr()); + const auto& comp = node.expr()->comprehension_expr(); + return comp.has_iter_range() && comp.iter_range().has_list_expr() && + comp.iter_range().list_expr().elements().empty(); +} + +} // namespace + +Validation ComprehensionNestingLimitValidator(int limit) { + return Validation( + [limit](ValidationContext& context) -> bool { + bool is_valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kComprehension) { + continue; + } + if (IsEmptyRangeComprehension(node)) { + continue; + } + + int count = 0; + const NavigableAstNode* current = &node; + while (current != nullptr) { + if (current->node_kind() == NodeKind::kComprehension && + !IsEmptyRangeComprehension(*current)) { + count++; + } + current = current->parent(); + } + if (count > limit) { + context.ReportErrorAt( + node.expr()->id(), + absl::StrCat("comprehension nesting level of ", count, + " exceeds limit of ", limit)); + is_valid = false; + break; + } + } + return is_valid; + }, + "cel.validator.comprehension_nesting_limit"); +} + +} // namespace cel diff --git a/validator/comprehension_nesting_validator.h b/validator/comprehension_nesting_validator.h new file mode 100644 index 000000000..4dab78db0 --- /dev/null +++ b/validator/comprehension_nesting_validator.h @@ -0,0 +1,31 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that comprehensions are not nested beyond +// the specified limit. +// +// Comprehensions with an empty iteration range (e.g. `cel.bind`) do not count +// towards the nesting limit. +Validation ComprehensionNestingLimitValidator(int limit); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ diff --git a/validator/comprehension_nesting_validator_test.cc b/validator/comprehension_nesting_validator_test.cc new file mode 100644 index 000000000..c1b47f82d --- /dev/null +++ b/validator/comprehension_nesting_validator_test.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + return builder->Build(); +} + +struct TestCase { + std::string expression; + int limit; + bool valid; + std::string error_substr = ""; +}; + +using ComprehensionNestingValidatorTest = testing::TestWithParam; + +TEST_P(ComprehensionNestingValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(ComprehensionNestingLimitValidator(test_case.limit)); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + auto result_or = compiler->Compile(test_case.expression); + if (!result_or.ok()) { + GTEST_SKIP() << "Expression failed to compile: " << test_case.expression + << " " << result_or.status().message(); + } + auto result = std::move(result_or).value(); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression + << " Limit: " << test_case.limit; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionNestingValidatorTest, ComprehensionNestingValidatorTest, + testing::Values( + TestCase{"[1, 2].all(x, x > 0)", 1, true}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 2, true}, + // Empty range comprehension (does not count) + TestCase{"[].all(x, [1, 2].all(y, y > 0))", 1, true}, + TestCase{"cel.bind(x, [1, 2].all(y, y > 0), [1, 2].all(z, z > 0))", 1, + true}, + // Nested empty range comprehensions + TestCase{"[].all(x, [].all(y, true))", 0, true}, + // Deeply nested mixed + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 2, true})); + +} // namespace +} // namespace cel From 7a37461941067c8fe90cce8724dd5153b42c21bf Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 14 Apr 2026 13:12:00 -0700 Subject: [PATCH 033/143] Test fixes for macos builds - ignore conformance test that depends on charconv shortest float rep formatting - fix internal test with ambiguous overloads PiperOrigin-RevId: 899739434 --- conformance/BUILD | 44 ++------- internal/overflow_test.cc | 189 ++++++++++++++++++++++++-------------- 2 files changed, 126 insertions(+), 107 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index ba485f36d..139739891 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -164,7 +164,7 @@ _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] -_TESTS_TO_SKIP_MODERN = [ +_TESTS_TO_SKIP = [ # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "timestamps/duration_converters/get_milliseconds", @@ -197,10 +197,13 @@ _TESTS_TO_SKIP_MODERN = [ "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", "timestamps/timestamp_selectors_tz/getDayOfYear", # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. + # precision to preserve value. Not available on older compilers where we just use absl::Format. + # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", ] +_TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP + _TESTS_TO_SKIP_MODERN_DASHBOARD = [ # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. @@ -208,34 +211,7 @@ _TESTS_TO_SKIP_MODERN_DASHBOARD = [ "enums/strong_proto3", ] -_TESTS_TO_SKIP_LEGACY = [ - # Tests which require spec changes. - # TODO(issues/93): Deprecate Duration.getMilliseconds. - "timestamps/duration_converters/get_milliseconds", - - # Broken test cases which should be supported. - # TODO(issues/112): Unbound functions result in empty eval response. - "basic/functions/unbound", - "basic/functions/unbound_is_runtime_error", - - # TODO(issues/97): Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails - "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", - "namespace/qualified/self_eval_qualified_lookup", - "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - # TODO(issues/117): Integer overflow on enum assignments should error. - "enums/legacy_proto2/select_big,select_neg", - - # Skip until fixed. - "wrappers/field_mask/to_json", - "wrappers/empty/to_json", - "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", - "parse/receiver_function_names", - - # Future features for CEL 1.0 - # TODO(issues/119): Strong typing support for enums, specified but not implemented. - "enums/strong_proto2", - "enums/strong_proto3", - +_TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ # Legacy value does not support optional_type. "optionals/optionals", @@ -245,14 +221,6 @@ _TESTS_TO_SKIP_LEGACY = [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # These depend on legacy US/ timezones. It's spotty if these are included with a normally - # configured timezone database. - "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", - "timestamps/timestamp_selectors_tz/getDayOfYear", - # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. - "conversions/string/double_hard", - # cel.@block "block_ext/basic/optional_list", "block_ext/basic/optional_map", diff --git a/internal/overflow_test.cc b/internal/overflow_test.cc index 38c5fa750..213e7a79d 100644 --- a/internal/overflow_test.cc +++ b/internal/overflow_test.cc @@ -57,25 +57,30 @@ INSTANTIATE_TEST_SUITE_P( CheckedIntMathTest, CheckedIntResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1L, 1L); }, 2L}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1L); }, 1L}, - {"ZeroAddMinusOne", [] { return CheckedAdd(0, -1L); }, -1L}, - {"OneAddZero", [] { return CheckedAdd(1L, 0); }, 1L}, - {"MinusOneAddZero", [] { return CheckedAdd(-1L, 0); }, -1L}, + {"OneAddOne", [] { return CheckedAdd(int64_t{1L}, 1L); }, 2L}, + {"ZeroAddOne", [] { return CheckedAdd(int64_t{0}, 1L); }, 1L}, + {"ZeroAddMinusOne", [] { return CheckedAdd(int64_t{0}, -1L); }, -1L}, + {"OneAddZero", [] { return CheckedAdd(int64_t{1L}, 0); }, 1L}, + {"MinusOneAddZero", [] { return CheckedAdd(int64_t{-1L}, 0); }, -1L}, {"OneAddIntMax", - [] { return CheckedAdd(1L, std::numeric_limits::max()); }, + [] { + return CheckedAdd(int64_t{1L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneAddIntMin", - [] { return CheckedAdd(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedAdd(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests. - {"TwoSubThree", [] { return CheckedSub(2L, 3L); }, -1L}, - {"TwoSubZero", [] { return CheckedSub(2L, 0); }, 2L}, - {"ZeroSubTwo", [] { return CheckedSub(0, 2L); }, -2L}, - {"MinusTwoSubThree", [] { return CheckedSub(-2L, 3L); }, -5L}, - {"MinusTwoSubZero", [] { return CheckedSub(-2L, 0); }, -2L}, - {"ZeroSubMinusTwo", [] { return CheckedSub(0, -2L); }, 2L}, + {"TwoSubThree", [] { return CheckedSub(int64_t{2L}, 3L); }, -1L}, + {"TwoSubZero", [] { return CheckedSub(int64_t{2L}, 0); }, 2L}, + {"ZeroSubTwo", [] { return CheckedSub(int64_t{0}, 2L); }, -2L}, + {"MinusTwoSubThree", [] { return CheckedSub(int64_t{-2L}, 3L); }, -5L}, + {"MinusTwoSubZero", [] { return CheckedSub(int64_t{-2L}, 0); }, -2L}, + {"ZeroSubMinusTwo", [] { return CheckedSub(int64_t{0}, -2L); }, 2L}, {"IntMinSubIntMax", [] { return CheckedSub(std::numeric_limits::max(), @@ -84,66 +89,100 @@ INSTANTIATE_TEST_SUITE_P( absl::OutOfRangeError("integer overflow")}, // Multiplication tests. - {"TwoMulThree", [] { return CheckedMul(2L, 3L); }, 6L}, - {"MinusTwoMulThree", [] { return CheckedMul(-2L, 3L); }, -6L}, - {"MinusTwoMulMinusThree", [] { return CheckedMul(-2L, -3L); }, 6L}, - {"TwoMulMinusThree", [] { return CheckedMul(2L, -3L); }, -6L}, + {"TwoMulThree", [] { return CheckedMul(int64_t{2L}, 3L); }, 6L}, + {"MinusTwoMulThree", [] { return CheckedMul(int64_t{-2L}, 3L); }, -6L}, + {"MinusTwoMulMinusThree", [] { return CheckedMul(int64_t{-2L}, -3L); }, + 6L}, + {"TwoMulMinusThree", [] { return CheckedMul(int64_t{2L}, -3L); }, -6L}, {"TwoMulIntMax", - [] { return CheckedMul(2L, std::numeric_limits::max()); }, + [] { + return CheckedMul(int64_t{2L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneMulIntMin", - [] { return CheckedMul(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulMinusOne", - [] { return CheckedMul(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulZero", - [] { return CheckedMul(std::numeric_limits::lowest(), 0); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{0}); + }, 0}, {"ZeroMulIntMin", - [] { return CheckedMul(0, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{0}, + std::numeric_limits::lowest()); + }, 0}, {"IntMaxMulZero", - [] { return CheckedMul(std::numeric_limits::max(), 0); }, 0}, + [] { + return CheckedMul(std::numeric_limits::max(), int64_t{0}); + }, + 0}, {"ZeroMulIntMax", - [] { return CheckedMul(0, std::numeric_limits::max()); }, 0}, + [] { + return CheckedMul(int64_t{0}, std::numeric_limits::max()); + }, + 0}, // Division cases. - {"ZeroDivOne", [] { return CheckedDiv(0, 1L); }, 0}, - {"TenDivTwo", [] { return CheckedDiv(10L, 2L); }, 5}, - {"TenDivMinusOne", [] { return CheckedDiv(10L, -1L); }, -10}, - {"MinusTenDivMinusOne", [] { return CheckedDiv(-10L, -1L); }, 10}, - {"MinusTenDivTwo", [] { return CheckedDiv(-10L, 2L); }, -5}, - {"OneDivZero", [] { return CheckedDiv(1L, 0L); }, + {"ZeroDivOne", [] { return CheckedDiv(int64_t{0}, 1L); }, 0}, + {"TenDivTwo", [] { return CheckedDiv(int64_t{10L}, 2L); }, 5}, + {"TenDivMinusOne", [] { return CheckedDiv(int64_t{10L}, -1L); }, -10}, + {"MinusTenDivMinusOne", [] { return CheckedDiv(int64_t{-10L}, -1L); }, + 10}, + {"MinusTenDivTwo", [] { return CheckedDiv(int64_t{-10L}, 2L); }, -5}, + {"OneDivZero", [] { return CheckedDiv(int64_t{1L}, 0L); }, absl::InvalidArgumentError("divide by zero")}, {"IntMinDivMinusOne", - [] { return CheckedDiv(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedDiv(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Modulus cases. - {"ZeroModTwo", [] { return CheckedMod(0, 2L); }, 0}, - {"TwoModTwo", [] { return CheckedMod(2L, 2L); }, 0}, - {"ThreeModTwo", [] { return CheckedMod(3L, 2L); }, 1L}, - {"TwoModZero", [] { return CheckedMod(2L, 0); }, + {"ZeroModTwo", [] { return CheckedMod(int64_t{0}, 2L); }, 0}, + {"TwoModTwo", [] { return CheckedMod(int64_t{2L}, 2L); }, 0}, + {"ThreeModTwo", [] { return CheckedMod(int64_t{3L}, 2L); }, 1L}, + {"TwoModZero", [] { return CheckedMod(int64_t{2L}, 0); }, absl::InvalidArgumentError("modulus by zero")}, {"IntMinModTwo", - [] { return CheckedMod(std::numeric_limits::lowest(), 2L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{2L}); + }, 0}, {"IntMaxModMinusOne", - [] { return CheckedMod(std::numeric_limits::max(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::max(), int64_t{-1L}); + }, 0}, {"IntMinModMinusOne", - [] { return CheckedMod(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Negation cases. - {"NegateOne", [] { return CheckedNegation(1L); }, -1L}, + {"NegateOne", [] { return CheckedNegation(int64_t{1L}); }, -1L}, {"NegateMinInt64", [] { return CheckedNegation(std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Numeric conversion cases for uint -> int, double -> int - {"Uint64Conversion", [] { return CheckedUint64ToInt64(1UL); }, 1L}, + {"Uint64Conversion", [] { return CheckedUint64ToInt64(uint64_t{1UL}); }, + 1L}, {"Uint32MaxConversion", [] { return CheckedUint64ToInt64( @@ -156,7 +195,8 @@ INSTANTIATE_TEST_SUITE_P( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToInt64(100.1); }, 100L}, + {"DoubleConversion", [] { return CheckedDoubleToInt64(double{100.1}); }, + 100L}, {"DoubleInt64MaxConversionError", [] { return CheckedDoubleToInt64( @@ -201,9 +241,10 @@ INSTANTIATE_TEST_SUITE_P( }, absl::OutOfRangeError("out of int64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToInt64(-1.0e99); }, + [] { return CheckedDoubleToInt64(double{-1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, - {"PosRangeConversionError", [] { return CheckedDoubleToInt64(1.0e99); }, + {"PosRangeConversionError", + [] { return CheckedDoubleToInt64(double{1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, }), [](const testing::TestParamInfo& info) { @@ -218,51 +259,58 @@ INSTANTIATE_TEST_SUITE_P( CheckedUintMathTest, CheckedUintResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1UL, 1UL); }, 2UL}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1UL); }, 1UL}, - {"OneAddZero", [] { return CheckedAdd(1UL, 0); }, 1UL}, + {"OneAddOne", [] { return CheckedAdd(uint64_t{1UL}, 1UL); }, 2UL}, + {"ZeroAddOne", [] { return CheckedAdd(uint64_t{0}, 1UL); }, 1UL}, + {"OneAddZero", [] { return CheckedAdd(uint64_t{1UL}, 0); }, 1UL}, {"OneAddIntMax", - [] { return CheckedAdd(1UL, std::numeric_limits::max()); }, + [] { + return CheckedAdd(uint64_t{1UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Subtraction tests. - {"OneSubOne", [] { return CheckedSub(1UL, 1UL); }, 0}, - {"ZeroSubOne", [] { return CheckedSub(0, 1UL); }, + {"OneSubOne", [] { return CheckedSub(uint64_t{1UL}, 1UL); }, 0}, + {"ZeroSubOne", [] { return CheckedSub(uint64_t{0}, 1UL); }, absl::OutOfRangeError("unsigned integer overflow")}, - {"OneSubZero", [] { return CheckedSub(1UL, 0); }, 1UL}, + {"OneSubZero", [] { return CheckedSub(uint64_t{1UL}, 0); }, 1UL}, // Multiplication tests. - {"OneMulOne", [] { return CheckedMul(1UL, 1UL); }, 1UL}, - {"ZeroMulOne", [] { return CheckedMul(0, 1UL); }, 0}, - {"OneMulZero", [] { return CheckedMul(1UL, 0); }, 0}, + {"OneMulOne", [] { return CheckedMul(uint64_t{1UL}, 1UL); }, 1UL}, + {"ZeroMulOne", [] { return CheckedMul(uint64_t{0}, 1UL); }, 0}, + {"OneMulZero", [] { return CheckedMul(uint64_t{1UL}, 0); }, 0}, {"TwoMulUintMax", - [] { return CheckedMul(2UL, std::numeric_limits::max()); }, + [] { + return CheckedMul(uint64_t{2UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Division tests. - {"TwoDivTwo", [] { return CheckedDiv(2UL, 2UL); }, 1UL}, - {"TwoDivFour", [] { return CheckedDiv(2UL, 4UL); }, 0}, - {"OneDivZero", [] { return CheckedDiv(1UL, 0); }, + {"TwoDivTwo", [] { return CheckedDiv(uint64_t{2UL}, 2UL); }, 1UL}, + {"TwoDivFour", [] { return CheckedDiv(uint64_t{2UL}, 4UL); }, 0}, + {"OneDivZero", [] { return CheckedDiv(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("divide by zero")}, // Modulus tests. - {"TwoModTwo", [] { return CheckedMod(2UL, 2UL); }, 0}, - {"TwoModFour", [] { return CheckedMod(2UL, 4UL); }, 2UL}, - {"OneModZero", [] { return CheckedMod(1UL, 0); }, + {"TwoModTwo", [] { return CheckedMod(uint64_t{2UL}, 2UL); }, 0}, + {"TwoModFour", [] { return CheckedMod(uint64_t{2UL}, 4UL); }, 2UL}, + {"OneModZero", [] { return CheckedMod(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("modulus by zero")}, // Conversion test cases for int -> uint, double -> uint. - {"Int64Conversion", [] { return CheckedInt64ToUint64(1L); }, 1UL}, + {"Int64Conversion", [] { return CheckedInt64ToUint64(int64_t{1L}); }, + 1UL}, {"Int64MaxConversion", [] { return CheckedInt64ToUint64(std::numeric_limits::max()); }, static_cast(std::numeric_limits::max())}, {"NegativeInt64ConversionError", - [] { return CheckedInt64ToUint64(-1L); }, + [] { return CheckedInt64ToUint64(int64_t{-1L}); }, absl::OutOfRangeError("out of uint64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToUint64(100.1); }, - 100UL}, + {"DoubleConversion", + [] { return CheckedDoubleToUint64(double{100.1}); }, 100UL}, {"DoubleUint64MaxConversionError", [] { return CheckedDoubleToUint64( @@ -287,13 +335,14 @@ INSTANTIATE_TEST_SUITE_P( std::numeric_limits::infinity()); }, absl::OutOfRangeError("out of uint64 range")}, - {"NegConversionError", [] { return CheckedDoubleToUint64(-1.1); }, + {"NegConversionError", + [] { return CheckedDoubleToUint64(double{-1.1}); }, absl::OutOfRangeError("out of uint64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToUint64(-1.0e99); }, + [] { return CheckedDoubleToUint64(double{-1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, {"PosRangeConversionError", - [] { return CheckedDoubleToUint64(1.0e99); }, + [] { return CheckedDoubleToUint64(double{1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, }), [](const testing::TestParamInfo& info) { @@ -571,7 +620,8 @@ TEST_P(CheckedConvertInt64Int32Test, Conversions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedConvertInt64Int32Test, CheckedConvertInt64Int32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedInt64ToInt32(1L); }, 1}, + {"SimpleConversion", [] { return CheckedInt64ToInt32(int64_t{1L}); }, + 1}, {"Int32MaxConversion", [] { return CheckedInt64ToInt32( @@ -610,7 +660,8 @@ TEST_P(CheckedConvertUint64Uint32Test, Conversions) { INSTANTIATE_TEST_SUITE_P( CheckedConvertUint64Uint32Test, CheckedConvertUint64Uint32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedUint64ToUint32(1UL); }, 1U}, + {"SimpleConversion", + [] { return CheckedUint64ToUint32(uint64_t{1UL}); }, 1U}, {"Uint32MaxConversion", [] { return CheckedUint64ToUint32( From a69b0ea03670d65745564a931d350ba00e63ce54 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 14 Apr 2026 14:31:16 -0700 Subject: [PATCH 034/143] More macos test fixes. PiperOrigin-RevId: 899776008 --- env/env_yaml.cc | 2 +- env/env_yaml_test.cc | 3 ++- eval/public/builtin_func_test.cc | 42 +++++++++++++++++++------------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index a6f66bd83..4ba16ea84 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -882,7 +882,7 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { case ConstantKindCase::kTimestamp: out << YAML::Key << "value" << YAML::Value; out << absl::FormatTime( - "%4Y-%2m-%2d%ET%2H:%2M:%E*SZ", + "%Y-%m-%d%ET%H:%M:%E*SZ", // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) constant.timestamp_value(), absl::UTCTimeZone()); break; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index c3e4839af..25cc63206 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -880,7 +880,8 @@ INSTANTIATE_TEST_SUITE_P( })); std::string Unindent(std::string_view yaml) { - std::vector lines = absl::StrSplit(yaml, '\n'); + absl::string_view yaml_view = yaml; + std::vector lines = absl::StrSplit(yaml_view, '\n'); int indent = -1; std::vector unindented_lines; for (auto& line : lines) { diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 1eeb07193..dba71d307 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -750,22 +750,25 @@ TEST_F(BuiltinsTest, TestBytesConversions_string) { TEST_F(BuiltinsTest, TestDoubleConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), 100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), + double{100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversions_string) { std::string ref = "-100.1"; - TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), -100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), + double{-100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_uint) { uint64_t ref = 100UL; - TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), + double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { @@ -774,34 +777,36 @@ TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { } TEST_F(BuiltinsTest, TestDynConversions) { - TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), 100.1); - TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), 100L); - TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), double{100.1}); + TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), int64_t{100L}); + TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestIntConversions_int) { - TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_Timestamp) { Timestamp ref; ref.set_seconds(100); - TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), 100L); + TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_uint) { uint64_t ref = 100; - TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_doubleIntMin) { @@ -826,7 +831,7 @@ TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // value, but it will rountrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), - std::numeric_limits::max() - 1023); + int64_t{std::numeric_limits::max() - 1023}); } TEST_F(BuiltinsTest, TestIntConversionError_doubleNegRange) { @@ -874,21 +879,24 @@ TEST_F(BuiltinsTest, TestIntConversionError_uintRange) { TEST_F(BuiltinsTest, TestUintConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_uint) { - TestTypeConverts(builtin::kUint, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateUint64(uint64_t{100UL}), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversionError_doubleNegRange) { From 0253936925f3d5274c24c8ffa7e5eed683ecdd41 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 15 Apr 2026 11:31:54 -0700 Subject: [PATCH 035/143] No-op doc change to trigger builds. PiperOrigin-RevId: 900266859 --- common/ast.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/ast.h b/common/ast.h index db336f52d..afd0575ad 100644 --- a/common/ast.h +++ b/common/ast.h @@ -136,10 +136,10 @@ class Ast final { expr_version_ = expr_version; } - // Computes the source location (line and column) for the given expression id + // Computes the source location (line and column) for the given expression ID // from the source info (which stores absolute positions). // - // Returns a default (empty) source location if the expression id is not found + // Returns a default (empty) source location if the expression ID is not found // or the source info is not populated correctly. SourceLocation ComputeSourceLocation(int64_t expr_id) const; From bbe1aaa9dd1caa60c48a0d2278b3577c3d1d5797 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 16 Apr 2026 15:36:24 -0700 Subject: [PATCH 036/143] Remove special error for null select target. Just return the general invalid select target error. PiperOrigin-RevId: 900943253 --- eval/eval/select_step.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index b95915145..420f3ac31 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -19,7 +19,6 @@ #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -158,13 +157,6 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { result_trail = trail.Step(&field_); } - if (arg->Is()) { - frame->value_stack().PopAndPush( - cel::ErrorValue(cel::runtime_internal::CreateError("Message is NULL")), - std::move(result_trail)); - return absl::OkStatus(); - } - absl::optional optional_arg; if (enable_optional_types_ && arg.IsOptional()) { @@ -354,10 +346,6 @@ class DirectSelectStep : public DirectExpressionStep { case ValueKind::kStruct: case ValueKind::kMap: break; - case ValueKind::kNull: - result = cel::ErrorValue( - cel::runtime_internal::CreateError("Message is NULL")); - return absl::OkStatus(); default: if (optional_arg) { break; From 69719524b2c99de9802239acc7a8f1cb8712955c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 17 Apr 2026 12:02:08 -0700 Subject: [PATCH 037/143] No public description PiperOrigin-RevId: 901405918 --- internal/to_address.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/to_address.h b/internal/to_address.h index 5dffef3c1..36e7eeb60 100644 --- a/internal/to_address.h +++ b/internal/to_address.h @@ -49,7 +49,7 @@ struct PointerTraitsToAddress { template struct PointerTraitsToAddress< - T, absl::void_t::to_address( + T, std::void_t::to_address( std::declval()))> > { static constexpr auto Dispatch( const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { From 77d98c5b13884b10f516b42034929505e3e46ab6 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:17:59 -0700 Subject: [PATCH 038/143] Disable incompatible test on MacOS PiperOrigin-RevId: 901492992 --- checker/internal/format_type_name_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/checker/internal/format_type_name_test.cc b/checker/internal/format_type_name_test.cc index 23bc2bda9..ff04e04d2 100644 --- a/checker/internal/format_type_name_test.cc +++ b/checker/internal/format_type_name_test.cc @@ -101,6 +101,7 @@ TEST(FormatTypeNameTest, Opaque) { "tuple(tuple(int, int), tuple(int, int), tuple(int, int))"); } +#ifndef __APPLE__ TEST(FormatTypeNameTest, ArbitraryNesting) { google::protobuf::Arena arena; Type type = IntType(); @@ -111,6 +112,7 @@ TEST(FormatTypeNameTest, ArbitraryNesting) { EXPECT_THAT(FormatTypeName(type), MatchesRegex(R"(^(ptype\(){1000}int(\)){1000})")); } +#endif } // namespace } // namespace cel::checker_internal From e20bf48de055dcd02a517151624f21bff2eda160 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:18:25 -0700 Subject: [PATCH 039/143] Ensure RequestContext is in the descriptor pool Mark benchmark tests as 'manual' PiperOrigin-RevId: 901493173 --- eval/tests/BUILD | 20 ++++++++++++++++---- eval/tests/allocation_benchmark_test.cc | 3 +++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 8eeafd521..9163548d1 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -18,7 +18,10 @@ cc_test( srcs = [ "benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -52,7 +55,10 @@ cc_test( srcs = [ "modern_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:allocator", @@ -102,7 +108,10 @@ cc_test( srcs = [ "allocation_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -151,7 +160,10 @@ cc_test( srcs = [ "expression_builder_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:minimal_descriptor_pool", diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc index 5364d3fc0..425355e3a 100644 --- a/eval/tests/allocation_benchmark_test.cc +++ b/eval/tests/allocation_benchmark_test.cc @@ -169,6 +169,9 @@ static void BM_AllocateMessage(benchmark::State& state) { "google.api.expr.runtime.RequestContext{" "ip: '192.168.0.1'," "path: '/root'}"); + // Make sure RequestContext is loaded in the generated descriptor pool. + RequestContext context; + static_cast(context); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); From c62346262ea2ecaa98c98847651210d94af70348 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:27:27 -0700 Subject: [PATCH 040/143] Add missing 'alwayslink' directive PiperOrigin-RevId: 901496760 --- testing/testrunner/user_tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/testrunner/user_tests/BUILD b/testing/testrunner/user_tests/BUILD index 140b77aef..53cd8f716 100644 --- a/testing/testrunner/user_tests/BUILD +++ b/testing/testrunner/user_tests/BUILD @@ -59,6 +59,7 @@ cc_library( "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], + alwayslink = True, ) cc_library( From ecece1a7758a1219854f04d68de34bc73471ec76 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:35:23 -0700 Subject: [PATCH 041/143] Add missing absl/str_cat deps PiperOrigin-RevId: 901499946 --- internal/BUILD | 2 ++ internal/strings_test.cc | 1 + internal/testing.cc | 2 ++ 3 files changed, 5 insertions(+) diff --git a/internal/BUILD b/internal/BUILD index 6bd0f0a46..3891c635d 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -296,6 +296,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) @@ -312,6 +313,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) diff --git a/internal/strings_test.cc b/internal/strings_test.cc index d6c90473e..fcdb6d4ec 100644 --- a/internal/strings_test.cc +++ b/internal/strings_test.cc @@ -24,6 +24,7 @@ #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "internal/testing.h" diff --git a/internal/testing.cc b/internal/testing.cc index 77e4c65b4..84aa58cce 100644 --- a/internal/testing.cc +++ b/internal/testing.cc @@ -14,6 +14,8 @@ #include "internal/testing.h" +#include "absl/strings/str_cat.h" // IWYU pragma: keep + namespace cel::internal { void AddFatalFailure(const char* file, int line, absl::string_view expression, From 170a758a0f9c5963630b074ece3f28bb52158883 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:49:49 -0700 Subject: [PATCH 042/143] Use ShortDebugString instead of implicit proto string serialization PiperOrigin-RevId: 901506584 --- internal/message_equality_test.cc | 53 ++++++++++++++++++------------- parser/parser_test.cc | 16 ++++++---- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index bc5914bef..318138d9b 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -110,22 +110,22 @@ TEST_P(UnaryMessageEqualsTest, Equals) { } EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); // Test any. auto lhs_any = PackMessage(*lhs); auto rhs_any = PackMessage(*rhs); EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs; + << lhs_any->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs_any; + << lhs->ShortDebugString() << " " << rhs_any->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs_any; + << lhs_any->ShortDebugString() << " " << rhs_any->ShortDebugString(); } } } @@ -455,28 +455,30 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); if (!lhs_field->is_repeated() && lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } if (!rhs_field->is_repeated() && rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -485,14 +487,16 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_message, rhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( *rhs_message, rhs_field), *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } // Test `google.protobuf.Any`. absl::optional, @@ -505,21 +509,24 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); if (!lhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( *lhs_any->first, lhs_any->second), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); } } if (rhs_any) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); if (!rhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(*lhs_message, lhs_field, @@ -527,7 +534,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_any->first, rhs_any->second), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } if (lhs_any && rhs_any) { @@ -535,7 +543,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_any->second; + << lhs_any->first->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } } diff --git a/parser/parser_test.cc b/parser/parser_test.cc index aee121051..c96845e67 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1495,14 +1495,16 @@ TEST_P(ExpressionTest, Parse) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); ; } @@ -1514,7 +1516,7 @@ TEST_P(ExpressionTest, Parse) { if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) - << result->parsed_expr(); + << result->parsed_expr().ShortDebugString(); ; } } @@ -1867,14 +1869,16 @@ TEST_P(UpdatedAccuVarDisabledTest, Parse) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.R.empty()) { @@ -1885,7 +1889,7 @@ TEST_P(UpdatedAccuVarDisabledTest, Parse) { if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) - << result->parsed_expr(); + << result->parsed_expr().ShortDebugString(); } } From 474683e6a183629bd76dafd2526a2b8dbaeae0be Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 18:04:04 -0700 Subject: [PATCH 043/143] Use explicit numeric types in tests PiperOrigin-RevId: 901553198 --- eval/public/builtin_func_test.cc | 149 +++++++++++++++++++------------ 1 file changed, 92 insertions(+), 57 deletions(-) diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index dba71d307..037fa8345 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -544,13 +544,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(93541L); ref.set_nanos(11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), 25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - 1559L); + int64_t{1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - 93541L); + int64_t{93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - 11L); + int64_t{11L}); std::string result = "93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -560,13 +561,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(-93541L); ref.set_nanos(-11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), -25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{-25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - -1559L); + int64_t{-1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - -93541L); + int64_t{-93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - -11L); + int64_t{-11L}); result = "-93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -595,23 +597,28 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { ref.set_seconds(1L); ref.set_nanos(11000000L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1970L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 0L); + int64_t{1970L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 0L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 1L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 1L); + int64_t{0L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), - 11L); + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); } TEST_F(BuiltinsTest, TestTimestampConversionToString) { @@ -640,46 +647,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); // Test timestamp functions w/ fixed timezone ref.set_seconds(1L); @@ -690,46 +711,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); TestTypeConversionError( builtin::kString, @@ -828,7 +863,7 @@ TEST_F(BuiltinsTest, TestIntConversions_doubleIntMinMinus1024) { TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // Converting int64_t max - 512 to a double will not roundtrip to the original - // value, but it will rountrip to a valid 64-bit integer. + // value, but it will roundtrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), int64_t{std::numeric_limits::max() - 1023}); From 1629bb4739a26d3cdbbbe7d2f964121c241e81ed Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 20 Apr 2026 09:08:41 -0700 Subject: [PATCH 044/143] Update alloc macros in internal/new.cc PiperOrigin-RevId: 902667189 --- internal/new.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/internal/new.cc b/internal/new.cc index 5bd9e8158..05396e624 100644 --- a/internal/new.cc +++ b/internal/new.cc @@ -67,6 +67,13 @@ void* AlignedNew(size_t size, std::align_val_t alignment) { ThrowStdBadAlloc(); } return ptr; +#elif defined(__APPLE__) + void* ptr; + if (ABSL_PREDICT_FALSE( + posix_memalign(&ptr, static_cast(alignment), size) != 0)) { + ThrowStdBadAlloc(); + } + return ptr; #else void* ptr = std::aligned_alloc(static_cast(alignment), size); if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { @@ -107,7 +114,7 @@ void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { - Delete(ptr, size); + SizedDelete(ptr, size); } else { #if defined(_MSC_VER) _aligned_free(ptr); From 8c140223c1262ff9481d9bdef70fd020d154fcd1 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 20 Apr 2026 09:28:15 -0700 Subject: [PATCH 045/143] Update Bazel build rules for MacOS PiperOrigin-RevId: 902675874 --- .bazelrc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.bazelrc b/.bazelrc index a6d7a13f0..475706072 100644 --- a/.bazelrc +++ b/.bazelrc @@ -15,6 +15,13 @@ build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc build:msvc --build_tag_filters=-no_test_msvc +build:macos --cxxopt=-faligned-allocation +build:macos --cxxopt=-mmacosx-version-min=10.13 +build:macos --linkopt=-mmacosx-version-min=10.13 + +# ANTLR tool requires Java 17+. +build --java_runtime_version=remotejdk_17 + test --test_output=errors # Enable matchers in googletest From c1e3026841f93b01774264e2c7171d7e55549119 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 20 Apr 2026 13:33:49 -0700 Subject: [PATCH 046/143] Make error handling in wrapped legacy map consistent with modern version. The wrapped legacy implementation would use absl status return to propagate expected errors leading to inconsistencies in error handling depending on whether the underlying map was backed by a modern or legacy map. Update so that expected errors are cel::ErrorValue and unexpected errors (internal bug or client error) are returned as non-ok Status return value. Add Documentation for MapValue members. PiperOrigin-RevId: 902798507 --- common/legacy_value.cc | 17 ++-- common/values/custom_map_value.h | 16 ++-- common/values/legacy_map_value.h | 12 +-- common/values/map_value.h | 55 ++++++++---- common/values/null_value.h | 3 +- common/values/parsed_json_map_value.h | 12 +-- common/values/parsed_map_field_value.h | 12 +-- common/values/values.h | 1 - eval/eval/equality_steps.cc | 32 +++---- eval/public/builtin_func_test.cc | 3 +- eval/public/cel_type_registry.h | 6 -- eval/public/structs/field_access_impl.cc | 6 +- .../proto_message_type_adapter_test.cc | 4 +- .../container_membership_functions.cc | 90 ++++++++++--------- 14 files changed, 142 insertions(+), 127 deletions(-) diff --git a/common/legacy_value.cc b/common/legacy_value.cc index 5c81fdacb..7fbf16732 100644 --- a/common/legacy_value.cc +++ b/common/legacy_value.cc @@ -700,7 +700,8 @@ absl::Status LegacyMapValue::Get( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -732,7 +733,7 @@ absl::StatusOr LegacyMapValue::Find( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -764,11 +765,17 @@ absl::Status LegacyMapValue::Has( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - CEL_ASSIGN_OR_RETURN(auto has, impl_->Has(cel_key)); - *result = BoolValue{has}; + absl::StatusOr has = impl_->Has(cel_key); + if (!has.ok()) { + *result = ErrorValue(std::move(has).status()); + return absl::OkStatus(); + } + + *result = BoolValue(*has); return absl::OkStatus(); } diff --git a/common/values/custom_map_value.h b/common/values/custom_map_value.h index 9e840e07f..ca6e1e025 100644 --- a/common/values/custom_map_value.h +++ b/common/values/custom_map_value.h @@ -225,7 +225,7 @@ class CustomMapValueInterface { // Returns the number of entries in this map. virtual size_t Size() const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -233,7 +233,7 @@ class CustomMapValueInterface { google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ForEach( ForEachCallback callback, @@ -347,7 +347,7 @@ class CustomMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -356,7 +356,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -365,7 +365,7 @@ class CustomMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -374,7 +374,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -386,7 +386,7 @@ class CustomMapValue final // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, @@ -394,7 +394,7 @@ class CustomMapValue final google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr NewIterator() const; diff --git a/common/values/legacy_map_value.h b/common/values/legacy_map_value.h index 31865a873..c83b7fc2f 100644 --- a/common/values/legacy_map_value.h +++ b/common/values/legacy_map_value.h @@ -102,7 +102,7 @@ class LegacyMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -111,7 +111,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -120,7 +120,7 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -129,7 +129,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,11 +137,11 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/map_value.h b/common/values/map_value.h index ffbdea6c9..b6e69ea57 100644 --- a/common/values/map_value.h +++ b/common/values/map_value.h @@ -15,10 +15,16 @@ // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" -// `MapValue` represents values of the primitive `map` type. `MapValueView` -// is a non-owning view of `MapValue`. `MapValueInterface` is the abstract -// base class of implementations. `MapValue` and `MapValueView` act as smart -// pointers to `MapValueInterface`. +// `MapValue` represents values of the primitive `map` type. It provides a +// unified interface for accessing map contents, regardless of the underlying +// implementation (e.g., JSON, protobuf map field, or custom implementation). +// +// Public member functions: +// - `IsEmpty()` / `Size()`: Query map size. +// - `Get()` / `Find()` / `Has()`: Access entries by key. +// - `ListKeys()` / `NewIterator()` / `ForEach()`: Iterate over entries. +// - `ConvertToJson()` / `ConvertToJsonObject()`: JSON conversion. +// - `IsCustom()` / `AsCustom()` / `GetCustom()`: Access custom implementation. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ @@ -54,7 +60,6 @@ namespace cel { -class MapValueInterface; class MapValue; class Value; @@ -119,8 +124,13 @@ class MapValue final : private common_internal::MapValueMixin { absl::StatusOr Size() const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Get` sets the value `result` to (via `result`) the value associated with + // `key`. If `key` is not found, `no such key` is set to `result`. If an error + // occurs (e.g., invalid key type), an `no such key` is returned. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -128,8 +138,13 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Find` returns `true` if `key` is found in the map, and stores the + // associated value in `result`. If `key` is not found, `false` is returned + // and `result` is unchanged. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,8 +152,13 @@ class MapValue final : private common_internal::MapValueMixin { google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Has` returns `true` if `key` is found in the map, and stores the BoolValue + // result in `result`. In case of an error, the result is set to an + // ErrorValue. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -146,28 +166,25 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ListKeys` returns a `ListValue` containing all keys in the map. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for - // documentation. + // `ForEachCallback` is the callback type for `ForEach`. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ForEach` calls `callback` for each entry in the map. Iteration continues + // until all entries are visited or `callback` returns an error or `false`. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `NewIterator` returns a new iterator for the map. absl::StatusOr NewIterator() const; // Returns `true` if this value is an instance of a custom map value. diff --git a/common/values/null_value.h b/common/values/null_value.h index 53c3161a1..d4d05dba3 100644 --- a/common/values/null_value.h +++ b/common/values/null_value.h @@ -37,8 +37,7 @@ namespace cel { class Value; class NullValue; -// `NullValue` represents values of the primitive `duration` type. - +// `NullValue` represents the CEL `null` value. class NullValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kNull; diff --git a/common/values/parsed_json_map_value.h b/common/values/parsed_json_map_value.h index b20fe032b..ba8d3490d 100644 --- a/common/values/parsed_json_map_value.h +++ b/common/values/parsed_json_map_value.h @@ -132,7 +132,7 @@ class ParsedJsonMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -141,7 +141,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -150,7 +150,7 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -159,7 +159,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -167,11 +167,11 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/parsed_map_field_value.h b/common/values/parsed_map_field_value.h index 3478f75bc..21d686bfd 100644 --- a/common/values/parsed_map_field_value.h +++ b/common/values/parsed_map_field_value.h @@ -117,7 +117,7 @@ class ParsedMapFieldValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -126,7 +126,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -135,7 +135,7 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -144,7 +144,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -152,11 +152,11 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/values.h b/common/values/values.h index c9703dcbb..aaa6f8659 100644 --- a/common/values/values.h +++ b/common/values/values.h @@ -48,7 +48,6 @@ namespace cel { class ValueInterface; class ListValueInterface; -class MapValueInterface; class StructValueInterface; class Value; diff --git a/eval/eval/equality_steps.cc b/eval/eval/equality_steps.cc index e134069d5..d720302e4 100644 --- a/eval/eval/equality_steps.cc +++ b/eval/eval/equality_steps.cc @@ -132,15 +132,11 @@ class IterativeEqualityStep : public ExpressionStepBase { absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, const Value& item, const MapValue& container) { - absl::StatusOr result = {BoolValue(false)}; switch (item.kind()) { case ValueKind::kBool: case ValueKind::kString: case ValueKind::kInt: case ValueKind::kUint: - result = container.Has(item, frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - break; case ValueKind::kDouble: break; default: @@ -148,9 +144,12 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, cel::runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIn)); } + Value result; + CEL_RETURN_IF_ERROR(container.Has(item, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), + &result)); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + if (result.IsTrue()) { return result; } @@ -159,10 +158,10 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromUint64(item.GetUint().NativeValue()); if (number.LosslessConvertibleToInt()) { - result = container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + CEL_RETURN_IF_ERROR( + container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } @@ -173,21 +172,16 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromInt64(item.GetInt().NativeValue()); if (number.LosslessConvertibleToUint()) { - result = + CEL_RETURN_IF_ERROR( container.Has(UintValue(number.AsUint()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } } - if (!result.ok()) { - return BoolValue(false); - } - - return result; + return BoolValue(false); } absl::StatusOr EvaluateIn(ExecutionFrameBase& frame, const Value& item, diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 037fa8345..b73a2dc55 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1632,7 +1632,8 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); - EXPECT_TRUE(result_value.IsBool()); + ASSERT_TRUE(result_value.IsBool()) + << key.DebugString() << " : " << result_value.DebugString(); EXPECT_FALSE(result_value.BoolOrDie()); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 0c01eb8e9..3fb80bcea 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -28,7 +28,6 @@ #include "base/type_provider.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -139,11 +138,6 @@ class CelTypeRegistry { private: // Internal modern registry. cel::TypeRegistry modern_type_registry_; - - // TODO(uncreated-issue/44): This is needed to inspect the registered legacy type - // providers for client tests. This can be removed when they are migrated to - // use the modern APIs. - std::shared_ptr legacy_type_provider_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 3b3cb9847..2bd9fff9d 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -139,8 +139,7 @@ class FieldAccessor { case FieldDescriptor::TYPE_BYTES: return CelValue::CreateBytesView(value); default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Error handling C++ string conversion"); + break; } break; } @@ -153,8 +152,7 @@ class FieldAccessor { return CelValue::CreateInt64(enum_value); } default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Unhandled C++ type conversion"); + break; } return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 088d20d48..32608bc3f 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -69,8 +69,8 @@ class ProtoMessageTypeAccessorTest : public testing::TestWithParam { bool use_generic_instance = GetParam(); if (use_generic_instance) { // implementation detail: in general, type info implementations may - // return a different accessor object based on the messsage instance, but - // this implemenation returns the same one no matter the message. + // return a different accessor object based on the message instance, but + // this implementation returns the same one no matter the message. return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); } else { diff --git a/runtime/standard/container_membership_functions.cc b/runtime/standard/container_membership_functions.cc index 9f5ca3755..cc0638429 100644 --- a/runtime/standard/container_membership_functions.cc +++ b/runtime/standard/container_membership_functions.cc @@ -174,15 +174,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(BoolValue(key), descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(BoolValue(key), descriptor_pool, + message_factory, arena, &has)); + if (has.IsTrue()) { + return has; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return has; }; auto intKeyInSet = @@ -191,27 +192,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(IntValue(key), descriptor_pool, message_factory, arena); + Value result; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(key), descriptor_pool, + message_factory, arena, &result)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (result.IsTrue()) { + return result; } Number number = Number::FromInt64(key); if (number.LosslessConvertibleToUint()) { - const auto& result = - map_value.Has(UintValue(number.AsUint()), descriptor_pool, - message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value result_alt; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, + arena, &result_alt)); + if (result_alt.IsTrue()) { + return result_alt; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return result; }; auto stringKeyInSet = @@ -220,14 +220,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = map_value.Has(key, descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value result; + CEL_RETURN_IF_ERROR( + map_value.Has(key, descriptor_pool, message_factory, arena, &result)); + if (result.IsBool()) { + return result; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return result; }; auto uintKeyInSet = @@ -236,26 +238,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - const auto& result = - map_value.Has(UintValue(key), descriptor_pool, message_factory, arena); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(key), descriptor_pool, + message_factory, arena, &has)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (has.IsTrue()) { + return has; } + Value has_alt; Number number = Number::FromUint64(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, + arena, &has_alt)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return has; }; auto doubleKeyInSet = @@ -265,17 +267,21 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Number number = Number::FromDouble(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } if (number.LosslessConvertibleToUint()) { - const auto& result = map_value.Has( - UintValue(number.AsUint()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); From c2c1205482616c260313eda6b506cde5f5dc04f7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 21 Apr 2026 16:03:30 -0700 Subject: [PATCH 047/143] Add ExpressionContainer to hold namespace config ExpressionContainer will hold configuration for namespace resolution rules in a given expression, adding support for abbrevs (imports) and aliases. PiperOrigin-RevId: 903474367 --- common/BUILD | 23 +++++ common/container.cc | 189 +++++++++++++++++++++++++++++++++++++++ common/container.h | 116 ++++++++++++++++++++++++ common/container_test.cc | 103 +++++++++++++++++++++ 4 files changed, 431 insertions(+) create mode 100644 common/container.cc create mode 100644 common/container.h create mode 100644 common/container_test.cc diff --git a/common/BUILD b/common/BUILD index e289ef413..ea6246b51 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1140,3 +1140,26 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "container", + srcs = ["container.cc"], + hdrs = ["container.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "container_test", + srcs = ["container_test.cc"], + deps = [ + ":container", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/common/container.cc b/common/container.cc new file mode 100644 index 000000000..4abceea2d --- /dev/null +++ b/common/container.cc @@ -0,0 +1,189 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel { +namespace { + +// Basic validation for accidental misuse. Does not fully validate against the +// CEL grammar rules for identifiers. +bool IsIdentifierChar(char c) { + return c == '_' || std::isalnum(static_cast(c)); +} + +bool IsValidQualifiedName(absl::string_view name) { + bool dot_ok = false; + for (char c : name) { + if (c == '.') { + if (!dot_ok) { + return false; + } + dot_ok = false; + continue; + } + if (!IsIdentifierChar(c)) { + return false; + } + dot_ok = true; + } + // Must not end in a dot. + return dot_ok; +} + +bool IsValidAlias(absl::string_view alias) { + if (alias.empty()) { + return false; + } + for (char c : alias) { + if (!IsIdentifierChar(c)) { + return false; + } + } + return true; +} + +bool IsAbreviation(absl::string_view alias, absl::string_view name) { + auto pos = name.rfind('.'); + return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && + alias == name.substr(pos + 1); +} + +} // namespace + +bool ExpressionContainer::AliasListing::IsAbbreviation() const { + return IsAbreviation(alias, name); +} + +absl::StatusOr ExpressionContainer::Create( + absl::string_view name) { + ExpressionContainer container; + + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + return container; +} + +absl::Status ExpressionContainer::SetContainer(absl::string_view name) { + if (name.empty()) { + container_ = ""; + return absl::OkStatus(); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + for (const auto& entry : aliases_) { + const std::string& alias = entry.first; + if (name == alias || + (name.size() > alias.size() && + absl::string_view(name).substr(0, alias.size()) == alias && + name.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("container name collides with alias: ", alias)); + } + } + + container_ = std::string(name); + return absl::OkStatus(); +} + +absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + if (!IsValidQualifiedName(abrev)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev)); + } + + auto pos = abrev.rfind('.'); + if (pos == 0 || pos == absl::string_view::npos || pos == abrev.size() - 1) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + absl::string_view alias = abrev.substr(pos + 1); + return AddAlias(alias, abrev); +} + +absl::Status ExpressionContainer::AddAlias(absl::string_view alias, + absl::string_view name) { + if (!IsValidAlias(alias)) { + return absl::InvalidArgumentError(absl::StrCat( + "alias must be non-empty and simple (not qualified): ", alias)); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + if (auto it = aliases_.find(alias); it != aliases_.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "alias collides with existing reference: ", alias, " -> ", it->second)); + } + + if (container_ == alias || + (container_.size() > alias.size() && + absl::string_view(container_).substr(0, alias.size()) == alias && + container_.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("alias collides with container name: ", alias)); + } + + aliases_.insert({std::string(alias), std::string(name)}); + return absl::OkStatus(); +} + +absl::string_view ExpressionContainer::FindAlias( + absl::string_view alias) const { + auto it = aliases_.find(alias); + if (it != aliases_.end()) { + return it->second; + } + return ""; +} + +std::vector ExpressionContainer::ListAbbreviations() const { + std::vector res; + for (const auto& entry : aliases_) { + if (IsAbreviation(entry.first, entry.second)) { + res.push_back(entry.second); + } + } + return res; +} + +std::vector +ExpressionContainer::ListAliases() const { + std::vector res; + for (const auto& entry : aliases_) { + res.push_back({entry.first, entry.second}); + } + return res; +} + +} // namespace cel diff --git a/common/container.h b/common/container.h new file mode 100644 index 000000000..a6555a8ac --- /dev/null +++ b/common/container.h @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel { + +// ExpressionContainer represents the namespace configuration for a CEL +// expression. +// +// The container defines the default resolution order for names referenced in +// the expression. It generally maps to a protobuf package and follows +// approximately the same resolution rules as protobuf or C++ namespaces. +// +// Aliases declare short names that can be referenced without resolving against +// the scopes defined by the container. For consistency, an alias cannot be +// a prefix of the container name. Aliases are always unqualified identifiers. +// +// An abbreviation is a special case of alias that behaves like an import or +// using declaration in other languages. (pkg.TypeName -> TypeName). +// +// For better traceability, prefer using abbreviations over aliases. +class ExpressionContainer { + public: + struct AliasListing { + std::string alias; + std::string name; + + bool IsAbbreviation() const; + }; + + ExpressionContainer() = default; + + static absl::StatusOr Create(absl::string_view name); + + ExpressionContainer(const ExpressionContainer&) = default; + ExpressionContainer(ExpressionContainer&&) = default; + ExpressionContainer& operator=(const ExpressionContainer&) = default; + ExpressionContainer& operator=(ExpressionContainer&&) = default; + + // Returns the full name of the container. + // + // The default value is an empty string meaning no container. + absl::string_view container() const { return container_; } + + // Sets the container name. + // + // Returns an error if the container name is malformed or conflicts with an + // existing alias. + absl::Status SetContainer(absl::string_view name); + + // Adds an abbreviation to the container. + // + // Returns an error if the abbreviation is malformed or conflicts with the + // container or an existing alias. + absl::Status AddAbbreviation(absl::string_view abrev); + + // Adds an alias to the container. + // + // Returns an error if the alias is malformed or conflicts with the container + // or an existing alias. + absl::Status AddAlias(absl::string_view alias, absl::string_view name); + + // Returns the full name of the alias or an empty string if not found. + // + // The returned string view may be invalidated by updates to the + // ExpressionContainer. + absl::string_view FindAlias(absl::string_view alias) const; + + // Utility method for listing the abbreviations in the container. + // Order is not guaranteed. + std::vector ListAbbreviations() const; + + // Utility method for listing the aliases in the container. + // Includes abbreviations. + // Order is not guaranteed. + std::vector ListAliases() const; + + // Removes all aliases and abbreviations from the container. + void clear() { + container_.clear(); + aliases_.clear(); + } + + private: + explicit ExpressionContainer(absl::string_view name); + + std::string container_; + + // alias -> full name. + absl::flat_hash_map aliases_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc new file mode 100644 index 000000000..d8c052040 --- /dev/null +++ b/common/container_test.cc @@ -0,0 +1,103 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(ExpressionContainerTest, DefaultConstructed) { + ExpressionContainer container; + EXPECT_THAT(container.container(), IsEmpty()); + EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); +} + +TEST(ExpressionContainerTest, SetContainer) { + ExpressionContainer container; + EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); + EXPECT_THAT(container.container(), Eq("my.container.name")); + EXPECT_THAT(container.SetContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, AddAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); +} + +TEST(ExpressionContainerTest, AddAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); + EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); +} + +TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); + EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); + + EXPECT_THAT(container.ListAbbreviations(), + UnorderedElementsAre("qual.pkg.Abbr")); + + auto aliases = container.ListAliases(); + EXPECT_THAT(aliases, SizeIs(2)); +} + +TEST(ExpressionContainerTest, InvalidAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation(""), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation(".pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg."), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, InvalidAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo.bar", "baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo", ".baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, CollidesWithContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("my", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel From 40176739fca23895ef5b7e1923d24746eeb8d519 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 22 Apr 2026 11:23:30 -0700 Subject: [PATCH 048/143] Wire ExpressionContainer into type checker. Update to Make* convention for factories to match other things in /common. Not used yet. PiperOrigin-RevId: 903948803 --- checker/BUILD | 1 + checker/internal/BUILD | 5 +-- checker/internal/type_check_env.h | 8 ++--- checker/internal/type_checker_builder_impl.cc | 9 +++-- checker/internal/type_checker_builder_impl.h | 6 +++- checker/internal/type_checker_impl.cc | 15 ++++----- checker/internal/type_checker_impl_test.cc | 15 +++++---- checker/type_checker_builder.h | 9 ++++- common/container.cc | 2 +- common/container.h | 29 +++++++++++++--- common/container_test.cc | 33 +++++++++++++++---- 11 files changed, 95 insertions(+), 37 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index d5eb3601c..7b151d6a8 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -88,6 +88,7 @@ cc_library( deps = [ ":checker_options", ":type_checker", + "//common:container", "//common:decl", "//common:type", "@com_google_absl//absl/base:nullability", diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 3f64417a0..73e5c177d 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -67,6 +67,7 @@ cc_library( deps = [ ":descriptor_pool_type_introspector", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//internal:status_macros", @@ -120,7 +121,6 @@ cc_library( "type_checker_impl.h", ], deps = [ - ":descriptor_pool_type_introspector", ":format_type_name", ":namespace_generator", ":type_check_env", @@ -136,9 +136,9 @@ cc_library( "//common:ast_visitor", "//common:ast_visitor_base", "//common:constant", + "//common:container", "//common:decl", "//common:expr", - "//common:source", "//common:type", "//common:type_kind", "//internal:lexis", @@ -172,6 +172,7 @@ cc_test( "//checker:type_check_issue", "//checker:validation_result", "//common:ast", + "//common:container", "//common:decl", "//common:expr", "//common:source", diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 520b0eab6..491e4b550 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -91,7 +92,6 @@ class TypeCheckEnv { absl_nonnull std::shared_ptr descriptor_pool) : descriptor_pool_(std::move(descriptor_pool)), - container_(""), proto_type_introspector_( std::make_shared( descriptor_pool_.get())) { @@ -104,9 +104,9 @@ class TypeCheckEnv { TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; - const std::string& container() const { return container_; } + const ExpressionContainer& container() const { return container_; } - void set_container(std::string container) { + void set_container(ExpressionContainer container) { container_ = std::move(container); } @@ -206,7 +206,7 @@ class TypeCheckEnv { // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; - std::string container_; + ExpressionContainer container_; // Used to resolve fields on message types. std::shared_ptr proto_type_introspector_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 7545aa949..9ebcb4e34 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -343,7 +343,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( absl::StatusOr> TypeCheckerBuilderImpl::Build() { TypeCheckEnv env(descriptor_pool_); - env.set_container(container_); + env.set_container(expression_container_); if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } @@ -479,7 +479,12 @@ void TypeCheckerBuilderImpl::AddTypeProvider( } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { - container_ = container; + expression_container_.SetContainer(container).IgnoreError(); +} + +void TypeCheckerBuilderImpl::SetExpressionContainer( + ExpressionContainer container) { + expression_container_ = std::move(container); } void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 3b3472232..7a099040b 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -31,6 +31,7 @@ #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -76,6 +77,9 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { void set_container(absl::string_view container) override; + void SetExpressionContainer( + ExpressionContainer expression_container) override; + const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { @@ -137,7 +141,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; - std::string container_; + ExpressionContainer expression_container_; absl::optional expected_type_; }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 8e8047755..df8f83683 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -187,14 +187,12 @@ class ResolveVisitor : public AstVisitorBase { bool requires_disambiguation; }; - ResolveVisitor(absl::string_view container, - NamespaceGenerator namespace_generator, + ResolveVisitor(NamespaceGenerator namespace_generator, const TypeCheckEnv& env, const Ast& ast, TypeInferenceContext& inference_context, std::vector& issues, google::protobuf::Arena* absl_nonnull arena) - : container_(container), - namespace_generator_(std::move(namespace_generator)), + : namespace_generator_(std::move(namespace_generator)), env_(&env), inference_context_(&inference_context), issues_(&issues), @@ -326,7 +324,7 @@ class ResolveVisitor : public AstVisitorBase { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", - container_, "')"))); + env_->container().container(), "')"))); } void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, @@ -407,7 +405,6 @@ class ResolveVisitor : public AstVisitorBase { return DynType(); } - absl::string_view container_; NamespaceGenerator namespace_generator_; const TypeCheckEnv* absl_nonnull env_; TypeInferenceContext* absl_nonnull inference_context_; @@ -1260,12 +1257,12 @@ absl::StatusOr TypeCheckerImpl::Check( google::protobuf::Arena type_arena; std::vector issues; - CEL_ASSIGN_OR_RETURN(auto generator, - NamespaceGenerator::Create(env_.container())); + CEL_ASSIGN_OR_RETURN( + auto generator, NamespaceGenerator::Create(env_.container().container())); TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); - ResolveVisitor visitor(env_.container(), std::move(generator), env_, *ast, + ResolveVisitor visitor(std::move(generator), env_, *ast, type_inference_context, issues, &type_arena); TraversalOptions opts; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 6eccc3701..714e669cd 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -35,6 +35,7 @@ #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/container.h" #include "common/decl.h" #include "common/expr.h" #include "common/source.h" @@ -757,7 +758,7 @@ TEST(TypeCheckerImplTest, NestedComprehensions) { TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("com"); + env.set_container(*MakeExpressionContainer("com")); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1462,7 +1463,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); @@ -1483,7 +1484,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); CheckerOptions options; @@ -1508,7 +1509,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -1538,7 +1539,7 @@ TEST_P(WktCreationTest, MessageCreation) { const CheckedExprTestCase& test_case = GetParam(); TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1696,7 +1697,7 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( @@ -2247,7 +2248,7 @@ TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index e5942b157..b3a86f64c 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -25,6 +25,7 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/type_checker.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -132,10 +133,16 @@ class TypeCheckerBuilder { // // This is used for resolving references in the expressions being built. // + // Prefer setting the container via SetExpressionContainer(). + // // Note: if set multiple times, the last value is used. This can lead to - // surprising behavior if used in a custom library. + // surprising behavior if used in a custom library. If container is not a + // valid container name, the operation is ignored. virtual void set_container(absl::string_view container) = 0; + virtual void SetExpressionContainer( + ExpressionContainer expression_container) = 0; + // The current options for the TypeChecker being built. virtual const CheckerOptions& options() const = 0; diff --git a/common/container.cc b/common/container.cc index 4abceea2d..dbfa987d0 100644 --- a/common/container.cc +++ b/common/container.cc @@ -75,7 +75,7 @@ bool ExpressionContainer::AliasListing::IsAbbreviation() const { return IsAbreviation(alias, name); } -absl::StatusOr ExpressionContainer::Create( +absl::StatusOr MakeExpressionContainer( absl::string_view name) { ExpressionContainer container; diff --git a/common/container.h b/common/container.h index a6555a8ac..cd40aaef9 100644 --- a/common/container.h +++ b/common/container.h @@ -51,8 +51,6 @@ class ExpressionContainer { ExpressionContainer() = default; - static absl::StatusOr Create(absl::string_view name); - ExpressionContainer(const ExpressionContainer&) = default; ExpressionContainer(ExpressionContainer&&) = default; ExpressionContainer& operator=(const ExpressionContainer&) = default; @@ -103,14 +101,37 @@ class ExpressionContainer { } private: - explicit ExpressionContainer(absl::string_view name); - std::string container_; // alias -> full name. absl::flat_hash_map aliases_; }; +// Factory function for creating an ExpressionContainer. +absl::StatusOr MakeExpressionContainer( + absl::string_view name); + +// Factory function for creating an ExpressionContainer with a list of +// abbreviations. +template +absl::StatusOr MakeExpressionContainer( + absl::string_view name, Args&&... abbrevs) { + ExpressionContainer container; + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + absl::string_view abbrevs_view[] = {std::forward(abbrevs)...}; + for (absl::string_view abrev : abbrevs_view) { + status.Update(container.AddAbbreviation(abrev)); + if (!status.ok()) { + return status; + } + } + + return container; +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc index d8c052040..991362320 100644 --- a/common/container_test.cc +++ b/common/container_test.cc @@ -33,6 +33,27 @@ TEST(ExpressionContainerTest, DefaultConstructed) { EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); } +TEST(ExpressionContainerTest, MakeExpressionContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.container(), Eq("my.container")); + + EXPECT_THAT(MakeExpressionContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, MakeExpressionContainerWithAbbrevs) { + ASSERT_OK_AND_ASSIGN( + ExpressionContainer container, + MakeExpressionContainer("my.container", "pkg.Abbr", "qual.pkg.Abbr2")); + EXPECT_THAT(container.container(), Eq("my.container")); + EXPECT_THAT(container.FindAlias("Abbr"), Eq("pkg.Abbr")); + EXPECT_THAT(container.FindAlias("Abbr2"), Eq("qual.pkg.Abbr2")); + + EXPECT_THAT(MakeExpressionContainer("my.container", "invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + TEST(ExpressionContainerTest, SetContainer) { ExpressionContainer container; EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); @@ -43,21 +64,21 @@ TEST(ExpressionContainerTest, SetContainer) { TEST(ExpressionContainerTest, AddAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); } TEST(ExpressionContainerTest, AddAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); } TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); @@ -70,7 +91,7 @@ TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { TEST(ExpressionContainerTest, InvalidAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation(""), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAbbreviation("pkg"), @@ -83,7 +104,7 @@ TEST(ExpressionContainerTest, InvalidAbbreviation) { TEST(ExpressionContainerTest, InvalidAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAlias("foo.bar", "baz"), @@ -94,7 +115,7 @@ TEST(ExpressionContainerTest, InvalidAlias) { TEST(ExpressionContainerTest, CollidesWithContainer) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("my", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); } From 0341704b5c79509317360faec164e034d765fd9b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 23 Apr 2026 13:01:06 -0700 Subject: [PATCH 049/143] Add support for resolving aliases. PiperOrigin-RevId: 904596936 --- checker/internal/BUILD | 5 + checker/internal/namespace_generator.cc | 109 +++++-- checker/internal/namespace_generator.h | 44 ++- checker/internal/namespace_generator_test.cc | 70 +++-- checker/internal/type_checker_impl.cc | 4 +- checker/internal/type_checker_impl_test.cc | 290 ++++++++++++++++++- common/BUILD | 1 + common/container.cc | 43 +-- common/container.h | 5 +- common/container_test.cc | 2 + 10 files changed, 482 insertions(+), 91 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 73e5c177d..1af48af57 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -89,8 +89,12 @@ cc_library( srcs = ["namespace_generator.cc"], hdrs = ["namespace_generator.h"], deps = [ + "//common:container", "//internal:lexis", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -104,6 +108,7 @@ cc_test( srcs = ["namespace_generator_test.cc"], deps = [ ":namespace_generator", + "//common:container", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", diff --git a/checker/internal/namespace_generator.cc b/checker/internal/namespace_generator.cc index e5b2cfa51..7ab7628e4 100644 --- a/checker/internal/namespace_generator.cc +++ b/checker/internal/namespace_generator.cc @@ -20,7 +20,7 @@ #include #include "absl/functional/function_ref.h" -#include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -28,19 +28,20 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" #include "internal/lexis.h" namespace cel::checker_internal { namespace { -bool FieldSelectInterpretationCandidates( +bool FieldSelectInterpretationCandidatesImpl( absl::string_view prefix, - absl::Span partly_qualified_name, + absl::Span partly_qualified_name, bool prefix_is_alias, absl::FunctionRef callback) { for (int i = 0; i < partly_qualified_name.size(); ++i) { std::string buf; int count = partly_qualified_name.size() - i; - auto end_idx = count - 1; + auto end_idx = count - (prefix_is_alias ? 0 : 1); auto ident = absl::StrJoin(partly_qualified_name.subspan(0, count), "."); absl::string_view candidate = ident; if (absl::StartsWith(candidate, ".")) { @@ -54,28 +55,44 @@ bool FieldSelectInterpretationCandidates( return false; } } + if (prefix_is_alias) { + return callback(prefix, 0); + } return true; } +bool FieldSelectInterpretationCandidates( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/false, callback); +} + +bool FieldSelectInterpretationCandidatesWithAlias( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/true, callback); +} + } // namespace absl::StatusOr NamespaceGenerator::Create( - absl::string_view container) { + const ExpressionContainer& expression_container) { std::vector candidates; + absl::string_view container = expression_container.container(); if (container.empty()) { - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } - if (absl::StartsWith(container, ".")) { - return absl::InvalidArgumentError("container must not start with a '.'"); - } std::string prefix; for (auto segment : absl::StrSplit(container, '.')) { - if (!internal::LexisIsIdentifier(segment)) { - return absl::InvalidArgumentError( - "container must only contain valid identifier segments"); - } + // Assumes the the ExpressionContainer has already validated the container + // and aliases. + ABSL_DCHECK(internal::LexisIsIdentifier(segment)); if (prefix.empty()) { prefix = segment; } else { @@ -84,31 +101,75 @@ absl::StatusOr NamespaceGenerator::Create( candidates.push_back(prefix); } std::reverse(candidates.begin(), candidates.end()); - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } void NamespaceGenerator::GenerateCandidates( - absl::string_view unqualified_name, - absl::FunctionRef callback) { - if (absl::StartsWith(unqualified_name, ".")) { - callback(unqualified_name.substr(1)); + absl::string_view simple_name, + absl::FunctionRef callback) const { + // Special case for root-relative names. Aliases still apply first. + bool is_root_relative = absl::StartsWith(simple_name, "."); + if (is_root_relative) { + simple_name = simple_name.substr(1); + } + + // The name is unqualified, but may include a namespace (struct creation). + // This is just a quirk of the parser. + if (auto dot_pos = simple_name.find('.'); + dot_pos != absl::string_view::npos) { + absl::string_view first_segment = simple_name.substr(0, dot_pos); + absl::string_view rest = simple_name.substr(dot_pos + 1); + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + callback(absl::StrCat(resolved_alias, ".", rest)); + return; + } + } else { + if (auto resolved_alias = expression_container_->FindAlias(simple_name); + !resolved_alias.empty()) { + callback(resolved_alias); + return; + } + } + + if (is_root_relative) { + callback(simple_name); return; } + for (const auto& prefix : candidates_) { - std::string candidate = absl::StrCat(prefix, ".", unqualified_name); + std::string candidate = absl::StrCat(prefix, ".", simple_name); if (!callback(candidate)) { return; } } - callback(unqualified_name); + callback(simple_name); } void NamespaceGenerator::GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback) { - // Special case for explicit root relative name. e.g. '.com.example.Foo' - if (!partly_qualified_name.empty() && - absl::StartsWith(partly_qualified_name[0], ".")) { + absl::FunctionRef callback) const { + if (partly_qualified_name.empty()) { + return; + } + + // Special case for root-relative names. Aliases still apply first. + absl::string_view first_segment = partly_qualified_name[0]; + bool is_root_relative = absl::StartsWith(first_segment, "."); + if (is_root_relative) { + first_segment = first_segment.substr(1); + } + + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + FieldSelectInterpretationCandidatesWithAlias( + resolved_alias, partly_qualified_name.subspan(1), callback); + // If the alias matches, we don't check the container even if name + // resolution fails. + return; + } + + if (is_root_relative) { FieldSelectInterpretationCandidates("", partly_qualified_name, callback); return; } diff --git a/checker/internal/namespace_generator.h b/checker/internal/namespace_generator.h index 18c40dbda..61cb1956b 100644 --- a/checker/internal/namespace_generator.h +++ b/checker/internal/namespace_generator.h @@ -19,18 +19,26 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" namespace cel::checker_internal { // Utility class for generating namespace qualified candidates for reference // resolution. +// +// This class is expected to be scoped to a single type checking operation and +// borrows the ExpressionContainer from the TypeCheckEnv. class NamespaceGenerator { public: - static absl::StatusOr Create(absl::string_view container); + static absl::StatusOr Create( + const ExpressionContainer& expression_container + ABSL_ATTRIBUTE_LIFETIME_BOUND); // Copyable and movable. NamespaceGenerator(const NamespaceGenerator&) = default; @@ -51,8 +59,18 @@ class NamespaceGenerator { // and unqualified name foo // // com.google.foo, com.foo, foo - void GenerateCandidates(absl::string_view unqualified_name, - absl::FunctionRef callback); + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (foo = com.example) + // unqualified name foo + // + // com.example + void GenerateCandidates( + absl::string_view simple_name, + absl::FunctionRef callback) const; // For a partially qualified name, generate all the qualified candidates in // order of resolution precedence and pass them to the provided callback. The @@ -72,16 +90,30 @@ class NamespaceGenerator { // (com.Foo).bar, // (Foo.bar), // (Foo).bar, + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (Foo = com.example.Foo) + // partially qualified name Foo.bar + // + // (com.example.Foo.bar), + // (com.example.Foo).bar, void GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback); + absl::FunctionRef callback) const; private: - explicit NamespaceGenerator(std::vector candidates) - : candidates_(std::move(candidates)) {} + explicit NamespaceGenerator( + const ExpressionContainer* absl_nonnull expression_container, + std::vector candidates) + : candidates_(std::move(candidates)), + expression_container_(expression_container) {} // list of prefixes ordered from most qualified to least. std::vector candidates_; + const ExpressionContainer* absl_nonnull expression_container_; }; } // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator_test.cc b/checker/internal/namespace_generator_test.cc index da174748a..ba9bb88a4 100644 --- a/checker/internal/namespace_generator_test.cc +++ b/checker/internal/namespace_generator_test.cc @@ -18,19 +18,20 @@ #include #include -#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "common/container.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { -using ::absl_testing::StatusIs; +using ::absl_testing::IsOk; using ::testing::ElementsAre; using ::testing::Pair; TEST(NamespaceGeneratorTest, EmptyContainer) { - ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create("")); + ExpressionContainer container; + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -40,8 +41,9 @@ TEST(NamespaceGeneratorTest, EmptyContainer) { } TEST(NamespaceGeneratorTest, MultipleSegments) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -51,8 +53,9 @@ TEST(NamespaceGeneratorTest, MultipleSegments) { } TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates(".foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -61,18 +64,46 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { EXPECT_THAT(candidates, ElementsAre("foo")); } -TEST(NamespaceGeneratorTest, InvalidContainers) { - EXPECT_THAT(NamespaceGenerator::Create(".com.example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com..example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com.$example"), - StatusIs(absl::StatusCode::kInvalidArgument)); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT( + candidates, + ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), + Pair("com.foo.Bar", 1), Pair("com.foo", 0), + Pair("foo.Bar", 1), Pair("foo", 0))); } -TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT(candidates, + ElementsAre(Pair("bar.baz.Bar", 1), Pair("bar.baz", 0))); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasNoMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAbbreviation("foo.Bar"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + // No match on the alias (Bar) since it's not the first segment. std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( @@ -89,8 +120,9 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {".foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index df8f83683..6697cc06a 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1257,8 +1257,8 @@ absl::StatusOr TypeCheckerImpl::Check( google::protobuf::Arena type_arena; std::vector issues; - CEL_ASSIGN_OR_RETURN( - auto generator, NamespaceGenerator::Create(env_.container().container())); + CEL_ASSIGN_OR_RETURN(auto generator, + NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 714e669cd..a82dd58ee 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -75,6 +75,7 @@ using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; namespace testpb3 = ::cel::expr::conformance::proto3; +namespace testpb2 = ::cel::expr::conformance::proto2; std::string SevString(Severity severity) { switch (severity) { @@ -583,6 +584,34 @@ TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } +TEST(TypeCheckerImplTest, NamespacedFunctionWithAbbreviation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + env.set_container(*MakeExpressionContainer("", "x.y.foo")); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + TEST(TypeCheckerImplTest, MixedListTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -1689,10 +1718,257 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(PrimitiveType::kBool), })); +TEST(AliasTest, ImportVariable) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr.conformance", + "com.example.TestVariable1", + "com.example.TestVariable2")); + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable1", + MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable2", + MessageType(testpb2::TestAllTypes::descriptor())))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + "TestVariable1.single_int64 == TestVariable2.single_int64")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + ASSERT_TRUE(checked_ast->root_expr().has_call_expr()); + ASSERT_EQ(checked_ast->root_expr().call_expr().function(), "_==_"); + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[0] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable1"); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[1] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable2"); +} + +TEST(AliasTest, AliasToContainerResolvesMessage) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))))); + + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, + "cel.expr.conformance.proto3.TestAllTypes")))); + + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(AliasTest, AliasSimpleName) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("foo", "bar"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertOrReplaceVariable(MakeVariableDecl("bar", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "bar"); +} + +TEST(AliasTest, AliasPreventsContainerResolution) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr")); + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("cel.expr.pb3.FooVariable", IntType()))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("expr.pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), + "cel.expr.pb3.FooVariable"); + } +} + +TEST(AliasTest, AliasPreventsDisambiguation) { + // Copying behavior from cel-go and cel-java. + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + env.InsertOrReplaceVariable(MakeVariableDecl("pb3.Foo", IntType())); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst(".pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.Foo'"))); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(".pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to '.pb3.Foo'"))); + } +} + class GenericMessagesTest : public testing::TestWithParam { }; -TEST_P(GenericMessagesTest, TypeChecksProto3) { +TEST_P(GenericMessagesTest, TypeChecksProto3Imports) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer( + "", "cel.expr.conformance.proto3.TestAllTypes", + "cel.expr.conformance.proto3.NestedTestAllTypes")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); +} + +TEST_P(GenericMessagesTest, TypeChecksProto3Container) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; @@ -1715,11 +1991,7 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { return; } - ASSERT_TRUE(result.IsValid()) - << absl::StrJoin(result.GetIssues(), "\n", - [](std::string* out, const TypeCheckIssue& issue) { - absl::StrAppend(out, issue.message()); - }); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); @@ -1840,6 +2112,12 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_nested_message: " + "[TestAllTypes.NestedMessage{bb: 42}]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: duration('1s')}", .expected_result_type = AstType( diff --git a/common/BUILD b/common/BUILD index ea6246b51..db177951d 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1146,6 +1146,7 @@ cc_library( srcs = ["container.cc"], hdrs = ["container.h"], deps = [ + "//internal:lexis", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/container.cc b/common/container.cc index dbfa987d0..f69f0cc80 100644 --- a/common/container.cc +++ b/common/container.cc @@ -14,7 +14,6 @@ #include "common/container.h" -#include #include #include @@ -22,48 +21,28 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "internal/lexis.h" namespace cel { namespace { -// Basic validation for accidental misuse. Does not fully validate against the -// CEL grammar rules for identifiers. -bool IsIdentifierChar(char c) { - return c == '_' || std::isalnum(static_cast(c)); -} - bool IsValidQualifiedName(absl::string_view name) { - bool dot_ok = false; - for (char c : name) { - if (c == '.') { - if (!dot_ok) { - return false; - } - dot_ok = false; - continue; - } - if (!IsIdentifierChar(c)) { + auto dot_pos = name.find('.'); + while (dot_pos != absl::string_view::npos) { + if (!internal::LexisIsIdentifier(name.substr(0, dot_pos))) { return false; } - dot_ok = true; + name = name.substr(dot_pos + 1); + dot_pos = name.find('.'); } - // Must not end in a dot. - return dot_ok; + return internal::LexisIsIdentifier(name); } bool IsValidAlias(absl::string_view alias) { - if (alias.empty()) { - return false; - } - for (char c : alias) { - if (!IsIdentifierChar(c)) { - return false; - } - } - return true; + return internal::LexisIsIdentifier(alias); } -bool IsAbreviation(absl::string_view alias, absl::string_view name) { +bool IsAbbreviationImpl(absl::string_view alias, absl::string_view name) { auto pos = name.rfind('.'); return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && alias == name.substr(pos + 1); @@ -72,7 +51,7 @@ bool IsAbreviation(absl::string_view alias, absl::string_view name) { } // namespace bool ExpressionContainer::AliasListing::IsAbbreviation() const { - return IsAbreviation(alias, name); + return IsAbbreviationImpl(alias, name); } absl::StatusOr MakeExpressionContainer( @@ -170,7 +149,7 @@ absl::string_view ExpressionContainer::FindAlias( std::vector ExpressionContainer::ListAbbreviations() const { std::vector res; for (const auto& entry : aliases_) { - if (IsAbreviation(entry.first, entry.second)) { + if (IsAbbreviationImpl(entry.first, entry.second)) { res.push_back(entry.second); } } diff --git a/common/container.h b/common/container.h index cd40aaef9..ad8d91c35 100644 --- a/common/container.h +++ b/common/container.h @@ -33,8 +33,9 @@ namespace cel { // approximately the same resolution rules as protobuf or C++ namespaces. // // Aliases declare short names that can be referenced without resolving against -// the scopes defined by the container. For consistency, an alias cannot be -// a prefix of the container name. Aliases are always unqualified identifiers. +// the scopes defined by the container. An alias cannot be a prefix of the +// container name, (otherwise re-type-checking an expression could +// change the meaning). Aliases are always unqualified identifiers. // // An abbreviation is a special case of alias that behaves like an import or // using declaration in other languages. (pkg.TypeName -> TypeName). diff --git a/common/container_test.cc b/common/container_test.cc index 991362320..e40814f54 100644 --- a/common/container_test.cc +++ b/common/container_test.cc @@ -60,6 +60,8 @@ TEST(ExpressionContainerTest, SetContainer) { EXPECT_THAT(container.container(), Eq("my.container.name")); EXPECT_THAT(container.SetContainer("..invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.SetContainer("foo.1invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, AddAlias) { From 1ca513feb24061c792b1ed086105df39f4339f5c Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Thu, 23 Apr 2026 23:23:06 -0700 Subject: [PATCH 050/143] Add signature generation functions for types and function overloads PiperOrigin-RevId: 904826666 --- common/BUILD | 1 + common/decl.cc | 187 ++-------------------- common/decl_test.cc | 28 +--- common/internal/BUILD | 31 ++++ common/internal/signature.cc | 211 +++++++++++++++++++++++++ common/internal/signature.h | 61 ++++++++ common/internal/signature_test.cc | 249 ++++++++++++++++++++++++++++++ env/type_info.cc | 6 + env/type_info_test.cc | 4 + 9 files changed, 580 insertions(+), 198 deletions(-) create mode 100644 common/internal/signature.cc create mode 100644 common/internal/signature.h create mode 100644 common/internal/signature_test.cc diff --git a/common/BUILD b/common/BUILD index db177951d..0ead8b15a 100644 --- a/common/BUILD +++ b/common/BUILD @@ -114,6 +114,7 @@ cc_library( ":constant", ":type", ":type_kind", + "//common/internal:signature", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", diff --git a/common/decl.cc b/common/decl.cc index 1e06cb703..b338bfd4f 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -23,8 +23,10 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/internal/signature.h" #include "common/type.h" #include "common/type_kind.h" @@ -104,181 +106,6 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { return args_overlap; } -void AppendEscaped(std::string* result, absl::string_view str, - bool escape_dot) { - for (char c : str) { - switch (c) { - case '\\': - case '(': - case ')': - case '<': - case '>': - case '"': - case ',': - result->push_back('\\'); - result->push_back(c); - break; - case '.': - if (escape_dot) { - result->push_back('\\'); - } - result->push_back(c); - break; - default: - result->push_back(c); - break; - } - } -} - -void AppendTypeParameters(std::string* result, const Type& type); - -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -void AppendTypeToOverloadId(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - return; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - return; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - return; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - return; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - return; - case TypeKind::kString: - absl::StrAppend(result, "string"); - return; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - return; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - return; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - return; - case TypeKind::kUnknown: - absl::StrAppend(result, "unknown"); - return; - case TypeKind::kError: - absl::StrAppend(result, "error"); - return; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - return; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - return; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - return; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - return; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - return; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - return; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - return; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - return; - case TypeKind::kList: - absl::StrAppend(result, "list"); - AppendTypeParameters(result, type); - return; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - AppendTypeParameters(result, type); - return; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - AppendTypeParameters(result, type); - return; - case TypeKind::kEnum: - absl::StrAppend(result, "enum"); - AppendTypeParameters(result, type); - return; - case TypeKind::kType: - absl::StrAppend(result, "type"); - AppendTypeParameters(result, type); - return; - case TypeKind::kOpaque: - result->push_back('"'); - AppendEscaped(result, type.name(), /*escape_dot=*/false); - result->push_back('"'); - AppendTypeParameters(result, type); - return; - default: // This includes TypeKind::kStruct aka TypeKind::kTypeMessage - AppendEscaped(result, type.name(), /*escape_dot=*/false); - return; - } -} - -void AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - AppendTypeToOverloadId(result, parameters[i]); - if (i < parameters.size() - 1) { - result->push_back(','); - } - } - result->push_back('>'); - } -} - -// Generates an identifier for the overload based on the function name and -// the types of the arguments. If `member` is true, the first argument type -// is used as the receiver and is prepended to the function name, followed by -// a dot. -// -// Examples: -// -// - `foo()` -// - `foo(int)` -// - `bar.foo(int)` -// - `foo(int,string)` -// - `foo(list,list)` -// - `bar.foo(list,list<"my_type">)` -// -std::string GenerateOverloadId(std::string_view function_name, - const std::vector& args, bool member) { - std::string result; - if (member) { - if (!args.empty()) { - AppendTypeToOverloadId(&result, args[0]); - } else { - // This should never happen: a member function with no receiver. - absl::StrAppend(&result, "error"); - } - result.push_back('.'); - } - AppendEscaped(&result, function_name, /*escape_dot=*/true); - result.push_back('('); - for (size_t i = member ? 1 : 0; i < args.size(); ++i) { - AppendTypeToOverloadId(&result, args[i]); - if (i < args.size() - 1) { - result.push_back(','); - } - } - result.push_back(')'); - - return result; -} - template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, @@ -290,8 +117,14 @@ void AddOverloadInternal(std::string_view function_name, if (overload.id().empty()) { OverloadDecl overload_decl = overload; - overload_decl.set_id(GenerateOverloadId(function_name, overload_decl.args(), - overload_decl.member())); + absl::StatusOr overload_id = + common_internal::MakeOverloadSignature( + function_name, overload_decl.args(), overload_decl.member()); + if (!overload_id.ok()) { + status = overload_id.status(); + return; + } + overload_decl.set_id(*overload_id); AddOverloadInternal(function_name, insertion_order, overloads, std::move(overload_decl), status); return; diff --git a/common/decl_test.cc b/common/decl_test.cc index 6e5710049..510cd5017 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -14,7 +14,10 @@ #include "common/decl.h" -#include "absl/log/die_if_null.h" +#include +#include + +#include "absl/log/die_if_null.h" // IWYU pragma: keep #include "absl/status/status.h" #include "common/constant.h" #include "common/type.h" @@ -186,7 +189,6 @@ TEST(FunctionDecl, OverloadId) { MakeOverloadDecl(IntType{}, TimestampType{}), MakeOverloadDecl(IntType{}, IntWrapperType{}), MakeOverloadDecl(IntType{}, MessageType(descriptor)), - MakeMemberOverloadDecl(IntType{}), MakeMemberOverloadDecl(StringType{}, StringType{}), MakeMemberOverloadDecl(StringType{}, StringType{}, ListType(&arena, BoolType{})), @@ -198,36 +200,20 @@ TEST(FunctionDecl, OverloadId) { ElementsAre(Property(&OverloadDecl::id, "hello()"), Property(&OverloadDecl::id, "hello(string)"), Property(&OverloadDecl::id, "hello(int,uint)"), - Property(&OverloadDecl::id, "hello(list)"), - Property(&OverloadDecl::id, "hello(map)"), - Property(&OverloadDecl::id, "hello(\"bar\">)"), + Property(&OverloadDecl::id, "hello(list<~A>)"), + Property(&OverloadDecl::id, "hello(map<~B,~C>)"), + Property(&OverloadDecl::id, "hello(bar>)"), Property(&OverloadDecl::id, "hello(any)"), Property(&OverloadDecl::id, "hello(duration)"), Property(&OverloadDecl::id, "hello(timestamp)"), Property(&OverloadDecl::id, "hello(int_wrapper)"), Property(&OverloadDecl::id, "hello(cel.expr.conformance.proto3.TestAllTypes)"), - Property(&OverloadDecl::id, "error.hello()"), Property(&OverloadDecl::id, "string.hello()"), Property(&OverloadDecl::id, "string.hello(list)"), Property(&OverloadDecl::id, "string.hello(bool,dyn)"))); } -TEST(FunctionDecl, OverloadIdEscaping) { - google::protobuf::Arena arena; - ASSERT_OK_AND_ASSIGN( - auto function_decl, - MakeFunctionDecl("h.(e),l\\o", - MakeMemberOverloadDecl( - StringType{}, StringType{}, - ListType(&arena, TypeParamType("a,b..(d)\\e"))))); - - EXPECT_THAT(function_decl.overloads(), - ElementsAre(Property(&OverloadDecl::id, - "string.h\\.\\(e\\)\\,l\\\\\\o(list<" - "a\\,b.\\.\\(d\\)\\\\e>)"))); -} - using common_internal::TypeIsAssignable; TEST(TypeIsAssignable, BoolWrapper) { diff --git a/common/internal/BUILD b/common/internal/BUILD index c5ca63564..10084b685 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -137,3 +137,34 @@ cc_library( "@com_google_protobuf//src/google/protobuf/io", ], ) + +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + "//common:type", + "//common:type_kind", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":signature", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/internal/signature.cc b/common/internal/signature.cc new file mode 100644 index 000000000..f63049878 --- /dev/null +++ b/common/internal/signature.cc @@ -0,0 +1,211 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/signature.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" + +namespace cel::common_internal { + +namespace { + +void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { + for (char c : str) { + switch (c) { + case '\\': + case '(': + case ')': + case '<': + case '>': + case '"': + case ',': + case '~': + result->push_back('\\'); + break; + case '.': + if (escape_dot) { + result->push_back('\\'); + } + break; + } + result->push_back(c); + } +} + +absl::Status AppendTypeParameters(std::string* result, const Type& type); + +// Recursively appends a string representation of the given `type` to `result`. +// Type parameters are enclosed in angle brackets and separated by commas. + +// Grammar: +// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; +// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; +// TypeList = TypeElem { "," TypeElem } ; +// TypeElem = TypeDesc | TypeParam +// TypeParam = "~" Alpha ; +// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; +// (* Terminals *) +// Alpha = "a"..."z" | "A"..."Z" ; +// Digit = "0"..."9" ; +// AlphaNumeric = Alpha | Digit ; +// +// For compatibility, the implementation allows unexpected characters in +// type names and parameters and escapes them with a backslash. +absl::Status AppendTypeDesc(std::string* result, const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + absl::StrAppend(result, "null"); + break; + case TypeKind::kBool: + absl::StrAppend(result, "bool"); + break; + case TypeKind::kInt: + absl::StrAppend(result, "int"); + break; + case TypeKind::kUint: + absl::StrAppend(result, "uint"); + break; + case TypeKind::kDouble: + absl::StrAppend(result, "double"); + break; + case TypeKind::kString: + absl::StrAppend(result, "string"); + break; + case TypeKind::kBytes: + absl::StrAppend(result, "bytes"); + break; + case TypeKind::kDuration: + absl::StrAppend(result, "duration"); + break; + case TypeKind::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case TypeKind::kAny: + absl::StrAppend(result, "any"); + break; + case TypeKind::kDyn: + absl::StrAppend(result, "dyn"); + break; + case TypeKind::kBoolWrapper: + absl::StrAppend(result, "bool_wrapper"); + break; + case TypeKind::kIntWrapper: + absl::StrAppend(result, "int_wrapper"); + break; + case TypeKind::kUintWrapper: + absl::StrAppend(result, "uint_wrapper"); + break; + case TypeKind::kDoubleWrapper: + absl::StrAppend(result, "double_wrapper"); + break; + case TypeKind::kStringWrapper: + absl::StrAppend(result, "string_wrapper"); + break; + case TypeKind::kBytesWrapper: + absl::StrAppend(result, "bytes_wrapper"); + break; + case TypeKind::kList: + absl::StrAppend(result, "list"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kMap: + absl::StrAppend(result, "map"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kFunction: + absl::StrAppend(result, "function"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kType: + absl::StrAppend(result, "type"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kTypeParam: + absl::StrAppend(result, "~"); + AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); + break; + case TypeKind::kOpaque: + AppendEscaped(result, type.name(), /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kStruct: + AppendEscaped(result, type.name(), /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Type kind: %s is not supported in CEL declarations", + type.DebugString())); + } + return absl::OkStatus(); +} + +absl::Status AppendTypeParameters(std::string* result, const Type& type) { + const auto& parameters = type.GetParameters(); + if (!parameters.empty()) { + result->push_back('<'); + for (size_t i = 0; i < parameters.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); + if (i < parameters.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); + } + return absl::OkStatus(); +} +} // namespace + +absl::StatusOr MakeTypeSignature(const Type& type) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + return result; +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::string result; + if (is_member) { + if (!args.empty()) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[0])); + } else { + return absl::InvalidArgumentError("Member function with no receiver"); + } + result.push_back('.'); + } + AppendEscaped(&result, function_name, /*escape_dot=*/true); + result.push_back('('); + for (size_t i = is_member ? 1 : 0; i < args.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[i])); + if (i < args.size() - 1) { + result.push_back(','); + } + } + result.push_back(')'); + + return result; +} +} // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h new file mode 100644 index 000000000..3f31d8fd1 --- /dev/null +++ b/common/internal/signature.h @@ -0,0 +1,61 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "common/type.h" + +namespace cel::common_internal { + +// Generates an signature for a `cel::Type`, which is a string representation of +// the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSignature(const Type& type); + +// Generates an identifier for a function overload based on the function name +// and the types of the arguments. If `is_member` is true, the first argument +// type is used as the receiver and is prepended to the function name, followed +// by a dollar sign. +// +// Examples: +// +// - `foo()` +// - `foo(int)` +// - `bar.foo(int)` +// - `foo(int,string)` +// - `foo(list,list)` +// - `bar.foo(list,list>)` +// +// If the function name contains a period, it is escaped with a backslash, e.g. +// `foo.bar` becomes `foo\.bar`. This allows to disambiguate between a member +// function and qualified target type name. +// +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc new file mode 100644 index 000000000..8e41c70fb --- /dev/null +++ b/common/internal/signature_test.cc @@ -0,0 +1,249 @@ +#include "common/internal/signature.h" +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +struct TypeSignatureTestCase { + Type type; + std::string expected_signature; + std::string expected_error; +}; + +using TypeSignatureTest = testing::TestWithParam; + +TEST_P(TypeSignatureTest, TypeSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + common_internal::MakeTypeSignature(param.type); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetTypeSignatureTestCases() { + return { + { + .type = StringType{}, + .expected_signature = "string", + }, + { + .type = IntType{}, + .expected_signature = "int", + }, + { + .type = ListType(GetTestArena(), StringType{}), + .expected_signature = "list", + }, + { + .type = ListType(GetTestArena(), TypeParamType("A")), + .expected_signature = "list<~A>", + }, + { + .type = MapType(GetTestArena(), IntType{}, DynType{}), + .expected_signature = "map", + }, + { + .type = + MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C")), + .expected_signature = "map<~B,~C>", + }, + { + .type = OpaqueType( + GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), {})}), + .expected_signature = "bar>", + }, + { + .type = AnyType{}, + .expected_signature = "any", + }, + { + .type = DurationType{}, + .expected_signature = "duration", + }, + { + .type = TimestampType{}, + .expected_signature = "timestamp", + }, + { + .type = IntWrapperType{}, + .expected_signature = "int_wrapper", + }, + { + .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = UnknownType{}, + .expected_error = + "Type kind: *unknown* is not supported in CEL declarations", + }, + { + .type = ErrorType{}, + .expected_error = + "Type kind: *error* is not supported in CEL declarations", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + +struct OverloadSignatureTestCase { + std::string function_name = "hello"; + std::vector args; + bool is_member = false; + std::string expected_signature; + std::string expected_error; +}; + +using OverloadSignatureTest = testing::TestWithParam; + +TEST_P(OverloadSignatureTest, OverloadSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + common_internal::MakeOverloadSignature(param.function_name, param.args, + param.is_member); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetOverloadSignatureTestCases() { + return { + { + .args = {StringType{}}, + .expected_signature = "hello(string)", + }, + { + .args = {IntType{}, UintType{}}, + .expected_signature = "hello(int,uint)", + }, + { + .args = {ListType(GetTestArena(), StringType{})}, + .expected_signature = "hello(list)", + }, + { + .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .expected_signature = "hello(list<~A>)", + }, + { + .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .expected_signature = "hello(map)", + }, + { + .args = {MapType(GetTestArena(), TypeParamType("B"), + TypeParamType("C"))}, + .expected_signature = "hello(map<~B,~C>)", + }, + { + .args = {OpaqueType( + GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .expected_signature = "hello(bar>)", + }, + { + .args = {AnyType{}}, + .expected_signature = "hello(any)", + }, + { + .args = {DurationType{}}, + .expected_signature = "hello(duration)", + }, + { + .args = {TimestampType{}}, + .expected_signature = "hello(timestamp)", + }, + { + .args = {IntWrapperType{}}, + .expected_signature = "hello(int_wrapper)", + }, + { + .args = {MessageType( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"))}, + .expected_signature = + "hello(cel.expr.conformance.proto3.TestAllTypes)", + }, + {.args = {}, + .is_member = true, + .expected_error = "Member function with no receiver"}, + { + .args = {StringType{}}, + .is_member = true, + .expected_signature = "string.hello()", + }, + { + .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .is_member = true, + .expected_signature = "string.hello(list)", + }, + { + .args = {StringType{}, BoolType{}, DynType{}}, + .is_member = true, + .expected_signature = "string.hello(bool,dyn)", + }, + { + .function_name = R"(h.(e),l\o)", + .args = {StringType{}, + ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .is_member = true, + .expected_signature = + R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, + ValuesIn(GetOverloadSignatureTestCases())); + +} // namespace +} // namespace cel::common_internal diff --git a/env/type_info.cc b/env/type_info.cc index ed72a842f..a5b47b6f1 100644 --- a/env/type_info.cc +++ b/env/type_info.cc @@ -59,11 +59,17 @@ std::optional TypeNameToTypeKind(absl::string_view type_name) { {"any", TypeKind::kAny}, {"dyn", TypeKind::kDyn}, {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {"bool_wrapper", TypeKind::kBoolWrapper}, {IntWrapperType::kName, TypeKind::kIntWrapper}, + {"int_wrapper", TypeKind::kIntWrapper}, {UintWrapperType::kName, TypeKind::kUintWrapper}, + {"uint_wrapper", TypeKind::kUintWrapper}, {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {"double_wrapper", TypeKind::kDoubleWrapper}, {StringWrapperType::kName, TypeKind::kStringWrapper}, + {"string_wrapper", TypeKind::kStringWrapper}, {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"bytes_wrapper", TypeKind::kBytesWrapper}, {"type", TypeKind::kType}, }); if (auto it = kTypeNameToTypeKind->find(type_name); diff --git a/env/type_info_test.cc b/env/type_info_test.cc index ca9d0467c..015d8a928 100644 --- a/env/type_info_test.cc +++ b/env/type_info_test.cc @@ -105,6 +105,10 @@ std::vector GetTestCases() { .type_info = {.name = "google.protobuf.DoubleValue"}, .expected_type_pb = "wrapper: DOUBLE", }, + TestCase{ + .type_info = {.name = "double_wrapper"}, + .expected_type_pb = "wrapper: DOUBLE", + }, TestCase{ .type_info = {.name = "type", .params = {Config::TypeInfo{.name = "duration"}}}, From c4a927e63bc4cb9f3ad4b61e37f4456c2d883380 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 24 Apr 2026 14:24:47 -0700 Subject: [PATCH 051/143] Add ToBuilder() functions for parser/checker/compiler Adds support for extending a configured compiler similar to cel-java. PiperOrigin-RevId: 905220235 --- checker/BUILD | 1 + checker/internal/BUILD | 2 + checker/internal/type_check_env.h | 17 ++++- checker/internal/type_checker_builder_impl.cc | 30 +++++--- checker/internal/type_checker_builder_impl.h | 31 +++++--- checker/internal/type_checker_impl.cc | 6 ++ checker/internal/type_checker_impl.h | 3 + checker/internal/type_checker_impl_test.cc | 39 ++++++++++ checker/type_checker.h | 5 ++ checker/type_checker_builder.h | 1 - checker/type_checker_builder_factory_test.cc | 59 +++++++++++++++ compiler/BUILD | 1 + compiler/compiler.h | 13 +++- compiler/compiler_factory.cc | 14 +++- compiler/compiler_factory_test.cc | 26 +++++++ parser/macro_registry.cc | 10 +++ parser/macro_registry.h | 4 + parser/parser.cc | 27 ++++++- parser/parser_interface.h | 3 + parser/parser_test.cc | 75 +++++++++++++++++++ 20 files changed, 335 insertions(+), 32 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index 7b151d6a8..f1e0cef3c 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -123,6 +123,7 @@ cc_test( srcs = ["type_checker_builder_factory_test.cc"], deps = [ ":checker_options", + ":optional", ":standard_library", ":type_checker", ":type_checker_builder", diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1af48af57..1c560cdb9 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -74,6 +74,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -175,6 +176,7 @@ cc_test( ":type_checker_impl", "//checker:checker_options", "//checker:type_check_issue", + "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:container", diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 491e4b550..15f8ecc4d 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -23,6 +23,7 @@ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -100,7 +101,8 @@ class TypeCheckEnv { type_providers_.push_back(proto_type_introspector_); } - // Move-only. + TypeCheckEnv(const TypeCheckEnv&) = default; + TypeCheckEnv& operator=(const TypeCheckEnv&) = default; TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; @@ -193,11 +195,15 @@ class TypeCheckEnv { // Used to keep an arena alive if one was needed to allocate types. // - // The TypeCheckEnv does not otherwise use it. - void set_arena(std::shared_ptr arena) { + // Expected to be called exactly once if at all. + void set_arena(std::shared_ptr arena) { + ABSL_DCHECK(arena_ == nullptr || arena == arena_); arena_ = std::move(arena); } + // Returns the arena if one was set, nullptr otherwise. + std::shared_ptr arena() const { return arena_; } + private: absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; @@ -205,7 +211,10 @@ class TypeCheckEnv { absl_nonnull std::shared_ptr descriptor_pool_; // If set, an arena was needed to allocate types in the environment. - absl_nullable std::shared_ptr arena_; + // + // The TypeCheckEnv does not otherwise use the arena, though it may be used by + // derived TypeCheckerBuilders. + absl_nullable std::shared_ptr arena_; ExpressionContainer container_; // Used to resolve fields on message types. diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 9ebcb4e34..94a05602e 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -34,6 +34,7 @@ #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -342,8 +343,16 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } absl::StatusOr> TypeCheckerBuilderImpl::Build() { - TypeCheckEnv env(descriptor_pool_); - env.set_container(expression_container_); + TypeCheckEnv env(template_env_); + CEL_RETURN_IF_ERROR(ConfigureTypeCheckEnv(env)); + return std::make_unique(std::move(env), + options_); +} + +absl::Status TypeCheckerBuilderImpl::ConfigureTypeCheckEnv(TypeCheckEnv& env) { + if (expression_container_.has_value()) { + env.set_container(*expression_container_); + } if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } @@ -377,12 +386,10 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { /*subset=*/nullptr, env)); CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, /*subset=*/nullptr, env)); - // A library may have been the first to initialize the arena, so we need to - // set it as the last step. - env.set_arena(arena_); - auto checker = std::make_unique( - std::move(env), options_); - return checker; + if (type_arena_ != nullptr) { + env.set_arena(type_arena_); + } + return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { @@ -432,7 +439,7 @@ absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( absl::string_view type) { const google::protobuf::Descriptor* desc = - descriptor_pool_->FindMessageTypeByName(type); + template_env_.descriptor_pool()->FindMessageTypeByName(type); if (desc == nullptr) { return absl::NotFoundError( absl::StrCat("context declaration '", type, "' not found")); @@ -479,7 +486,10 @@ void TypeCheckerBuilderImpl::AddTypeProvider( } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { - expression_container_.SetContainer(container).IgnoreError(); + if (!expression_container_.has_value()) { + expression_container_.emplace(); + } + expression_container_->SetContainer(container).IgnoreError(); } void TypeCheckerBuilderImpl::SetExpressionContainer( diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 7a099040b..646a5d16f 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -40,8 +40,6 @@ namespace cel::checker_internal { -class TypeCheckerBuilderImpl; - // Builder for TypeChecker instances. class TypeCheckerBuilderImpl : public TypeCheckerBuilder { public: @@ -51,7 +49,18 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options) : options_(options), target_config_(&default_config_), - descriptor_pool_(std::move(descriptor_pool)) {} + template_env_(std::move(descriptor_pool)) {} + + // Constructor for building an extended TypeChecker. + explicit TypeCheckerBuilderImpl(const CheckerOptions& options, + const TypeCheckEnv& template_env) + : options_(options), + target_config_(&default_config_), + template_env_(template_env) { + if (auto arena = template_env_.arena(); arena != nullptr) { + type_arena_ = std::move(arena); + } + } // Move only. TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; @@ -83,14 +92,14 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { - if (arena_ == nullptr) { - arena_ = std::make_shared(); + if (type_arena_ == nullptr) { + type_arena_ = std::make_shared(); } - return arena_.get(); + return type_arena_.get(); } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const override { - return descriptor_pool_.get(); + return template_env_.descriptor_pool(); } private: @@ -129,6 +138,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status ApplyConfig(ConfigRecord config, const TypeCheckerSubset* subset, TypeCheckEnv& env); + absl::Status ConfigureTypeCheckEnv(TypeCheckEnv& env); + CheckerOptions options_; // Default target for configuration changes. Used for direct calls to // AddVariable, AddFunction, etc. @@ -136,12 +147,12 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { // Active target for configuration changes. // This is used to track which library the change is made on behalf of. ConfigRecord* absl_nonnull target_config_; - std::shared_ptr descriptor_pool_; - std::shared_ptr arena_; + TypeCheckEnv template_env_; + std::shared_ptr type_arena_; std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; - ExpressionContainer expression_container_; + absl::optional expression_container_; absl::optional expected_type_; }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6697cc06a..8f67efbde 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -37,8 +37,10 @@ #include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" +#include "checker/internal/type_checker_builder_impl.h" #include "checker/internal/type_inference_context.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_rewrite.h" @@ -1326,4 +1328,8 @@ absl::StatusOr TypeCheckerImpl::Check( return ValidationResult(std::move(ast), std::move(issues)); } +std::unique_ptr TypeCheckerImpl::ToBuilder() const { + return std::make_unique(options_, env_); +} + } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 1b9062ec1..71683276d 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -22,6 +22,7 @@ #include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "google/protobuf/arena.h" @@ -44,6 +45,8 @@ class TypeCheckerImpl : public TypeChecker { absl::StatusOr Check( std::unique_ptr ast) const override; + std::unique_ptr ToBuilder() const override; + private: TypeCheckEnv env_; google::protobuf::Arena type_arena_; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index a82dd58ee..e6cd641d6 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -33,6 +33,7 @@ #include "checker/internal/test_ast_helpers.h" #include "checker/internal/type_check_env.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/container.h" @@ -1413,6 +1414,44 @@ TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { "expected type 'map(string, string)' but found 'map(string, int)'"))); } +TEST(TypeCheckerImplTest, ToBuilder) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + TypeCheckerImpl impl(std::move(env)); + auto builder = impl.ToBuilder(); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN(auto new_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + new_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, ToBuilderPropagatesArena) { + auto arena = std::make_shared(); + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_arena(arena); + + Type list_type = ListType(arena.get(), IntType()); + ASSERT_TRUE( + env.InsertVariableIfAbsent(MakeVariableDecl("my_list", list_type))); + + auto base_checker = std::make_unique(std::move(env)); + + std::unique_ptr builder = base_checker->ToBuilder(); + + base_checker.reset(); + arena.reset(); + + ASSERT_OK_AND_ASSIGN(auto derived_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("my_list")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + derived_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/checker/type_checker.h b/checker/type_checker.h index 993eafb71..e47b7dca6 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -23,6 +23,8 @@ namespace cel { +class TypeCheckerBuilder; + // TypeChecker interface. // // Checks references and type agreement for a parsed CEL expression. @@ -43,6 +45,9 @@ class TypeChecker { virtual absl::StatusOr Check( std::unique_ptr ast) const = 0; + // Returns a builder initialized with the configuration of this type checker. + virtual std::unique_ptr ToBuilder() const = 0; + // TODO(uncreated-issue/73): add overload for cref AST. }; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index b3a86f64c..5dd1f5256 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -35,7 +35,6 @@ namespace cel { class TypeCheckerBuilder; -class TypeCheckerBuilderImpl; // Functional implementation to apply the library features to a // TypeCheckerBuilder. diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 030186c83..38430de5f 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -23,6 +23,7 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" +#include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -743,5 +744,63 @@ TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); } +TEST(TypeCheckerBuilderTest, ToBuilderIndependenceAndInheritance) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("addOne", + MakeOverloadDecl("addOne_int", IntType(), IntType()))); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); + + // Exercise checker1. + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("addOne(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result1, + checker1->Check(std::move(ast))); + EXPECT_TRUE(result1.IsValid()); + } + + // Start new builder via ToBuilder. + auto builder2 = checker1->ToBuilder(); + ASSERT_THAT(builder2->AddVariable(MakeVariableDecl("y", IntType())), IsOk()); + ASSERT_THAT(builder2->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder2->SetExpectedType(IntType()); + + ASSERT_OK_AND_ASSIGN(auto checker2, builder2->Build()); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("optional.of(addOne(x)).orValue(0) + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker2->Check(std::move(ast))); + EXPECT_TRUE(result2.IsValid()); + } + + // Demonstrate checker1 is unmodified and independent (still does not know + // about y). + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result_y_checker1_again, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result_y_checker1_again.IsValid()); + } + + // Same for optional library functions. + { + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("optional.none().orValue(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + } +} + } // namespace } // namespace cel diff --git a/compiler/BUILD b/compiler/BUILD index 44ef4f537..170f1068b 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -66,6 +66,7 @@ cc_test( deps = [ ":compiler", ":compiler_factory", + ":optional", ":standard_library", "//checker:optional", "//checker:standard_library", diff --git a/compiler/compiler.h b/compiler/compiler.h index 6178cf2dc..48fa4e0b1 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -99,8 +99,8 @@ struct CompilerOptions { // Interface for CEL CompilerBuilder objects. // -// Builder implementations are thread hostile, but should create -// thread-compatible Compiler instances. +// Builder implementations do not provide any synchronization themselves, +// but create thread-compatible Compiler instances. class CompilerBuilder { public: virtual ~CompilerBuilder() = default; @@ -140,6 +140,15 @@ class Compiler { // Accessor for the underlying validator. virtual const Validator& GetValidator() const = 0; + + // Returns a builder initialized with the configuration of this compiler. + // + // The returned builder is a copy of the validated environment and may + // behave differently than the builder that created this compiler. + // + // The returned builder does not share state with the compiler and may be + // modified independently. + virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index c83633f68..3e9871706 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -65,6 +65,8 @@ class CompilerImpl : public Compiler { return result; } + std::unique_ptr ToBuilder() const override; + const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } const Validator& GetValidator() const override { return validator_; } @@ -78,9 +80,11 @@ class CompilerImpl : public Compiler { class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, - std::unique_ptr parser_builder) + std::unique_ptr parser_builder, + Validator validator = Validator()) : type_checker_builder_(std::move(type_checker_builder)), - parser_builder_(std::move(parser_builder)) {} + parser_builder_(std::move(parser_builder)), + validator_(std::move(validator)) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -154,6 +158,12 @@ class CompilerBuilderImpl : public CompilerBuilder { absl::flat_hash_set subsets_; }; +std::unique_ptr CompilerImpl::ToBuilder() const { + auto builder = std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); + return builder; +} + } // namespace absl::StatusOr> NewCompilerBuilder( diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index cfdc68e26..d217e4cc7 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -29,6 +29,7 @@ #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" +#include "compiler/optional.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -364,5 +365,30 @@ TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { HasSubstr("descriptor_pool must not be null"))); } +TEST(CompilerFactoryTest, ToBuilderWorks) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + auto derived_builder = compiler->ToBuilder(); + + ASSERT_THAT(derived_builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto derived_compiler, derived_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + derived_compiler->Compile("has(a.b) && a.?b.orValue('foo') == 'foo'")); + EXPECT_TRUE(result.IsValid()); +} + } // namespace } // namespace cel diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc index 3fc77f18c..3a816b10e 100644 --- a/parser/macro_registry.cc +++ b/parser/macro_registry.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/match.h" @@ -70,6 +71,15 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, return absl::nullopt; } +std::vector MacroRegistry::ListMacros() const { + std::vector macros; + macros.reserve(macros_.size()); + for (auto it = macros_.begin(); it != macros_.end(); ++it) { + macros.push_back(it->second); + } + return macros; +} + bool MacroRegistry::RegisterMacroImpl(const Macro& macro) { return macros_.insert(std::pair{macro.key(), macro}).second; } diff --git a/parser/macro_registry.h b/parser/macro_registry.h index 51899bade..01a0634ef 100644 --- a/parser/macro_registry.h +++ b/parser/macro_registry.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -44,6 +45,9 @@ class MacroRegistry final { absl::optional FindMacro(absl::string_view name, size_t arg_count, bool receiver_style) const; + // Returns a copy of all registered macros. + std::vector ListMacros() const; + private: bool RegisterMacroImpl(const Macro& macro); diff --git a/parser/parser.cc b/parser/parser.cc index d430e3169..d9f74e712 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1707,8 +1707,12 @@ absl::StatusOr ParseImpl(const cel::Source& source, class ParserImpl : public cel::Parser { public: explicit ParserImpl(const ParserOptions& options, - cel::MacroRegistry macro_registry) - : options_(options), macro_registry_(std::move(macro_registry)) {} + cel::MacroRegistry macro_registry, + absl::flat_hash_set library_ids) + : options_(options), + macro_registry_(std::move(macro_registry)), + library_ids_(std::move(library_ids)) {} + absl::StatusOr> Parse( const cel::Source& source) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, @@ -1717,9 +1721,12 @@ class ParserImpl : public cel::Parser { std::move(parse_result.source_info)); } + std::unique_ptr ToBuilder() const override; + private: const ParserOptions options_; const cel::MacroRegistry macro_registry_; + absl::flat_hash_set library_ids_; }; class ParserBuilderImpl : public cel::ParserBuilder { @@ -1796,21 +1803,28 @@ class ParserBuilderImpl : public cel::ParserBuilder { macros_.clear(); } + absl::flat_hash_set library_ids(library_ids_); + // Hack to support adding the standard library macros either by option or // with a library configurer. if (!options_.disable_standard_macros && !library_ids_.contains("stdlib")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); + library_ids.insert("stdlib"); } if (options_.enable_optional_syntax && !library_ids_.contains("optional")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); + library_ids.insert("optional"); } CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(individual_macros)); - return std::make_unique(options_, std::move(macro_registry)); + return std::make_unique(options_, std::move(macro_registry), + std::move(library_ids)); } private: + friend class ParserImpl; + ParserOptions options_; std::vector macros_; absl::flat_hash_set library_ids_; @@ -1818,6 +1832,13 @@ class ParserBuilderImpl : public cel::ParserBuilder { absl::flat_hash_map library_subsets_; }; +std::unique_ptr ParserImpl::ToBuilder() const { + auto ins = std::make_unique(options_); + ins->library_ids_ = library_ids_; + ins->macros_ = macro_registry_.ListMacros(); + return ins; +} + } // namespace absl::StatusOr Parse(absl::string_view expression, diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 0992385f7..7cc21ff26 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -83,6 +83,9 @@ class Parser { // Parses the given source into a CEL AST. virtual absl::StatusOr> Parse( const cel::Source& source) const = 0; + + // Returns a builder initialized with the configuration of this parser. + virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel diff --git a/parser/parser_test.cc b/parser/parser_test.cc index c96845e67..3659fd8fd 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1973,6 +1973,81 @@ TEST(NewParserBuilderTest, ForwardsOptions) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(NewParserBuilderTest, ToBuilderCopiesConfig) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddLibrary({"custom_lib", + [](cel::ParserBuilder& b) { + return b.AddMacro(cel::HasMacro()); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + EXPECT_TRUE(derived_builder->GetOptions().enable_optional_syntax); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b && has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, ToBuilderHandlesStdlibAndOptionalByLibrary) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + builder->GetOptions().enable_optional_syntax = false; + + // Abusing the library ids for testing. Real uses should use subsetting. + ASSERT_THAT( + builder->AddLibrary( + {"stdlib", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + ASSERT_THAT( + builder->AddLibrary( + {"optional", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + // Should be ignored now. + derived_builder->GetOptions().disable_standard_macros = false; + derived_builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#1:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = false; + builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [?a]")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From f8ed1bdd658435dee2de9dfd7712a57cf5fe07f2 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 24 Apr 2026 15:16:43 -0700 Subject: [PATCH 052/143] Remove invalid 'size' arg, use alignment as size. PiperOrigin-RevId: 905243616 --- internal/new.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/new.cc b/internal/new.cc index 05396e624..31ec82a08 100644 --- a/internal/new.cc +++ b/internal/new.cc @@ -114,7 +114,7 @@ void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { - SizedDelete(ptr, size); + ::operator delete(ptr); } else { #if defined(_MSC_VER) _aligned_free(ptr); From 23c9913ef7f8ed8835165fc6c65617e3090f1630 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 28 Apr 2026 16:26:10 -0700 Subject: [PATCH 053/143] Add option to specify an arena and persist the resolved types from type checking. PiperOrigin-RevId: 907236091 --- checker/BUILD | 5 ++++ checker/internal/type_checker_impl.cc | 34 +++++++++++++++++++------ checker/internal/type_checker_impl.h | 4 +-- checker/type_checker.cc | 36 +++++++++++++++++++++++++++ checker/type_checker.h | 15 ++++++++--- checker/validation_result.h | 20 ++++++++++++++- compiler/BUILD | 2 ++ compiler/compiler.h | 12 +++++++-- compiler/compiler_factory.cc | 7 +++--- compiler/compiler_factory_test.cc | 23 +++++++++++++++++ 10 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 checker/type_checker.cc diff --git a/checker/BUILD b/checker/BUILD index f1e0cef3c..27a1eb84e 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -50,7 +50,9 @@ cc_library( ":type_check_issue", "//common:ast", "//common:source", + "//common:type", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -74,11 +76,14 @@ cc_test( cc_library( name = "type_checker", + srcs = ["type_checker.cc"], hdrs = ["type_checker.h"], deps = [ ":validation_result", "//common:ast", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 8f67efbde..05601fdbb 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -1176,11 +1177,13 @@ class ResolveRewriter : public AstRewriterBase { explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, - Ast::ReferenceMap& references, Ast::TypeMap& types) + Ast::ReferenceMap& references, Ast::TypeMap& types, + ValidationResult::TypeMap& resolved_types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), type_map_(types), + resolved_types_(resolved_types), options_(options) {} bool PostVisitRewrite(Expr& expr) override { bool rewritten = false; @@ -1235,6 +1238,7 @@ class ResolveRewriter : public AstRewriterBase { return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); + resolved_types_[expr.id()] = iter->second; rewritten = true; } @@ -1249,23 +1253,28 @@ class ResolveRewriter : public AstRewriterBase { const TypeInferenceContext& inference_context_; Ast::ReferenceMap& reference_map_; Ast::TypeMap& type_map_; + ValidationResult::TypeMap& resolved_types_; const CheckerOptions& options_; }; } // namespace -absl::StatusOr TypeCheckerImpl::Check( - std::unique_ptr ast) const { - google::protobuf::Arena type_arena; +absl::StatusOr TypeCheckerImpl::CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + std::optional type_arena; + if (arena == nullptr) { + type_arena.emplace(); + arena = &(*type_arena); + } std::vector issues; CEL_ASSIGN_OR_RETURN(auto generator, NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( - &type_arena, options_.enable_legacy_null_assignment); + arena, options_.enable_legacy_null_assignment); ResolveVisitor visitor(std::move(generator), env_, *ast, - type_inference_context, issues, &type_arena); + type_inference_context, issues, arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; @@ -1310,9 +1319,10 @@ absl::StatusOr TypeCheckerImpl::Check( // Apply updates as needed. // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. + ValidationResult::TypeMap resolved_types; ResolveRewriter rewriter(visitor, type_inference_context, options_, ast->mutable_reference_map(), - ast->mutable_type_map()); + ast->mutable_type_map(), resolved_types); AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); @@ -1325,7 +1335,15 @@ absl::StatusOr TypeCheckerImpl::Check( {cel::ExtensionSpec::Component::kRuntime})); } - return ValidationResult(std::move(ast), std::move(issues)); + auto result = ValidationResult(std::move(ast), std::move(issues)); + if (!type_arena.has_value()) { + // cel::Type values will expire after this function returns when the local + // arena is destructed. Only set the resolved type map if we're using the + // caller's arena. + result.SetResolvedTypeMap(std::move(resolved_types)); + } + + return result; } std::unique_ptr TypeCheckerImpl::ToBuilder() const { diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 71683276d..9ee9a50d0 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -42,8 +42,8 @@ class TypeCheckerImpl : public TypeChecker { TypeCheckerImpl(TypeCheckerImpl&&) = delete; TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; - absl::StatusOr Check( - std::unique_ptr ast) const override; + absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const override; std::unique_ptr ToBuilder() const override; diff --git a/checker/type_checker.cc b/checker/type_checker.cc new file mode 100644 index 000000000..6d59e144d --- /dev/null +++ b/checker/type_checker.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker.h" + +namespace cel { +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast) const { + return CheckImpl(std::move(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::move(ast), arena); +} + +absl::StatusOr TypeChecker::Check(const Ast& ast) const { + return CheckImpl(std::make_unique(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + const Ast& ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::make_unique(ast), arena); +} +} // namespace cel diff --git a/checker/type_checker.h b/checker/type_checker.h index e47b7dca6..edb6cc91f 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -16,10 +16,13 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #include +#include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "google/protobuf/arena.h" namespace cel { @@ -42,13 +45,19 @@ class TypeChecker { // A non-ok status is returned if type checking can't reasonably complete // (e.g. if an internal precondition is violated or an extension returns an // error). - virtual absl::StatusOr Check( - std::unique_ptr ast) const = 0; + absl::StatusOr Check(std::unique_ptr ast) const; + absl::StatusOr Check(std::unique_ptr ast, + google::protobuf::Arena* arena) const; + absl::StatusOr Check(const Ast& ast) const; + absl::StatusOr Check(const Ast& ast, + google::protobuf::Arena* arena) const; // Returns a builder initialized with the configuration of this type checker. virtual std::unique_ptr ToBuilder() const = 0; - // TODO(uncreated-issue/73): add overload for cref AST. + private: + virtual absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; }; } // namespace cel diff --git a/checker/validation_result.h b/checker/validation_result.h index 8c84a84da..f424e7f6f 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -15,26 +15,31 @@ #ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ +#include #include #include #include #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/source.h" +#include "common/type.h" namespace cel { -// ValidationResult holds the result of TypeChecking. +// ValidationResult holds the result of type checking. // // Error states are captured as type check issues where possible. class ValidationResult { public: + using TypeMap = absl::flat_hash_map; + ValidationResult(std::unique_ptr ast, std::vector issues) : ast_(std::move(ast)), issues_(std::move(issues)) {} @@ -71,6 +76,18 @@ class ValidationResult { return std::move(source_); } + // Returns the resolved type map for the AST. + // + // Only populated if the AST was checked with an explicit arena. + // + // The type entries may have storage in the arena or reference type + // information from the type checker that produced the AST. This means the map + // is only valid as long as both the type checker and the arena are valid. + const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } + void SetResolvedTypeMap(TypeMap resolved_type_map) { + resolved_type_map_ = std::move(resolved_type_map); + } + // Returns a string representation of the issues in the result suitable for // display. // @@ -89,6 +106,7 @@ class ValidationResult { private: absl_nullable std::unique_ptr ast_; + TypeMap resolved_type_map_; std::vector issues_; absl_nullable std::unique_ptr source_; }; diff --git a/compiler/BUILD b/compiler/BUILD index 170f1068b..50bc1c9fa 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -28,9 +28,11 @@ cc_library( "//parser:options", "//parser:parser_interface", "//validator", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) diff --git a/compiler/compiler.h b/compiler/compiler.h index 48fa4e0b1..6d07e72c2 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -19,6 +19,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -29,6 +30,7 @@ #include "parser/options.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" namespace cel { @@ -126,10 +128,16 @@ class Compiler { virtual ~Compiler() = default; virtual absl::StatusOr Compile( - absl::string_view source, absl::string_view description) const = 0; + absl::string_view source, absl::string_view description, + google::protobuf::Arena* absl_nullable arena) const = 0; absl::StatusOr Compile(absl::string_view source) const { - return Compile(source, ""); + return Compile(source, "", nullptr); + } + + absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const { + return Compile(source, description, nullptr); } // Accessor for the underlying type checker. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 3e9871706..14586825e 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -33,6 +33,7 @@ #include "parser/parser.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -50,13 +51,13 @@ class CompilerImpl : public Compiler { validator_(std::move(validator)) {} absl::StatusOr Compile( - absl::string_view expression, - absl::string_view description) const override { + absl::string_view expression, absl::string_view description, + google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast))); + type_checker_->Check(std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index d217e4cc7..214c23765 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -37,6 +37,7 @@ #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" #include "validator/timestamp_literal_validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -390,5 +391,27 @@ TEST(CompilerFactoryTest, ToBuilderWorks) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); + ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); + EXPECT_TRUE( + it->second.IsOptional() && + it->second.GetOptional().GetParameter().IsList() && + it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); +} + } // namespace } // namespace cel From 9e73d93f77a159a149edd9a465d092cff03d702a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 30 Apr 2026 13:04:30 -0700 Subject: [PATCH 054/143] Deprecate option to switch to old accumulator var name in parser. cel::ParserOption::enable_hidden_accumulator_var now has no effect and will be removed in a later change. The standard / extension macros should always use `@result` now. PiperOrigin-RevId: 908333116 --- common/expr.h | 4 +- common/expr_factory.h | 1 - extensions/comprehensions_v2_macros.cc | 72 +++---- parser/macro.cc | 44 ++-- parser/macro_expr_factory.h | 3 +- parser/macro_expr_factory_test.cc | 2 +- parser/options.h | 4 +- parser/parser.cc | 19 +- parser/parser_test.cc | 270 ------------------------- 9 files changed, 72 insertions(+), 347 deletions(-) diff --git a/common/expr.h b/common/expr.h index 9c6f508c6..7305c2c9f 100644 --- a/common/expr.h +++ b/common/expr.h @@ -45,7 +45,9 @@ class MapExprEntry; class MapExpr; class ComprehensionExpr; -inline constexpr absl::string_view kAccumulatorVariableName = "__result__"; +inline constexpr absl::string_view kAccumulatorVariableName = "@result"; +inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = + "__result__"; bool operator==(const Expr& lhs, const Expr& rhs); diff --git a/common/expr_factory.h b/common/expr_factory.h index c8a9b831f..b9769b457 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -357,7 +357,6 @@ class ExprFactory { friend class ParserMacroExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} - explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {} std::string accu_var_; }; diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index a8de3a103..134fb80ff 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -56,15 +56,15 @@ absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, args[0], "all() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("all() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -102,15 +102,15 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, args[0], "exists() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("exists() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -153,15 +153,15 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, "existsOne() second variable must be different " "from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("existsOne() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("existsOne() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -205,15 +205,15 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -254,15 +254,15 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -305,15 +305,15 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -353,15 +353,15 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -403,17 +403,17 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -453,17 +453,17 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); diff --git a/parser/macro.cc b/parser/macro.cc index eaa1ebd1a..8f8c9e596 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -91,10 +91,10 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("all() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -123,10 +123,10 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -157,10 +157,10 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists_one() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -196,10 +196,10 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -229,10 +229,10 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -264,10 +264,10 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("filter() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto name = args[0].ident_expr().name(); @@ -302,10 +302,10 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); @@ -341,10 +341,10 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optFlatMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index ffba5e2f2..c66aa4fe0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -319,8 +319,7 @@ class MacroExprFactory : protected ExprFactory { friend class ParserMacroExprFactory; friend class TestMacroExprFactory; - explicit MacroExprFactory(absl::string_view accu_var) - : ExprFactory(accu_var) {} + explicit MacroExprFactory() = default; }; } // namespace cel diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 04705eec6..489538be1 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -27,7 +27,7 @@ namespace cel { class TestMacroExprFactory final : public MacroExprFactory { public: - TestMacroExprFactory() : MacroExprFactory(kAccumulatorVariableName) {} + TestMacroExprFactory() = default; ExprId id() const { return id_; } diff --git a/parser/options.h b/parser/options.h index a41d16104..916a941f0 100644 --- a/parser/options.h +++ b/parser/options.h @@ -51,7 +51,9 @@ struct ParserOptions final { // Disable standard macros (has, all, exists, exists_one, filter, map). bool disable_standard_macros = false; - // Enable hidden accumulator variable '@result' for builtin comprehensions. + // Deprecated: The builtin and extension macros now always use the new + // accumulator variable name. + // This option has no effect. bool enable_hidden_accumulator_var = true; // Enables support for identifier quoting syntax: diff --git a/parser/parser.cc b/parser/parser.cc index d9f74e712..f4ee3a1c5 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -163,9 +163,8 @@ SourceRange SourceRangeFromParserRuleContext( class ParserMacroExprFactory final : public MacroExprFactory { public: - explicit ParserMacroExprFactory(const cel::Source& source, - absl::string_view accu_var) - : MacroExprFactory(accu_var), source_(source) {} + explicit ParserMacroExprFactory(const cel::Source& source) + : source_(source) {} void BeginMacro(SourceRange macro_position) { macro_position_ = macro_position; @@ -607,13 +606,12 @@ class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, - absl::string_view accu_var, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, bool enable_quoted_identifiers = false) : source_(source), - factory_(source_, accu_var), + factory_(source_), macro_registry_(macro_registry), recursion_depth_(0), max_recursion_depth_(max_recursion_depth), @@ -1654,14 +1652,9 @@ absl::StatusOr ParseImpl(const cel::Source& source, CommonTokenStream tokens(&lexer); CelParser parser(&tokens); ExprRecursionListener listener(options.max_recursion_depth); - absl::string_view accu_var = cel::kAccumulatorVariableName; - if (options.enable_hidden_accumulator_var) { - accu_var = cel::kHiddenAccumulatorVariableName; - } - ParserVisitor visitor(source, options.max_recursion_depth, accu_var, - registry, options.add_macro_calls, - options.enable_optional_syntax, - options.enable_quoted_identifiers); + ParserVisitor visitor( + source, options.max_recursion_depth, registry, options.add_macro_calls, + options.enable_optional_syntax, options.enable_quoted_identifiers); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 3659fd8fd..a1a65481d 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1473,7 +1473,6 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); ParserOptions options; - options.enable_hidden_accumulator_var = true; if (!test_info.M.empty()) { options.add_macro_calls = true; } @@ -1628,271 +1627,6 @@ TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { EXPECT_THAT(result, IsOk()); } -const std::vector& UpdatedAccuVarTestCases() { - static const std::vector* kInstance = new std::vector{ - {"[].exists(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " false^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " !_(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#\n" - " )^#10:Expr.Call#,\n" - " // LoopStep\n" - " _||_(\n" - " __result__^#11:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#12:Expr.Call#,\n" - " // Result\n" - " __result__^#13:Expr.Ident#)^#14:Expr.Comprehension#"}, - {"[].exists_one(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " 0^#7:int64#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " 1^#10:int64#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " _==_(\n" - " __result__^#14:Expr.Ident#,\n" - " 1^#15:int64#\n" - " )^#16:Expr.Call#)^#17:Expr.Comprehension#"}, - {"[].all(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " true^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#,\n" - " // LoopStep\n" - " _&&_(\n" - " __result__^#10:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#4:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#5:Expr.Call#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x > 0, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#10:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#11:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#12:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#7:Expr.Ident#,\n" - " 1^#9:int64#\n" - " )^#8:Expr.Call#\n" - " ]^#13:Expr.CreateList#\n" - " )^#14:Expr.Call#,\n" - " __result__^#15:Expr.Ident#\n" - " )^#16:Expr.Call#,\n" - " // Result\n" - " __result__^#17:Expr.Ident#)^#18:Expr.Comprehension#"}, - {"[].filter(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " x^#3:Expr.Ident#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " __result__^#14:Expr.Ident#)^#15:Expr.Comprehension#"}, - // Maintain restriction on '__result__' variable name until the default is - // changed everywhere. - { - "[].map(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true)\n" - " | ...................^", - }, - { - "[].map(__result__, true, false)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true, false)\n" - " | ...................^", - }, - { - "[].filter(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: filter() variable name cannot be __result__\n" - " | [].filter(__result__, true)\n" - " | ......................^", - }, - { - "[].exists(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: exists() variable name cannot be __result__\n" - " | [].exists(__result__, true)\n" - " | ......................^", - }, - { - "[].all(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: all() variable name cannot be __result__\n" - " | [].all(__result__, true)\n" - " | ...................^", - }, - { - "[].exists_one(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:27: exists_one() variable name cannot be " - "__result__\n" - " | [].exists_one(__result__, true)\n" - " | ..........................^", - }}; - return *kInstance; -} - -class UpdatedAccuVarDisabledTest : public testing::TestWithParam {}; - -TEST_P(UpdatedAccuVarDisabledTest, Parse) { - const TestInfo& test_info = GetParam(); - ParserOptions options; - options.enable_hidden_accumulator_var = false; - if (!test_info.M.empty()) { - options.add_macro_calls = true; - } - - auto result = - EnrichedParse(test_info.I, Macro::AllMacros(), "", options); - if (test_info.E.empty()) { - EXPECT_THAT(result, IsOk()); - } else { - EXPECT_THAT(result, Not(IsOk())); - EXPECT_EQ(test_info.E, result.status().message()); - } - - if (!test_info.P.empty()) { - KindAndIdAdorner kind_and_id_adorner; - ExprPrinter w(kind_and_id_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.L.empty()) { - LocationAdorner location_adorner(result->parsed_expr().source_info()); - ExprPrinter w(location_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.R.empty()) { - EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( - result->enriched_source_info())); - } - - if (!test_info.M.empty()) { - EXPECT_EQ(test_info.M, ConvertMacroCallsToString( - result.value().parsed_expr().source_info())) - << result->parsed_expr().ShortDebugString(); - } -} - TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); @@ -2058,9 +1792,5 @@ std::string TestName(const testing::TestParamInfo& test_info) { INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases), TestName); -INSTANTIATE_TEST_SUITE_P(UpdatedAccuVarTest, UpdatedAccuVarDisabledTest, - testing::ValuesIn(UpdatedAccuVarTestCases()), - TestName); - } // namespace } // namespace google::api::expr::parser From 1cf21eec91baa4181b481ff4bf6e25b9b5e9afe9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 10:35:24 -0700 Subject: [PATCH 055/143] Optionally track structured error messages in parser. Add a parse overload that reports issues with location to an output parameter where reasonable. This is used make the error handling more consistent when using bundled parse + typecheck. PiperOrigin-RevId: 910112013 --- compiler/BUILD | 2 + compiler/compiler.h | 2 + compiler/compiler_factory.cc | 49 +++++++++++++++----- compiler/compiler_factory_test.cc | 14 ++++++ extensions/math_ext_test.cc | 38 +++++----------- parser/BUILD | 2 + parser/parser.cc | 74 +++++++++++++++++++++---------- parser/parser_interface.h | 50 ++++++++++++++++++++- parser/parser_test.cc | 19 ++++++++ 9 files changed, 186 insertions(+), 64 deletions(-) diff --git a/compiler/BUILD b/compiler/BUILD index 50bc1c9fa..d4a0ab4ac 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -42,10 +42,12 @@ cc_library( hdrs = ["compiler_factory.h"], deps = [ ":compiler", + "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", + "//common:ast", "//common:source", "//internal:noop_delete", "//internal:status_macros", diff --git a/compiler/compiler.h b/compiler/compiler.h index 6d07e72c2..27237df60 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -97,6 +97,8 @@ struct CompilerLibrarySubset { struct CompilerOptions { ParserOptions parser_options; CheckerOptions checker_options; + // If true, parse errors will be adapted to issues where possible. + bool adapt_parser_errors = false; }; // Interface for CEL CompilerBuilder objects. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 14586825e..ed22c5630 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -17,16 +17,19 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" @@ -45,19 +48,38 @@ class CompilerImpl : public Compiler { CompilerImpl(std::unique_ptr type_checker, std::unique_ptr parser, // Copy the validator in case builder is reused. - Validator validator) + Validator validator, CompilerOptions options) : type_checker_(std::move(type_checker)), parser_(std::move(parser)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::StatusOr Compile( absl::string_view expression, absl::string_view description, google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); - CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (!options_.adapt_parser_errors || + ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast), arena)); + type_checker_->Check(*std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { @@ -76,16 +98,18 @@ class CompilerImpl : public Compiler { std::unique_ptr type_checker_; std::unique_ptr parser_; Validator validator_; + CompilerOptions options_; }; class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, std::unique_ptr parser_builder, - Validator validator = Validator()) + Validator validator, CompilerOptions options) : type_checker_builder_(std::move(type_checker_builder)), parser_builder_(std::move(parser_builder)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -146,23 +170,23 @@ class CompilerBuilderImpl : public CompilerBuilder { absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); - return std::make_unique(std::move(type_checker), - std::move(parser), validator_); + return std::make_unique( + std::move(type_checker), std::move(parser), validator_, options_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; Validator validator_; + CompilerOptions options_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; }; std::unique_ptr CompilerImpl::ToBuilder() const { - auto builder = std::make_unique( - type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); - return builder; + return std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_, options_); } } // namespace @@ -179,7 +203,8 @@ absl::StatusOr> NewCompilerBuilder( auto parser_builder = NewParserBuilder(options.parser_options); return std::make_unique(std::move(type_checker_builder), - std::move(parser_builder)); + std::move(parser_builder), + Validator(), options); } } // namespace cel diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 214c23765..035fd8aa6 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -413,5 +413,19 @@ TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); } +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + CompilerOptions opts; + opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + } // namespace } // namespace cel diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 3088e6fa8..72605648f 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -23,7 +23,6 @@ #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -110,19 +109,6 @@ struct MacroTestCase { absl::string_view err = ""; }; -std::string FormatIssues(const cel::ValidationResult& result) { - std::string issues; - for (const auto& issue : result.GetIssues()) { - if (!issues.empty()) { - absl::StrAppend(&issues, "\n", - issue.ToDisplayString(*result.GetSource())); - } else { - issues = issue.ToDisplayString(*result.GetSource()); - } - } - return issues; -} - class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) @@ -352,10 +338,11 @@ TEST_P(MathExtMacroParamsTest, ParserTests) { TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { const MacroTestCase& test_case = GetParam(); - - ASSERT_OK_AND_ASSIGN( - auto compiler_builder, - cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CompilerOptions compile_opts; + compile_opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + cel::NewCompilerBuilder( + internal::GetTestingDescriptorPool(), compile_opts)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); @@ -381,16 +368,16 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); - auto result = compiler->Compile(test_case.expr, ""); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); if (!test_case.err.empty()) { - EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr(test_case.err))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); return; } - ASSERT_THAT(result, IsOk()); - ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( @@ -411,9 +398,8 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); - - ASSERT_OK_AND_ASSIGN(auto program, - runtime->CreateProgram(*result->ReleaseAst())); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; diff --git a/parser/BUILD b/parser/BUILD index 63813bb59..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -244,9 +244,11 @@ cc_library( ":options", "//common:ast", "//common:source", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/parser/parser.cc b/parser/parser.cc index f4ee3a1c5..709e2fd41 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -112,13 +112,12 @@ struct ParserError { }; std::string DisplayParserError(const cel::Source& source, - const ParserError& error) { - auto location = - source.GetLocation(error.range.begin).value_or(SourceLocation{}); + SourceLocation location, + absl::string_view message) { return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column - location.column + 1, error.message), + location.column + 1, message), source.DisplayErrorLocation(location)); } @@ -209,7 +208,7 @@ class ParserMacroExprFactory final : public MacroExprFactory { bool HasErrors() const { return error_count_ != 0; } - std::string ErrorMessage() { + std::vector CollectIssues() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location @@ -226,20 +225,23 @@ class ParserMacroExprFactory final : public MacroExprFactory { }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; - std::vector messages; - messages.reserve( + std::vector issues; + issues.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. - std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), - [this](const ParserError& error) { - return cel::DisplayParserError(source_, error); - }); + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); if (errors_truncated) { - messages.emplace_back( - absl::StrCat(error_count_ - 100, " more errors were truncated.")); + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); } - return absl::StrJoin(messages, "\n"); + return issues; } void AddMacroCall(int64_t macro_id, absl::string_view function, @@ -602,6 +604,15 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: @@ -673,7 +684,7 @@ class ParserVisitor final : public CelBaseVisitor, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; - std::string ErrorMessage(); + std::vector CollectIssues(); private: template @@ -1434,7 +1445,9 @@ void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } -std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, @@ -1638,9 +1651,10 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; -absl::StatusOr ParseImpl(const cel::Source& source, - const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1673,13 +1687,23 @@ absl::StatusOr ParseImpl(const cel::Source& source, expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return { @@ -1706,10 +1730,12 @@ class ParserImpl : public cel::Parser { macro_registry_(std::move(macro_registry)), library_ids_(std::move(library_ids)) {} - absl::StatusOr> Parse( - const cel::Source& source) const override { + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, - ParseImpl(source, macro_registry_, options_)); + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 7cc21ff26..ad6e8ca84 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -16,10 +16,14 @@ #include #include +#include +#include +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" @@ -73,6 +77,26 @@ class ParserBuilder { virtual absl::StatusOr> Build() = 0; }; +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. @@ -81,13 +105,35 @@ class Parser { virtual ~Parser() = default; // Parses the given source into a CEL AST. - virtual absl::StatusOr> Parse( - const cel::Source& source) const = 0; + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; // Returns a builder initialized with the configuration of this parser. virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; }; +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a1a65481d..587b63a30 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,25 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From 2f06d90f5b593269c2b1f58de3bfd5c8fc2fa895 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 12:32:22 -0700 Subject: [PATCH 056/143] Add checker support for block. This is needed for re-checking expressions that were produced as a part of policy compilation. PiperOrigin-RevId: 910179322 --- checker/internal/BUILD | 6 + checker/internal/type_checker_impl.cc | 77 ++++++++- checker/internal/type_checker_impl_test.cc | 95 +++++++++++ conformance/BUILD | 5 +- conformance/service.cc | 115 +------------- extensions/BUILD | 5 +- extensions/bindings_ext.cc | 32 +++- extensions/bindings_ext.h | 6 +- testutil/BUILD | 20 +++ testutil/test_macros.cc | 175 +++++++++++++++++++++ testutil/test_macros.h | 33 ++++ 11 files changed, 446 insertions(+), 123 deletions(-) create mode 100644 testutil/test_macros.cc create mode 100644 testutil/test_macros.h diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1c560cdb9..f4c60f937 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -179,6 +180,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", + "//common:ast_proto", "//common:container", "//common:decl", "//common:expr", @@ -187,13 +189,17 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", "//testutil:baseline_tests", + "//testutil:test_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 05601fdbb..2472d7def 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -25,6 +25,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -59,6 +60,15 @@ namespace cel::checker_internal { namespace { +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; @@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase { arena_(arena), current_scope_(&root_scope_) {} - void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view field_name); void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. @@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses @@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { - // Follows list type inferencing behavior in Go (see map comments above). + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); @@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { } } +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // TODO(uncreated-issue/90): This is slightly inconsistent with the java + // runtime implementation which just requires the references to be acyclic. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, @@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase { if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { - auto flattened_type = - FlattenType(inference_context_.FinalizeType(iter->second)); + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); - resolved_types_[expr.id()] = iter->second; + resolved_types_[expr.id()] = finalized_type; rewritten = true; } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index e6cd641d6..893f0689d 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -26,6 +26,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -36,6 +37,7 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/ast_proto.h" #include "common/container.h" #include "common/decl.h" #include "common/expr.h" @@ -45,7 +47,10 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" #include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -108,6 +113,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() { return &(*kArena); } +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", @@ -272,6 +288,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); @@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); return absl::OkStatus(); } @@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/conformance/BUILD b/conformance/BUILD index 139739891..9b527cf35 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -69,6 +69,7 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -221,7 +222,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # cel.@block + # no optional support for legacy types "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", @@ -231,7 +232,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. - "block_ext", + # "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ diff --git a/conformance/service.cc b/conformance/service.cc index 3edc214e6..463334bb5 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,7 +14,6 @@ #include "conformance/service.h" -#include #include #include #include @@ -36,11 +35,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" @@ -48,7 +44,6 @@ #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" -#include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" @@ -72,8 +67,6 @@ #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" -#include "parser/macro.h" -#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -85,6 +78,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -106,109 +100,6 @@ namespace google::api::expr::runtime { namespace { -bool IsCelNamespace(const cel::Expr& target) { - return target.has_ident_expr() && target.ident_expr().name() == "cel"; -} - -absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& bindings_arg = args[0]; - if (!bindings_arg.has_list_expr()) { - return factory.ReportErrorAt( - bindings_arg, "cel.block requires the first arg to be a list literal"); - } - return factory.NewCall("cel.@block", args); -} - -absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& index_arg = args[0]; - if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - int64_t index = index_arg.const_expr().int_value(); - if (index < 0) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - return factory.NewIdent(absl::StrCat("@index", index)); -} - -absl::optional CelIterVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.iterVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.iterVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::optional CelAccuVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.accuVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.accuVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { - CEL_ASSIGN_OR_RETURN(auto block_macro, - cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); - CEL_ASSIGN_OR_RETURN(auto index_macro, - cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); - CEL_ASSIGN_OR_RETURN( - auto iter_var_macro, - cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); - CEL_ASSIGN_OR_RETURN( - auto accu_var_macro, - cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); - return absl::OkStatus(); -} - google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -250,7 +141,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); - CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, @@ -285,6 +176,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( diff --git a/extensions/BUILD b/extensions/BUILD index ff37e2c3f..05104a4a5 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -215,7 +215,10 @@ cc_library( srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ - "//common:ast", + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", "//compiler", "//internal:status_macros", "//parser:macro", diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index f097709ca..c59f724bd 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -21,7 +21,10 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "common/ast.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -34,6 +37,8 @@ namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { @@ -47,6 +52,19 @@ inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { return absl::OkStatus(); } +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + } // namespace std::vector bindings_macros() { @@ -70,8 +88,16 @@ std::vector bindings_macros() { return {*cel_bind}; } -CompilerLibrary BindingsCompilerLibrary() { - return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; } } // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h index a338b24f6..40b83a37f 100644 --- a/extensions/bindings_ext.h +++ b/extensions/bindings_ext.h @@ -25,6 +25,7 @@ namespace cel::extensions { +constexpr int kBindingsVersionLatest = 1; // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); @@ -35,7 +36,10 @@ inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, } // Declarations for the bindings extension library. -CompilerLibrary BindingsCompilerLibrary(); +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); } // namespace cel::extensions diff --git a/testutil/BUILD b/testutil/BUILD index 292696033..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -62,6 +62,26 @@ cc_library( deps = ["//internal:proto_matchers"], ) +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + deps = [ + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "baseline_tests", testonly = True, diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..158135762 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +absl::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +absl::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ From ddcece1479ed78ccde1594e47d94eeb841de115f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 6 May 2026 13:56:21 -0700 Subject: [PATCH 057/143] Refactor optional dispatch tables. PiperOrigin-RevId: 911533283 --- common/values/optional_value.cc | 255 ++++++++++++++++---------------- 1 file changed, 124 insertions(+), 131 deletions(-) diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index ad0a65efb..688cf8fb0 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -122,200 +122,185 @@ absl::Status OptionalValueEqual( return absl::OkStatus(); } +google::protobuf::Arena* absl_nullable OptionalValueGetArenaNull( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return nullptr; +} + +OpaqueValue OptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + return common_internal::MakeOptionalValue(dispatcher, content); +} + +bool OptionalValueHasNoValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content) { + return false; +} + +void EmptyOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = + ErrorValue(absl::FailedPreconditionError("optional.none() dereference")); +} + +void NullOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = NullValue(); +} + +void BoolOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = BoolValue(content.To()); +} + +void IntOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = IntValue(content.To()); +} + +void UintOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UintValue(content.To()); +} + +void DoubleOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = DoubleValue(content.To()); +} + +void DurationOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeDurationValue(content.To()); +} + +void TimestampOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeTimestampValue(content.To()); +} + ABSL_CONST_INIT const OptionalValueDispatcher empty_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, - }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content) -> bool { return false; }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = ErrorValue( - absl::FailedPreconditionError("optional.none() dereference")); + .clone = &OptionalValueClone, }, + &OptionalValueHasNoValue, + &EmptyOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent, - cel::Value* absl_nonnull result) -> void { *result = NullValue(); }, + &NullOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = BoolValue(content.To()); - }, + &BoolOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = IntValue(content.To()); - }, + &IntOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UintValue(content.To()); - }, + &UintOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher double_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = DoubleValue(content.To()); - }, + &DoubleOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher duration_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeDurationValue(content.To()); - }, + &DurationOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher timestamp_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeTimestampValue(content.To()); - }, + &TimestampOptionalValueValue, }; struct OptionalValueContent { @@ -323,43 +308,51 @@ struct OptionalValueContent { google::protobuf::Arena* absl_nonnull arena; }; +google::protobuf::Arena* absl_nullable GenericOptionalValueGetArena( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) { + return content.To().arena; +} + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + +void GenericOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = *content.To().value; +} + ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = - [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent content) -> google::protobuf::Arena* absl_nullable { - return content.To().arena; - }, + .get_arena = &GenericOptionalValueGetArena, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - ABSL_DCHECK(arena != nullptr); - - cel::Value* absl_nonnull result = ::new ( - arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) - cel::Value( - content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { - arena->OwnDestructor(result); - } - return common_internal::MakeOptionalValue( - &optional_value_dispatcher, - OpaqueValueContent::From( - OptionalValueContent{.value = result, .arena = arena})); - }, + .clone = &GenericOptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = *content.To().value; - }, + &GenericOptionalValueValue, }; +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + cel::Value* absl_nonnull result = + ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(content.To().value->Clone(arena)); + if (!ArenaTraits<>::trivially_destructible(result)) { + arena->OwnDestructor(result); + } + return common_internal::MakeOptionalValue( + &optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); +} + } // namespace OptionalValue OptionalValue::Of(cel::Value value, From 5806d30ba86ca40d8ab111e59fa78983afe5319c Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 8 May 2026 13:03:46 -0700 Subject: [PATCH 058/143] Update conformance test skip list PiperOrigin-RevId: 912656156 --- conformance/BUILD | 20 ++++++++++++++++++++ conformance/run.bzl | 6 +++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 9b527cf35..726a11b0b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -201,6 +201,26 @@ _TESTS_TO_SKIP = [ # precision to preserve value. Not available on older compilers where we just use absl::Format. # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", + + # Recent changes + "proto2/set_null/repeated_field_timestamp_null_pruned", + "proto2/set_null/repeated_field_duration_null_pruned", + "proto2/set_null/repeated_field_wrapper_null_pruned", + "proto2/set_null/map_timestamp_null_pruned", + "proto2/set_null/map_duration_null_pruned", + "proto2/set_null/map_wrapper_null_pruned", + "proto3/set_null/repeated_field_timestamp_null_pruned", + "proto3/set_null/repeated_field_duration_null_pruned", + "proto3/set_null/repeated_field_wrapper_null_pruned", + "proto3/set_null/map_timestamp_null_pruned", + "proto3/set_null/map_duration_null_pruned", + "proto3/set_null/map_wrapper_null_pruned", + "string_ext/format/default precision for fixed-point clause with int", + "string_ext/format/default precision for fixed-point clause with uint", + "string_ext/format/default precision for scientific notation with int", + "string_ext/format/default precision for scientific notation with uint", + "namespace/namespace_shadowing/basic", + "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] _TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP diff --git a/conformance/run.bzl b/conformance/run.bzl index 4fcf325c6..d53fd539c 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -70,7 +70,7 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests={}".format(",".join(_expand_tests_to_skip(skip_tests)))) + args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -80,8 +80,8 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ name = _conformance_test_name(name, optimize, recursive), args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( { - "@platforms//os:windows": ["--skip_tests={}".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests={}".format(",".join(skip_tests))], + "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], + "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], }, ), data = data, From cb9dc8a2e71e503655b1992bdba3debc7fda12a7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 8 May 2026 15:45:31 -0700 Subject: [PATCH 059/143] Fix command line argument splitting issue for conformance tests. PiperOrigin-RevId: 912731724 --- conformance/run.bzl | 10 +++++----- conformance/run.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/conformance/run.bzl b/conformance/run.bzl index d53fd539c..15850b0aa 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): args = [] if modern: args.append("--modern") @@ -70,7 +70,6 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -78,10 +77,11 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(location " + test + ")" for test in data], + env = select( { - "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], + "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, + "//conditions:default": {"CEL_SKIP_TESTS": ",".join(skip_tests)}, }, ), data = data, diff --git a/conformance/run.cc b/conformance/run.cc index d5a919d76..80164d9a4 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -42,6 +42,7 @@ #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" @@ -273,6 +274,13 @@ int main(int argc, char** argv) { { auto service = NewConformanceServiceFromFlags(); auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); + if (const char* env_skip = std::getenv("CEL_SKIP_TESTS"); + env_skip != nullptr) { + for (absl::string_view test : + absl::StrSplit(env_skip, ',', absl::SkipEmpty())) { + tests_to_skip.push_back(std::string(test)); + } + } for (int argi = 1; argi < argc; argi++) { ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, absl::string_view(argv[argi]))); From cf31ddf620b9d809014418e82428863b54190cbb Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 11 May 2026 10:50:27 -0700 Subject: [PATCH 060/143] Introduce `Bind` expression factory helper PiperOrigin-RevId: 913778503 --- common/expr_factory.h | 23 ++++++++++++++ parser/macro_expr_factory_test.cc | 51 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index b9769b457..773217ad9 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -352,6 +352,29 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 489538be1..b95cbe16f 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -15,6 +15,7 @@ #include "parser/macro_expr_factory.h" #include +#include #include #include "absl/strings/string_view.h" @@ -39,6 +40,7 @@ class TestMacroExprFactory final : public MacroExprFactory { return NewUnspecified(NextId()); } + using MacroExprFactory::NewBind; using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; @@ -69,6 +71,8 @@ class TestMacroExprFactory final : public MacroExprFactory { namespace { +using ::testing::IsEmpty; + TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); @@ -147,5 +151,52 @@ TEST(MacroExprFactory, CopyComprehension) { factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel From 2e6e9ff4493bfbe0baf883107f3fb7ce6f675d88 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 11 May 2026 21:06:47 -0700 Subject: [PATCH 061/143] Add support for abbreviations and aliases in container configuration for CEL C++ environment YAML. This allows specifying name, abbreviations, and aliases in a container config instead of just a string. The string syntax is preserved as an alternative PiperOrigin-RevId: 914038623 --- env/BUILD | 1 + env/config.h | 11 +++- env/env.cc | 12 +++- env/env_test.cc | 30 ++++++++++ env/env_yaml.cc | 107 +++++++++++++++++++++++++++++++-- env/env_yaml_test.cc | 139 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 289 insertions(+), 11 deletions(-) diff --git a/env/BUILD b/env/BUILD index 55297b190..41ffc1723 100644 --- a/env/BUILD +++ b/env/BUILD @@ -52,6 +52,7 @@ cc_library( ":config", "//checker:type_checker_builder", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//compiler", diff --git a/env/config.h b/env/config.h index 10b23d030..e427832ff 100644 --- a/env/config.h +++ b/env/config.h @@ -34,9 +34,16 @@ class Config { struct ContainerConfig { std::string name; - // TODO(uncreated-issue/87): add support for aliases and abbreviations. + std::vector abbreviations; + struct Alias { + std::string alias; + std::string qualified_name; + }; + std::vector aliases; - bool IsEmpty() const { return name.empty(); } + bool IsEmpty() const { + return name.empty() && abbreviations.empty() && aliases.empty(); + } }; void SetContainerConfig(ContainerConfig container_config) { diff --git a/env/env.cc b/env/env.cc index 5a4198497..42652ce59 100644 --- a/env/env.cc +++ b/env/env.cc @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" @@ -130,7 +131,16 @@ absl::StatusOr> Env::NewCompilerBuilder() { cel::TypeCheckerBuilder& checker_builder = compiler_builder->GetCheckerBuilder(); - checker_builder.set_container(config_.GetContainerConfig().name); + ExpressionContainer container; + CEL_RETURN_IF_ERROR( + container.SetContainer(config_.GetContainerConfig().name)); + for (const auto& abbr : config_.GetContainerConfig().abbreviations) { + CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); + } + for (const auto& alias : config_.GetContainerConfig().aliases) { + CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); + } + checker_builder.SetExpressionContainer(std::move(container)); if (!config_.GetStandardLibraryConfig().disable) { CEL_RETURN_IF_ERROR( diff --git a/env/env_test.cc b/env/env_test.cc index 076eb57bc..b599aa569 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -314,6 +314,36 @@ TEST(ContainerConfigTest, ContainerConfig) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContainerConfigTest, ContainerConfigWithAbbreviations) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .abbreviations = {"cel.expr.conformance.proto2.TestAllTypes"}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAliases) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .aliases = { + {.alias = "MyTestType", + .qualified_name = "cel.expr.conformance.proto2.TestAllTypes"}}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("MyTestType{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 4ba16ea84..159786598 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -150,12 +150,72 @@ absl::Status ParseName(Config& config, absl::string_view yaml, absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node container = root["container"]; - if (container.IsDefined()) { - if (!container.IsScalar()) { - return YamlError(yaml, container, "Node 'container' is not a string"); - } + if (!container.IsDefined()) { + return absl::OkStatus(); + } + + if (container.IsScalar()) { config.SetContainerConfig({.name = GetString(yaml, container)}); + return absl::OkStatus(); } + + if (!container.IsMap()) { + return YamlError(yaml, container, + "Node 'container' is neither a string nor a map"); + } + + Config::ContainerConfig container_config; + + const YAML::Node name = container["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' in container is not a string"); + } + container_config.name = GetString(yaml, name); + } + + const YAML::Node abbreviations = container["abbreviations"]; + if (abbreviations.IsDefined()) { + if (!abbreviations.IsSequence()) { + return YamlError(yaml, abbreviations, + "Node 'abbreviations' is not a sequence"); + } + for (const YAML::Node& abbr : abbreviations) { + if (!abbr.IsScalar()) { + return YamlError(yaml, abbr, "Abbreviation is not a string"); + } + container_config.abbreviations.push_back(GetString(yaml, abbr)); + } + } + + const YAML::Node aliases = container["aliases"]; + if (aliases.IsDefined()) { + if (!aliases.IsSequence()) { + return YamlError(yaml, aliases, "Node 'aliases' is not a sequence"); + } + for (const YAML::Node& alias_node : aliases) { + if (!alias_node.IsMap()) { + return YamlError(yaml, alias_node, "Alias entry is not a map"); + } + const YAML::Node alias_key = alias_node["alias"]; + const YAML::Node qualified_name_key = alias_node["qualified_name"]; + + if (!alias_key.IsDefined() || !alias_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'alias' string"); + } + if (!qualified_name_key.IsDefined() || !qualified_name_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'qualified_name' string"); + } + + container_config.aliases.push_back( + {.alias = GetString(yaml, alias_key), + .qualified_name = GetString(yaml, qualified_name_key)}); + } + } + + config.SetContainerConfig(std::move(container_config)); return absl::OkStatus(); } @@ -686,7 +746,44 @@ void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { } out << YAML::Key << "container"; - out << YAML::Value << YAML::DoubleQuoted << container_config.name; + if (container_config.abbreviations.empty() && + container_config.aliases.empty()) { + out << YAML::Value << YAML::DoubleQuoted << container_config.name; + } else { + out << YAML::Value << YAML::BeginMap; + if (!container_config.name.empty()) { + out << YAML::Key << "name" << YAML::Value << YAML::DoubleQuoted + << container_config.name; + } + if (!container_config.abbreviations.empty()) { + std::vector sorted_abbrs = container_config.abbreviations; + absl::c_sort(sorted_abbrs); + out << YAML::Key << "abbreviations" << YAML::Value << YAML::BeginSeq; + for (const auto& abbr : sorted_abbrs) { + out << YAML::Value << YAML::DoubleQuoted << abbr; + } + out << YAML::EndSeq; + } + if (!container_config.aliases.empty()) { + std::vector sorted_aliases = + container_config.aliases; + absl::c_sort(sorted_aliases, [](const Config::ContainerConfig::Alias& a, + const Config::ContainerConfig::Alias& b) { + return a.alias < b.alias; + }); + out << YAML::Key << "aliases" << YAML::Value << YAML::BeginSeq; + for (const auto& alias : sorted_aliases) { + out << YAML::BeginMap; + out << YAML::Key << "alias" << YAML::Value << YAML::DoubleQuoted + << alias.alias; + out << YAML::Key << "qualified_name" << YAML::Value + << YAML::DoubleQuoted << alias.qualified_name; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } } void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 25cc63206..d19c0dbfb 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -55,6 +55,31 @@ TEST(EnvYamlTest, ParseContainerConfig) { Field(&Config::ContainerConfig::name, "test.container")); } +TEST(EnvYamlTest, ParseContainerConfig_AlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: + name: test.container + abbreviations: + - abbr1.Abbr1 + - abbr2.Abbr2 + aliases: + - alias: alias1 + qualified_name: qual.name1 + - alias: alias2 + qualified_name: qual.name2 + )yaml")); + + const auto& container_config = config.GetContainerConfig(); + EXPECT_EQ(container_config.name, "test.container"); + EXPECT_THAT(container_config.abbreviations, + UnorderedElementsAre("abbr1.Abbr1", "abbr2.Abbr2")); + ASSERT_THAT(container_config.aliases, SizeIs(2)); + EXPECT_EQ(container_config.aliases[0].alias, "alias1"); + EXPECT_EQ(container_config.aliases[0].qualified_name, "qual.name1"); + EXPECT_EQ(container_config.aliases[1].alias, "alias2"); + EXPECT_EQ(container_config.aliases[1].qualified_name, "qual.name2"); +} + TEST(EnvYamlTest, ParseExtensionConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( extensions: @@ -550,9 +575,78 @@ INSTANTIATE_TEST_SUITE_P( container: - error: "error" )yaml", - .expected_error = "3:19: Node 'container' is not a string\n" - "| - error: \"error\"\n" - "| ^", + .expected_error = + "3:19: Node 'container' is neither a string nor a map\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + name: [] + )yaml", + .expected_error = "3:25: Node 'name' in container is not a string\n" + "| name: []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: "abbr" + )yaml", + .expected_error = "3:34: Node 'abbreviations' is not a sequence\n" + "| abbreviations: \"abbr\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: + - [] + )yaml", + .expected_error = "4:21: Abbreviation is not a string\n" + "| - []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: "not a sequence" + )yaml", + .expected_error = "3:28: Node 'aliases' is not a sequence\n" + "| aliases: \"not a sequence\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - "not a map" + )yaml", + .expected_error = "4:21: Alias entry is not a map\n" + "| - \"not a map\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - qualified_name: "qual" + )yaml", + .expected_error = "4:21: Alias entry missing 'alias' string\n" + "| - qualified_name: \"qual\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - alias: "my_alias" + )yaml", + .expected_error = "4:21: Alias entry missing" + " 'qualified_name' string\n" + "| - alias: \"my_alias\"\n" + "| ^", }, ParseTestCase{ .yaml = R"yaml( @@ -946,6 +1040,33 @@ std::vector GetExportTestCases() { container: "test.container" )yaml", }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig( + {.name = "test.container", + .abbreviations = {"foo", "bar"}, + .aliases = { + {.alias = "foo", .qualified_name = "test.foo"}, + {.alias = "bar", .qualified_name = "test.bar"}, + }}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: + name: "test.container" + abbreviations: + - "bar" + - "foo" + aliases: + - alias: "bar" + qualified_name: "test.bar" + - alias: "foo" + qualified_name: "test.foo" + )yaml", + }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; @@ -1385,6 +1506,18 @@ std::vector GetRoundTripTestCases() { overloads: - id: "string_to_timestamp" )yaml", + R"yaml( + container: + name: "test.container" + abbreviations: + - "abbr1.Abbr1" + - "abbr2.Abbr2" + aliases: + - alias: "alias1" + qualified_name: "qual.name1" + - alias: "alias2" + qualified_name: "qual.name2" + )yaml", R"yaml( extensions: - name: "bindings" From cd9f059a5833c92576e85e3ffb2eaee2fd328e76 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 12:32:02 -0700 Subject: [PATCH 062/143] Fix repeated field null pruning for proto2/proto3 PiperOrigin-RevId: 914421409 --- common/values/struct_value_builder.cc | 11 +++++++++++ conformance/BUILD | 6 ------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 359596267..c342d6478 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -812,6 +812,17 @@ ProtoMessageRepeatedFieldFromValueMutator( const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + // If the value is null and the target repeated field is anything except + // google.protobuf.{Any,ListValue,Struct,Value}, it should be pruned. + if (value.IsNull()) { + const auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY) { + return absl::nullopt; + } + } auto* element = reflection->AddMessage(message, field, factory); auto result = ProtoMessageFromValueImpl(value, pool, factory, well_known_types, element); diff --git a/conformance/BUILD b/conformance/BUILD index 726a11b0b..abc0d918a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,15 +203,9 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/repeated_field_timestamp_null_pruned", - "proto2/set_null/repeated_field_duration_null_pruned", - "proto2/set_null/repeated_field_wrapper_null_pruned", "proto2/set_null/map_timestamp_null_pruned", "proto2/set_null/map_duration_null_pruned", "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/repeated_field_timestamp_null_pruned", - "proto3/set_null/repeated_field_duration_null_pruned", - "proto3/set_null/repeated_field_wrapper_null_pruned", "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", From 4749cf81003d9264fd87ca8b0640b5189bcc2b9e Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 16:00:59 -0700 Subject: [PATCH 063/143] Fix scientific notation and fixed point formatting for int and uint PiperOrigin-RevId: 914525320 --- conformance/BUILD | 4 ---- extensions/formatting.cc | 6 ++++++ extensions/formatting_test.cc | 12 ++++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index abc0d918a..4f9232ab6 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -209,10 +209,6 @@ _TESTS_TO_SKIP = [ "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", - "string_ext/format/default precision for fixed-point clause with int", - "string_ext/format/default precision for fixed-point clause with uint", - "string_ext/format/default precision for scientific notation with int", - "string_ext/format/default precision for scientific notation with uint", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 935815569..252fdc7bd 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -419,6 +419,12 @@ absl::StatusOr GetDouble(const Value& value, std::string& scratch) { str)); } } + if (value.kind() == ValueKind::kInt) { + return static_cast(value.GetInt().NativeValue()); + } + if (value.kind() == ValueKind::kUint) { + return static_cast(value.GetUint().NativeValue()); + } if (value.kind() != ValueKind::kDouble) { return absl::InvalidArgumentError( absl::StrCat("expected a double but got a ", value.GetTypeName())); diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index b80fe9bc0..6a7fb300b 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -553,6 +553,18 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "2.71828", .expected = "2.718280e+00", }, + { + .name = "FixedPointClauseWithInt", + .format = "%f", + .format_args = "3", + .expected = "3.000000", + }, + { + .name = "ScientificNotationWithUint", + .format = "%e", + .format_args = "uint(3)", + .expected = "3.000000e+00", + }, { .name = "NaNSupportForFixedPoint", .format = "%f", From 352666fba7822dd0d1f54dc00b332cc527aa81b1 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 13 May 2026 13:07:09 -0700 Subject: [PATCH 064/143] Fix map field value null pruning for proto2/proto3 PiperOrigin-RevId: 915015044 --- common/values/struct_value_builder.cc | 23 +++++++++++++++++++ conformance/BUILD | 6 ----- .../structs/proto_message_type_adapter.cc | 16 +++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index c342d6478..446b18421 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -956,6 +956,19 @@ class MessageValueBuilderImpl { if (error_value) { return false; } + if (map_value_field->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + entry_value.IsNull()) { + auto well_known_type = + map_value_field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } google::protobuf::MapValueRef proto_value; extensions::protobuf_internal::InsertOrLookupMapValue( *reflection_, message_, *field, proto_key, &proto_value); @@ -989,6 +1002,16 @@ class MessageValueBuilderImpl { CEL_RETURN_IF_ERROR(list_value->ForEach( [this, field, accessor, &error_value](const Value& element) -> absl::StatusOr { + if (field->message_type() != nullptr && element.IsNull()) { + auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } CEL_ASSIGN_OR_RETURN(error_value, (*accessor)(descriptor_pool_, message_factory_, &well_known_types_, reflection_, diff --git a/conformance/BUILD b/conformance/BUILD index 4f9232ab6..ccd2844c9 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,12 +203,6 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/map_timestamp_null_pruned", - "proto2/set_null/map_duration_null_pruned", - "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/map_timestamp_null_pruned", - "proto3/set_null/map_duration_null_pruned", - "proto3/set_null/map_wrapper_null_pruned", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a351890c2..6a3417ba3 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -582,6 +582,19 @@ absl::Status ProtoMessageTypeAdapter::SetField( ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), "failed to find value field descriptor")); + bool prune_when_null = false; + if (value_field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + auto well_known_type = + value_field_descriptor->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + prune_when_null = true; + } + } + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list).Get(arena, i); @@ -589,6 +602,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( auto value = (*cel_map).Get(arena, key); CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), "error serializing CelMap")); + if (prune_when_null && value->IsNull()) { + continue; + } Message* entry_msg = message->GetReflection()->AddMessage(message, field); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( key, key_field_descriptor, entry_msg, arena)); From 037e0bb42339376640024de353451e372bb47820 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Wed, 13 May 2026 13:33:46 -0700 Subject: [PATCH 065/143] Adding a TypeSpec to Type resolver. PiperOrigin-RevId: 915029748 --- common/BUILD | 32 ++++ common/type_spec_resolver.cc | 182 +++++++++++++++++++++ common/type_spec_resolver.h | 37 +++++ common/type_spec_resolver_test.cc | 257 ++++++++++++++++++++++++++++++ 4 files changed, 508 insertions(+) create mode 100644 common/type_spec_resolver.cc create mode 100644 common/type_spec_resolver.h create mode 100644 common/type_spec_resolver_test.cc diff --git a/common/BUILD b/common/BUILD index 0ead8b15a..ffc4ae1e9 100644 --- a/common/BUILD +++ b/common/BUILD @@ -46,6 +46,38 @@ cc_test( ], ) +cc_library( + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], + deps = [ + ":ast", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..97451f390 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,182 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + CEL_ASSIGN_OR_RETURN( + auto elem_type, + ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + CEL_ASSIGN_OR_RETURN( + auto key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + CEL_ASSIGN_OR_RETURN( + auto value_type, + ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + CEL_ASSIGN_OR_RETURN( + auto result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + std::vector arg_types; + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..44e1e088f --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,37 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Resolves a `cel::TypeSpec` to a `cel::Type`. +// +// TypeSpec only specifies a type while Type provides support for inspecting +// properties of the type when used in CEL. Returns a status with code +// `InvalidArgument` if the input cannot be resolved to a type. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..c7fbb2cf8 --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,257 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Values; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +using ConversionTest = testing::TestWithParam>; + +TEST_P(ConversionTest, TestTypeSpecConversion) { + ASSERT_OK_AND_ASSIGN( + auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_EQ(t.kind(), std::get<1>(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + TypeSpecResolverTest, ConversionTest, + testing::Values( + std::make_tuple(TypeSpec(PrimitiveType::kBool), TypeKind::kBool), + std::make_tuple(TypeSpec(PrimitiveType::kInt64), TypeKind::kInt), + std::make_tuple(TypeSpec(PrimitiveType::kUint64), TypeKind::kUint), + std::make_tuple(TypeSpec(PrimitiveType::kDouble), TypeKind::kDouble), + std::make_tuple(TypeSpec(PrimitiveType::kString), TypeKind::kString), + std::make_tuple(TypeSpec(PrimitiveType::kBytes), TypeKind::kBytes), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kAny), TypeKind::kAny), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kTimestamp), + TypeKind::kTimestamp), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kDuration), + TypeKind::kDuration), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + TypeKind::kBoolWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + TypeKind::kIntWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + TypeKind::kUintWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + TypeKind::kDoubleWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + TypeKind::kStringWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + TypeKind::kBytesWrapper))); + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); +} + +TEST(TypeSpecResolverTest, OptionalType) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("optional_type", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "optional_type"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + EXPECT_TRUE(t->IsOptional()); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnknownTypeSpecKindError) { + TypeSpec spec; + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown TypeSpec kind"))); +} + +} // namespace +} // namespace cel From ad18948079b2d3d8b9e62a202889076f872992e7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:06:31 -0700 Subject: [PATCH 066/143] No public description PiperOrigin-RevId: 915223448 --- eval/public/ast_rewrite.cc | 2 +- eval/public/ast_traverse.cc | 2 +- eval/public/cel_attribute.cc | 4 ++-- eval/public/equality_function_registrar_test.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 3c210e607..87c667eb5 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -68,7 +68,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index a86923c67..c18b806b9 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -67,7 +67,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 015289bed..70525a04d 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -76,8 +76,8 @@ CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 772ddfeba..577c4be22 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -86,7 +86,7 @@ MATCHER_P2(DefinesHomogenousOverload, name, argument_type, struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From 1f5a7e62900ae2ad1021228df04f2a950744c001 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:10:11 -0700 Subject: [PATCH 067/143] No public description PiperOrigin-RevId: 915224564 --- common/ast/constant_proto.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/ast/constant_proto.cc b/common/ast/constant_proto.cc index c0fe1c9f6..1982c05b4 100644 --- a/common/ast/constant_proto.cc +++ b/common/ast/constant_proto.cc @@ -35,7 +35,7 @@ using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, ConstantProto* absl_nonnull proto) { return absl::visit(absl::Overload( - [proto](absl::monostate) -> absl::Status { + [proto](std::monostate) -> absl::Status { proto->clear_constant_kind(); return absl::OkStatus(); }, From 6d311f704ade7aea062dd1091dfe3e683938fc78 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:12:46 -0700 Subject: [PATCH 068/143] No public description PiperOrigin-RevId: 915225296 --- internal/json.cc | 2 +- internal/message_equality.cc | 8 ++++---- internal/well_known_types.cc | 2 +- internal/well_known_types_test.cc | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/json.cc b/internal/json.cc index 630ceb267..cdd4c1a5d 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -1417,7 +1417,7 @@ class JsonMapIterator final { } private: - absl::variant variant_; + std::variant variant_; }; class JsonAccessor { diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 945cca8df..33ef78089 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -86,10 +86,10 @@ class EquatableMessage final }; using EquatableValue = - absl::variant; + std::variant; struct NullValueEqualer { bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index dee029534..02e50c3e3 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -2174,7 +2174,7 @@ absl::StatusOr AdaptFromMessage( if (adapted) { return adapted; } - return absl::monostate{}; + return std::monostate{}; } } diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc index 0d2c9fe33..afc8ce396 100644 --- a/internal/well_known_types_test.cc +++ b/internal/well_known_types_test.cc @@ -806,7 +806,7 @@ TEST_F(AdaptFromMessageTest, Struct) { TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(AdaptFromMessage(*message), - IsOkAndHolds(VariantWith(absl::monostate()))); + IsOkAndHolds(VariantWith(std::monostate()))); } TEST_F(AdaptFromMessageTest, Any_BoolValue) { From fb51dcdfd1082e67d209c1ba0c84e58b577c378a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:13:04 -0700 Subject: [PATCH 069/143] No public description PiperOrigin-RevId: 915268033 --- runtime/internal/convert_constant.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc index a9effd229..33f382858 100644 --- a/runtime/internal/convert_constant.cc +++ b/runtime/internal/convert_constant.cc @@ -33,7 +33,7 @@ using ::cel::Constant; struct ConvertVisitor { Allocator<> allocator; - absl::StatusOr operator()(absl::monostate) { + absl::StatusOr operator()(std::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } From 877239571674284da3d22bcc7ccfe2e175643de7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:15:52 -0700 Subject: [PATCH 070/143] No public description PiperOrigin-RevId: 915269088 --- common/values/message_value.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/common/values/message_value.cc b/common/values/message_value.cc index e06206407..66dfd9511 100644 --- a/common/values/message_value.cc +++ b/common/values/message_value.cc @@ -46,7 +46,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c ABSL_CHECK(*this); // Crash OK return absl::visit( absl::Overload( - [](absl::monostate) -> const google::protobuf::Descriptor* absl_nonnull { + [](std::monostate) -> const google::protobuf::Descriptor* absl_nonnull { ABSL_UNREACHABLE(); }, [](const ParsedMessageValue& alternative) @@ -58,7 +58,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c std::string MessageValue::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) -> std::string { return "INVALID"; }, + absl::Overload([](std::monostate) -> std::string { return "INVALID"; }, [](const ParsedMessageValue& alternative) -> std::string { return alternative.DebugString(); }), @@ -68,7 +68,7 @@ std::string MessageValue::DebugString() const { bool MessageValue::IsZeroValue() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) -> bool { return true; }, + absl::Overload([](std::monostate) -> bool { return true; }, [](const ParsedMessageValue& alternative) -> bool { return alternative.IsZeroValue(); }), @@ -81,7 +81,7 @@ absl::Status MessageValue::SerializeTo( google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -99,7 +99,7 @@ absl::Status MessageValue::ConvertToJson( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -117,7 +117,7 @@ absl::Status MessageValue::ConvertToJsonObject( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJsonObject` on " "an invalid `MessageValue`"); @@ -136,7 +136,7 @@ absl::Status MessageValue::Equal( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Equal` on " "an invalid `MessageValue`"); @@ -155,7 +155,7 @@ absl::Status MessageValue::GetFieldByName( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByName` on " "an invalid `MessageValue`"); @@ -175,7 +175,7 @@ absl::Status MessageValue::GetFieldByNumber( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByNumber` on " "an invalid `MessageValue`"); @@ -192,7 +192,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::string_view name) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByName` on " "an invalid `MessageValue`"); @@ -206,7 +206,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByNumber` on " "an invalid `MessageValue`"); @@ -224,7 +224,7 @@ absl::Status MessageValue::ForEachField( google::protobuf::Arena* absl_nonnull arena) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ForEachField` on " "an invalid `MessageValue`"); @@ -244,7 +244,7 @@ absl::Status MessageValue::Qualify( int* absl_nonnull count) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Qualify` on " "an invalid `MessageValue`"); From ff45a7c2a096ed1d38e6ed4d80a7180be32874b7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:39:55 -0700 Subject: [PATCH 071/143] No public description PiperOrigin-RevId: 915340723 --- common/ast_rewrite.cc | 2 +- common/ast_traverse.cc | 2 +- common/decl_proto.cc | 2 +- common/decl_proto_test.cc | 4 ++-- common/decl_proto_v1alpha1.cc | 2 +- common/type.cc | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/ast_rewrite.cc b/common/ast_rewrite.cc index 14582f44f..b61e1fab6 100644 --- a/common/ast_rewrite.cc +++ b/common/ast_rewrite.cc @@ -54,7 +54,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc index a6ba0d1ba..fb4f9731e 100644 --- a/common/ast_traverse.cc +++ b/common/ast_traverse.cc @@ -53,7 +53,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/decl_proto.cc b/common/decl_proto.cc index 89f7f4453..098c5068c 100644 --- a/common/decl_proto.cc +++ b/common/decl_proto.cc @@ -69,7 +69,7 @@ absl::StatusOr FunctionDeclFromProto( return decl; } -absl::StatusOr> DeclFromProto( +absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/decl_proto_test.cc b/common/decl_proto_test.cc index 62215f07f..d72d97e09 100644 --- a/common/decl_proto_test.cc +++ b/common/decl_proto_test.cc @@ -49,7 +49,7 @@ TEST_P(DeclFromProtoTest, FromProtoWorks) { cel::expr::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromProto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { @@ -79,7 +79,7 @@ TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { google::api::expr::v1alpha1::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { diff --git a/common/decl_proto_v1alpha1.cc b/common/decl_proto_v1alpha1.cc index 2c6cfb6e4..a8d73e5c2 100644 --- a/common/decl_proto_v1alpha1.cc +++ b/common/decl_proto_v1alpha1.cc @@ -52,7 +52,7 @@ absl::StatusOr FunctionDeclFromV1Alpha1Proto( return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); } -absl::StatusOr> DeclFromV1Alpha1Proto( +absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/type.cc b/common/type.cc index ce8c7a89a..f94e8bc52 100644 --- a/common/type.cc +++ b/common/type.cc @@ -97,7 +97,7 @@ static constexpr std::array kTypeToKindArray = { TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == - absl::variant_size(), + std::variant_size(), "Kind indexer must match variant declaration for cel::Type."); } // namespace From 366498bd3820ab8382282ca15279753e4789be31 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:46:00 -0700 Subject: [PATCH 072/143] No public description PiperOrigin-RevId: 915342975 --- common/types/struct_type.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc index 4540cec9c..a1be1f786 100644 --- a/common/types/struct_type.cc +++ b/common/types/struct_type.cc @@ -27,7 +27,7 @@ namespace cel { absl::string_view StructType::name() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) { return absl::string_view(); }, + absl::Overload([](std::monostate) { return absl::string_view(); }, [](const common_internal::BasicStructType& alt) { return alt.name(); }, @@ -39,7 +39,7 @@ TypeParameters StructType::GetParameters() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload( - [](absl::monostate) { return TypeParameters(); }, + [](std::monostate) { return TypeParameters(); }, [](const common_internal::BasicStructType& alt) { return alt.GetParameters(); }, @@ -49,7 +49,7 @@ TypeParameters StructType::GetParameters() const { std::string StructType::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) { return std::string(); }, + absl::Overload([](std::monostate) { return std::string(); }, [](common_internal::BasicStructType alt) { return alt.DebugString(); }, @@ -72,7 +72,7 @@ MessageType StructType::GetMessage() const { common_internal::TypeVariant StructType::ToTypeVariant() const { return absl::visit( absl::Overload( - [](absl::monostate) { return common_internal::TypeVariant(); }, + [](std::monostate) { return common_internal::TypeVariant(); }, [](common_internal::BasicStructType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); From 33156b1e59b458ff6c24208dbcb66ace3186ab9b Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 04:28:29 -0700 Subject: [PATCH 073/143] No public description PiperOrigin-RevId: 915359111 --- extensions/select_optimization.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 0f09773ae..44da4c48a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -92,7 +92,7 @@ struct SelectInstruction { // Represents a single qualifier in a traversal path. // TODO(uncreated-issue/51): support variable indexes. using QualifierInstruction = - absl::variant; + std::variant; struct SelectPath { Expr* operand; From a88afaca5106943d9f835cb622be9813b6bdee55 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 05:07:40 -0700 Subject: [PATCH 074/143] No public description PiperOrigin-RevId: 915372200 --- runtime/memory_safety_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 7e864ecf6..2a09be666 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -73,7 +73,7 @@ struct TestCase { std::string name; std::string expression; absl::flat_hash_map> + std::variant> activation; test::ValueMatcher expected_matcher; bool reference_resolver_enabled = false; From 513af3c2c338c0aabbbef21419018880dc9c23c4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:12 -0700 Subject: [PATCH 075/143] No public description PiperOrigin-RevId: 915787513 --- eval/internal/cel_value_equal_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/internal/cel_value_equal_test.cc b/eval/internal/cel_value_equal_test.cc index f52f38916..109a63795 100644 --- a/eval/internal/cel_value_equal_test.cc +++ b/eval/internal/cel_value_equal_test.cc @@ -67,7 +67,7 @@ using ::testing::ValuesIn; struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From f62419d04f3a4c12ecf2a802e95d26e33aa2b115 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:24 -0700 Subject: [PATCH 076/143] No public description PiperOrigin-RevId: 915787600 --- tools/branch_coverage.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc index 00ab7cb5a..b5bba3ffe 100644 --- a/tools/branch_coverage.cc +++ b/tools/branch_coverage.cc @@ -71,7 +71,7 @@ struct OtherNode { // Representation for coverage of an AST node. struct CoverageNode { int evaluate_count; - absl::variant kind; + std::variant kind; }; const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, From 7820913cda14e09bbb667c10d65391c1a79fb95d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:16:10 -0700 Subject: [PATCH 077/143] No public description PiperOrigin-RevId: 915792480 --- common/value.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/value.cc b/common/value.cc index 535ddead8..1cd3f54e1 100644 --- a/common/value.cc +++ b/common/value.cc @@ -115,7 +115,7 @@ Type Value::GetRuntimeType() const { namespace { template -struct IsMonostate : std::is_same, absl::monostate> {}; +struct IsMonostate : std::is_same, std::monostate> {}; } // namespace @@ -171,7 +171,7 @@ absl::Status Value::ConvertToJsonArray( google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -212,7 +212,7 @@ absl::Status Value::ConvertToJsonObject( google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -1363,7 +1363,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->CopyFrom(message); return ParsedMessageValue(cloned, arena); @@ -1391,7 +1391,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->GetReflection()->Swap(cloned, &message); return ParsedMessageValue(cloned, arena); @@ -1422,7 +1422,7 @@ Value Value::WrapMessage( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { auto* cloned = message->New(arena); cloned->CopyFrom(*message); @@ -1456,7 +1456,7 @@ Value Value::WrapMessageUnsafe( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { return UnsafeParsedMessageValue(message); } From da45d34071c8fe9f77fcc17e33e518841d382cdc Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Mon, 18 May 2026 08:43:19 -0700 Subject: [PATCH 078/143] Add missing include for `google/rpc/status.proto.h`. This code was relying on the transitive inclusion of third_party/cel/cpp/* to provide the type information for the Status proto. This makes the code brittle and prone to breakages when doing internal header refactors. PiperOrigin-RevId: 917253701 --- conformance/BUILD | 1 + conformance/service.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/conformance/BUILD b/conformance/BUILD index ccd2844c9..0ca90a4bc 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -83,6 +83,7 @@ cc_library( "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/conformance/service.cc b/conformance/service.cc index 463334bb5..7e3eded82 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -30,6 +30,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" +#include "google/rpc/status.pb.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" From d475cc6726ef85fefed557c8eb0e400119d13e95 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:14 -0700 Subject: [PATCH 079/143] No public description PiperOrigin-RevId: 918791547 --- codelab/network_functions.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index f4f729827..64f199cb3 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -213,8 +213,7 @@ absl::Status NetworkAddressRepEqual( return absl::OkStatus(); } const NetworkAddressRep rep = content.To(); - absl::optional other_rep = - NetworkAddressRep::Unwrap(other); + std::optional other_rep = NetworkAddressRep::Unwrap(other); ABSL_DCHECK(other_rep.has_value()); *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); return absl::OkStatus(); @@ -311,7 +310,7 @@ cel::Value parseAddress( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); } @@ -321,7 +320,7 @@ cel::Value parseAddress( cel::Value parseAddressOrZero(const cel::StringValue& str) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); static const NetworkAddressRep kZero; if (!rep.has_value()) { return NetworkAddressRep::MakeValue(kZero); @@ -336,8 +335,7 @@ cel::Value parseAddressMatcher( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = - NetworkAddressMatcher::Parse(addr); + std::optional rep = NetworkAddressMatcher::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue( absl::InvalidArgumentError("invalid address matcher")); @@ -365,7 +363,7 @@ cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { cel::OpaqueValueContent::From(rep)); } -absl::optional NetworkAddressRep::Unwrap( +std::optional NetworkAddressRep::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || @@ -381,7 +379,7 @@ absl::optional NetworkAddressRep::Unwrap( return opaque->content().To(); } -absl::optional NetworkAddressRep::Parse( +std::optional NetworkAddressRep::Parse( absl::string_view str) { uint32_t ipv4 = 0; char ipv6[16]; @@ -418,7 +416,7 @@ bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { return false; } -absl::optional NetworkAddressMatcher::Parse( +std::optional NetworkAddressMatcher::Parse( absl::string_view str) { // range style addr-addr int dash_pos = str.find('-'); From b7096df80e0d7b6facc2943326a1c04cde0f1d27 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:15 -0700 Subject: [PATCH 080/143] No public description PiperOrigin-RevId: 918791557 --- eval/public/equality_function_registrar_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 577c4be22..a77a92734 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -204,7 +204,7 @@ std::string CelValueEqualTestName( } TEST_P(CelValueEqualImplTypesTest, Basic) { - absl::optional result = CelValueEqualImpl(lhs(), rhs()); + std::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { @@ -286,7 +286,7 @@ const std::vector& NumericValuesNotEqualExample() { using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); - absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); + std::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } @@ -299,7 +299,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST(CelValueEqualImplTest, LossyNumericEquality) { - absl::optional result = CelValueEqualImpl( + std::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); From 719f3eed5919bc964b30c7e06a77a3a7eeb64953 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 21:47:09 -0700 Subject: [PATCH 081/143] No public description PiperOrigin-RevId: 918818689 --- eval/tests/benchmark_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index fc0c39294..f188dc0b7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -317,7 +317,7 @@ BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: - absl::optional operator[](CelValue key) const override { + std::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } From a552526a3c58346438cec05cbfe3afeb20657ed6 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Thu, 21 May 2026 12:28:31 -0700 Subject: [PATCH 082/143] Add functions to parse type and function signatures into cel types. PiperOrigin-RevId: 919194450 --- common/internal/BUILD | 8 +- common/internal/signature.cc | 390 +++++++++++++++++++++++- common/internal/signature.h | 21 ++ common/internal/signature_test.cc | 489 ++++++++++++++++++++++++++++-- 4 files changed, 889 insertions(+), 19 deletions(-) diff --git a/common/internal/BUILD b/common/internal/BUILD index 10084b685..48a8dfe8b 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -143,14 +143,17 @@ cc_library( srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ + "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) @@ -159,11 +162,14 @@ cc_test( srcs = ["signature_test.cc"], deps = [ ":signature", + "//common:ast", "//common:type", + "//common:type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/common/internal/signature.cc b/common/internal/signature.cc index f63049878..5c75225f9 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -15,20 +15,30 @@ #include "common/internal/signature.h" #include +#include +#include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { +// Signature generator helper functions. namespace { void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { @@ -58,7 +68,7 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type); // Recursively appends a string representation of the given `type` to `result`. // Type parameters are enclosed in angle brackets and separated by commas. - +// // Grammar: // TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; // NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; @@ -208,4 +218,382 @@ absl::StatusOr MakeOverloadSignature( return result; } + +// Signature parser helper functions. +namespace { + +std::string StripUnescapedWhitespace(std::string_view str) { + std::string result; + result.reserve(str.size()); + bool escaped = false; + for (char c : str) { + if (escaped) { + result.push_back(c); + escaped = false; + continue; + } + if (c == '\\') { + result.push_back(c); + escaped = true; + continue; + } + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + continue; + } + result.push_back(c); + } + return result; +} + +absl::optional ParseBuiltinOrWrapper(std::string_view name_str) { + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any" || name_str == "google.protobuf.Any") + return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp" || name_str == "google.protobuf.Timestamp") + return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration" || name_str == "google.protobuf.Duration") + return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn" || name_str == "google.protobuf.Value") + return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "google.protobuf.ListValue") { + return TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + if (name_str == "google.protobuf.Struct") { + return TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))); + } + + return absl::nullopt; +} + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [¶ms](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (!params.empty()) { + if (ParseBuiltinOrWrapper(name_str).has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid type signature: ", name_str, + " cannot have type parameters")); + } + } else { + if (auto builtin = ParseBuiltinOrWrapper(name_str); builtin.has_value()) { + return *builtin; + } + } + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + if (stripped_sig.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN( + size_t paren_idx, + SignatureScanner(stripped_sig, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || stripped_sig.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = std::string_view(stripped_sig).substr(0, paren_idx); + std::string_view args_str = + std::string_view(stripped_sig) + .substr(paren_idx + 1, stripped_sig.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + } // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h index 3f31d8fd1..3fdba4b2e 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -20,7 +20,10 @@ #include #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -56,6 +59,24 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + // The function signature type, configured as a `FunctionTypeSpec`. + TypeSpec signature_type; +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 8e41c70fb..765055f75 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -38,6 +42,101 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } +void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { + switch (original.kind()) { + case TypeKind::kDyn: + EXPECT_TRUE(parsed.has_dyn()); + break; + case TypeKind::kNull: + EXPECT_TRUE(parsed.has_null()); + break; + case TypeKind::kBool: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); + break; + case TypeKind::kInt: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); + break; + case TypeKind::kUint: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); + break; + case TypeKind::kDouble: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); + break; + case TypeKind::kString: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); + break; + case TypeKind::kBytes: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); + break; + case TypeKind::kAny: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); + break; + case TypeKind::kTimestamp: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); + break; + case TypeKind::kDuration: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); + break; + case TypeKind::kList: + EXPECT_TRUE(parsed.has_list_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.list_type().elem_type(), + original.GetParameters()[0]); + } + break; + case TypeKind::kMap: + EXPECT_TRUE(parsed.has_map_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.map_type().key_type(), + original.GetParameters()[0]); + } + if (original.GetParameters().size() > 1) { + VerifyParsedMatchesType(parsed.map_type().value_type(), + original.GetParameters()[1]); + } + break; + case TypeKind::kBoolWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + case TypeKind::kDoubleWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kBytesWrapper: + EXPECT_TRUE(parsed.has_wrapper()); + break; + case TypeKind::kType: + EXPECT_TRUE(parsed.has_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); + } + break; + case TypeKind::kTypeParam: + EXPECT_TRUE(parsed.has_type_param()); + break; + default: + EXPECT_TRUE(parsed.has_abstract_type()); + break; + } +} + +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kStruct || + lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + struct TypeSignatureTestCase { Type type; std::string expected_signature; @@ -73,10 +172,18 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), StringType{}), .expected_signature = "list", }, + { + .type = TypeType(GetTestArena(), IntType{}), + .expected_signature = "type", + }, { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, + { + .type = ListType(GetTestArena(), TypeParamType("A GetTypeSignatureTestCases() { .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})}), - .expected_signature = "bar>", + .type = OpaqueType(GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), + {StringType{}, BoolType{}})}), + .expected_signature = "bar>", }, { .type = AnyType{}, @@ -104,10 +211,18 @@ std::vector GetTypeSignatureTestCases() { .type = TimestampType{}, .expected_signature = "timestamp", }, + { + .type = BoolWrapperType{}, + .expected_signature = "bool_wrapper", + }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, + { + .type = UintWrapperType{}, + .expected_signature = "uint_wrapper", + }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), @@ -117,22 +232,32 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", }, - { - .type = UnknownType{}, - .expected_error = - "Type kind: *unknown* is not supported in CEL declarations", - }, - { - .type = ErrorType{}, - .expected_error = - "Type kind: *error* is not supported in CEL declarations", - }, }; } +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *unknown* is not supported"))); + + EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *error* is not supported"))); +} + INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + VerifyTypesEqual(*parsed, param.type); + } +} + struct OverloadSignatureTestCase { std::string function_name = "hello"; std::vector args; @@ -202,10 +327,18 @@ std::vector GetOverloadSignatureTestCases() { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, + { + .args = {BoolWrapperType{}}, + .expected_signature = "hello(bool_wrapper)", + }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, + { + .args = {UintWrapperType{}}, + .expected_signature = "hello(uint_wrapper)", + }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( @@ -213,9 +346,6 @@ std::vector GetOverloadSignatureTestCases() { .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, - {.args = {}, - .is_member = true, - .expected_error = "Member function with no receiver"}, { .args = {StringType{}}, .is_member = true, @@ -231,6 +361,18 @@ std::vector GetOverloadSignatureTestCases() { .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, + { + .function_name = "hello", + .args = {OpaqueType(GetTestArena(), "bar", + {TypeParamType("dummy.type")})}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .expected_signature = "inspect(type)", + }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, @@ -242,8 +384,321 @@ std::vector GetOverloadSignatureTestCases() { }; } +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto gp_any, + ParseType("google.protobuf.Any", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_any.IsAny()); + + ASSERT_OK_AND_ASSIGN(auto gp_timestamp, + ParseType("google.protobuf.Timestamp", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_timestamp.IsTimestamp()); + + ASSERT_OK_AND_ASSIGN(auto gp_duration, + ParseType("google.protobuf.Duration", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_duration.IsDuration()); + + ASSERT_OK_AND_ASSIGN(auto gp_value, + ParseType("google.protobuf.Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_value.IsDyn()); + + ASSERT_OK_AND_ASSIGN(auto gp_list_value, + ParseType("google.protobuf.ListValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_list_value.IsList()); + + ASSERT_OK_AND_ASSIGN(auto gp_struct, + ParseType("google.protobuf.Struct", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_struct.IsMap()); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_type1, + ParseType("map < int , string > ", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type1.IsMap()); + + ASSERT_OK_AND_ASSIGN(auto ws_type2, + ParseType("map\t<\nint\r,\tstring\n>\r", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type2.IsMap()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_func1, + ParseFunctionSignature(" hello ( string ) ")); + EXPECT_TRUE(ws_func1.signature_type.has_function()); + EXPECT_EQ(ws_func1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto ws_func2, + ParseFunctionSignature("\thello\n(\rstring\t)\n\r")); + EXPECT_TRUE(ws_func2.signature_type.has_function()); + EXPECT_EQ(ws_func2.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + // Checks that builtin types cannot have type parameters. + EXPECT_THAT( + ParseType("int", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + } // namespace } // namespace cel::common_internal From 6fd7030b954562c5e5c1c1185066e80cad29fd25 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 21 May 2026 18:17:46 -0700 Subject: [PATCH 083/143] not yet exported PiperOrigin-RevId: 919358445 --- common/expr_factory.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index 773217ad9..5607d8deb 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -32,6 +32,7 @@ namespace cel { class MacroExprFactory; class ParserMacroExprFactory; +class OptimizerExprFactory; class ExprFactory { protected: @@ -378,6 +379,7 @@ class ExprFactory { private: friend class MacroExprFactory; friend class ParserMacroExprFactory; + friend class OptimizerExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} From ec82288de1338c6d7763fd722d52c3636965ca1e Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:41:46 -0700 Subject: [PATCH 084/143] No public description PiperOrigin-RevId: 919408845 --- .../descriptor_pool_type_introspector.cc | 12 +++--- .../descriptor_pool_type_introspector_test.cc | 4 +- checker/internal/type_check_env.cc | 10 ++--- checker/internal/type_checker_builder_impl.cc | 6 +-- .../type_checker_builder_impl_test.cc | 2 +- checker/internal/type_checker_impl.cc | 20 +++++----- checker/internal/type_inference_context.cc | 8 ++-- .../internal/type_inference_context_test.cc | 40 +++++++++---------- 8 files changed, 51 insertions(+), 51 deletions(-) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc index f6001e947..da4f4430b 100644 --- a/checker/internal/descriptor_pool_type_introspector.cc +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -35,7 +35,7 @@ namespace { // Standard implementation for field lookups. // Avoids building a FieldTable and just checks the DescriptorPool directly. -absl::StatusOr> +absl::StatusOr> FindStructTypeFieldByNameDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type, absl::string_view name) { @@ -60,7 +60,7 @@ FindStructTypeFieldByNameDirectly( // Standard implementation for listing fields. // Avoids building a FieldTable and just checks the DescriptorPool directly. absl::StatusOr< - absl::optional>> + std::optional>> ListStructTypeFieldsDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type) { @@ -88,7 +88,7 @@ ListStructTypeFieldsDirectly( using Field = DescriptorPoolTypeIntrospector::Field; -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool_->FindMessageTypeByName(name); @@ -103,7 +103,7 @@ DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindEnumConstantImpl( absl::string_view type, absl::string_view value) const { const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = @@ -124,7 +124,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { if (!use_json_name_) { @@ -151,7 +151,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( } absl::StatusOr< - absl::optional>> + std::optional>> DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( absl::string_view type) const { if (!use_json_name_) { diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc index e2fdc9d40..456798744 100644 --- a/checker/internal/descriptor_pool_type_introspector_test.cc +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -117,7 +117,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); - absl::StatusOr> field = + absl::StatusOr> field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); @@ -132,7 +132,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); absl::StatusOr< - absl::optional>> + std::optional>> fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index c080326cb..763d9ba46 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -48,7 +48,7 @@ const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( return nullptr; } -absl::StatusOr> TypeCheckEnv::LookupTypeName( +absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -60,7 +60,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupEnumConstant( +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -77,9 +77,9 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupTypeConstant( +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { - CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); + CEL_ASSIGN_OR_RETURN(std::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(type->name(), TypeType(arena, *type)); } @@ -94,7 +94,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupStructField( +absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 94a05602e..85b581e83 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -158,8 +158,8 @@ absl::StatusOr MergeFunctionDecls( return merged_decl; } -absl::optional FilterDecl(FunctionDecl decl, - const TypeCheckerSubset& subset) { +std::optional FilterDecl(FunctionDecl decl, + const TypeCheckerSubset& subset) { FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); @@ -283,7 +283,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( for (FunctionDeclRecord& fn : config.functions) { FunctionDecl decl = std::move(fn.decl); if (subset != nullptr) { - absl::optional filtered = + std::optional filtered = FilterDecl(std::move(decl), *subset); if (!filtered.has_value()) { continue; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index f7a3dff97..494e7e440 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -144,7 +144,7 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { {}); class MyTypeProvider : public cel::TypeIntrospector { public: - absl::StatusOr> FindTypeImpl( + absl::StatusOr> FindTypeImpl( absl::string_view name) const override { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 2472d7def..1ce871255 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -379,7 +379,7 @@ class ResolveVisitor : public AstVisitorBase { // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( - absl::optional field_info, + std::optional field_info, env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); @@ -405,8 +405,8 @@ class ResolveVisitor : public AstVisitorBase { return absl::OkStatus(); } - absl::optional CheckFieldType(int64_t expr_id, const Type& operand_type, - absl::string_view field_name); + std::optional CheckFieldType(int64_t expr_id, const Type& operand_type, + absl::string_view field_name); void HandleOptSelect(const Expr& expr); void HandleBlockIndex(const Expr* expr); @@ -919,7 +919,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } - absl::optional resolution = + std::optional resolution = inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { @@ -968,7 +968,7 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } - absl::StatusOr> constant = + absl::StatusOr> constant = env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { @@ -1079,9 +1079,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier( } } -absl::optional ResolveVisitor::CheckFieldType(int64_t id, - const Type& operand_type, - absl::string_view field) { +std::optional ResolveVisitor::CheckFieldType(int64_t id, + const Type& operand_type, + absl::string_view field) { if (operand_type.kind() == TypeKind::kDyn || operand_type.kind() == TypeKind::kAny) { return DynType(); @@ -1137,7 +1137,7 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr, const Expr& operand) { const Type& operand_type = GetDeducedType(&operand); - absl::optional result_type; + std::optional result_type; int64_t id = expr.id(); // Support short-hand optional chaining. if (operand_type.IsOptional()) { @@ -1184,7 +1184,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { operand_type = operand_type.GetOptional().GetParameter(); } - absl::optional field_type = CheckFieldType( + std::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { types_[&expr] = ErrorType(); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 96d985071..5b909d982 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -133,7 +133,7 @@ FunctionOverloadInstance InstantiateFunctionOverload( // Converts a wrapper type to its corresponding primitive type. // Returns nullopt if the type is not a wrapper type. -absl::optional WrapperToPrimitive(const Type& t) { +std::optional WrapperToPrimitive(const Type& t) { switch (t.kind()) { case TypeKind::kBoolWrapper: return BoolType(); @@ -286,7 +286,7 @@ bool TypeInferenceContext::IsAssignableInternal( } // Type is as concrete as it can be under current substitutions. - if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); + if (std::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, @@ -531,11 +531,11 @@ bool TypeInferenceContext::IsAssignableWithConstraints( return false; } -absl::optional +std::optional TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, absl::Span argument_types, bool is_receiver) { - absl::optional result_type; + std::optional result_type; std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index d1bf7fa6d..458d08ff1 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -291,7 +291,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadBasic) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); @@ -309,7 +309,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadFails) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -324,7 +324,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -341,7 +341,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_a}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); } @@ -359,7 +359,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_int}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); @@ -375,7 +375,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsBool()); @@ -394,7 +394,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload( decl, {list_of_a_instance, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution.has_value()); @@ -407,7 +407,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {ListType(&arena, IntType()), list_of_a_instance}, false); ASSERT_TRUE(resolution2.has_value()); @@ -433,7 +433,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); EXPECT_FALSE(resolution.has_value()); } @@ -450,13 +450,13 @@ TEST(TypeInferenceContextTest, InferencesAccumulate) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution1 = + std::optional resolution1 = context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, false); ASSERT_TRUE(resolution1.has_value()); EXPECT_TRUE(resolution1->result_type.IsList()); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {resolution1->result_type, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution2.has_value()); @@ -480,7 +480,7 @@ TEST(TypeInferenceContextTest, DebugString) { MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_int, list_of_int}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); @@ -517,7 +517,7 @@ class TypeInferenceContextWrapperTypesTest TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapped_primitive_type}, @@ -534,7 +534,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload( ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); @@ -550,7 +550,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, NullType()}, false); @@ -566,7 +566,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), NullType(), test_case.wrapper_type}, false); @@ -582,7 +582,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapped_primitive_type, test_case.wrapper_type}, @@ -622,7 +622,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -648,7 +648,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -680,7 +680,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(to_type_decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); From 833bb0c8fe93dc9bf5e2971c38254f80d3a42c1f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:46:26 -0700 Subject: [PATCH 085/143] No public description PiperOrigin-RevId: 919410392 --- testutil/test_macros.cc | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc index 158135762..672439dc5 100644 --- a/testutil/test_macros.cc +++ b/testutil/test_macros.cc @@ -37,9 +37,8 @@ bool IsCelNamespace(const Expr& target) { return target.has_ident_expr() && target.ident_expr().name() == "cel"; } -absl::optional CelBlockMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -51,9 +50,8 @@ absl::optional CelBlockMacroExpander(MacroExprFactory& factory, return factory.NewCall("cel.@block", args); } -absl::optional CelIndexMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -70,9 +68,9 @@ absl::optional CelIndexMacroExpander(MacroExprFactory& factory, return factory.NewIdent(absl::StrCat("@index", index)); } -absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -94,9 +92,9 @@ absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, unique_arg.const_expr().int_value())); } -absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } From e00189c323d42b0cacafc320aa890eb4d630d394 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:26:10 -0700 Subject: [PATCH 086/143] No public description PiperOrigin-RevId: 920098804 --- runtime/activation_test.cc | 4 ++-- runtime/function_registry.cc | 5 ++--- runtime/function_registry_test.cc | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index e6a74f027..4303116a3 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -326,7 +326,7 @@ TEST_F(ActivationTest, MoveAssignment) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), @@ -377,7 +377,7 @@ TEST_F(ActivationTest, MoveCtor) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index ac1e53eb5..59f267255 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -44,14 +44,13 @@ class ActivationFunctionProviderImpl public: ActivationFunctionProviderImpl() = default; - absl::StatusOr> GetFunction( + absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); - absl::optional matching_overload = - absl::nullopt; + std::optional matching_overload = absl::nullopt; for (const auto& overload : overloads) { if (overload.descriptor.ShapeMatches(descriptor)) { diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index af7f5bc06..53916777a 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -120,7 +120,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation)); @@ -146,7 +146,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); From 267f4de9814c320e61b20cdd8fbe6580f2e57ac3 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:51:01 -0700 Subject: [PATCH 087/143] No public description PiperOrigin-RevId: 920105999 --- extensions/select_optimization.cc | 16 ++++++++-------- extensions/select_optimization_test.cc | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 44da4c48a..42cad0f92 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -153,7 +153,7 @@ Expr MakeSelectPathExpr( // Returns a single select operation based on the inferred type of the operand // and the field name. If the operand type doesn't define the field, returns // nullopt. -absl::optional GetSelectInstruction( +std::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { auto field_or = planner_context.type_reflector() @@ -407,13 +407,13 @@ class RewriterImpl : public AstRewriterBase { // support message traversal. const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); - absl::optional rt_type = + std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) : absl::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); - absl::optional field_or = + std::optional field_or = GetSelectInstruction(runtime_type, planner_context_, field_name); if (field_or.has_value()) { candidates_[&expr] = std::move(field_or).value(); @@ -538,7 +538,7 @@ class RewriterImpl : public AstRewriterBase { return candidates_.find(operand) != candidates_.end(); } - absl::optional GetRuntimeType(absl::string_view type_name) { + std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } @@ -582,14 +582,14 @@ class OptimizedSelectImpl { AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; - absl::optional attribute() const { return attribute_; } + std::optional attribute() const { return attribute_; } const std::vector& qualifiers() const { return qualifiers_; } private: - absl::optional attribute_; + std::optional attribute_; std::vector select_path_; std::vector qualifiers_; bool presence_test_; @@ -597,7 +597,7 @@ class OptimizedSelectImpl { }; // Check for unknowns or missing attributes. -absl::StatusOr> CheckForMarkedAttributes( +absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { return absl::nullopt; @@ -715,7 +715,7 @@ absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // select arguments. // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); - CEL_ASSIGN_OR_RETURN(absl::optional value, + CEL_ASSIGN_OR_RETURN(std::optional value, CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index c07f4c6ad..9d4024098 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -254,8 +254,9 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { return nullptr; } - absl::optional FindFieldByName( - absl::string_view field_name) const override { + std::optional< + google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> + FindFieldByName(absl::string_view field_name) const override { return absl::nullopt; } From b5db1b3dffb7b890f1a05110ad5833ebe0ebdbee Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 23:10:55 -0700 Subject: [PATCH 088/143] No public description PiperOrigin-RevId: 920403526 --- eval/compiler/flat_expr_builder.cc | 10 +++++----- eval/compiler/flat_expr_builder_extensions.cc | 2 +- eval/compiler/qualified_reference_resolver.cc | 10 +++++----- eval/compiler/regex_precompilation_optimization.cc | 6 +++--- eval/compiler/resolver.cc | 6 +++--- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index e38c912c0..8558c7007 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -840,7 +840,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. - absl::optional const_value; + std::optional const_value; int64_t select_root_id = -1; std::string path_candidate; @@ -1080,7 +1080,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Returns the maximum recursion depth of the current program if it is // eligible for recursion, or nullopt if it is not. - absl::optional RecursionEligible() { + std::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } @@ -1525,7 +1525,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - if (absl::optional depth = RecursionEligible(); depth.has_value()) { + if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1855,7 +1855,7 @@ class FlatExprVisitor : public cel::AstVisitor { int64_t expr_id) { absl::string_view ast_name = create_struct_expr.name(); - absl::optional> type; + std::optional> type; CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); if (!type.has_value()) { @@ -1932,7 +1932,7 @@ class FlatExprVisitor : public cel::AstVisitor { IndexManager index_manager_; bool enable_optional_types_; - absl::optional block_; + std::optional block_; int max_recursion_depth_ = 0; }; diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index 463b48425..e51b64023 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -98,7 +98,7 @@ size_t Subexpression::ComputeSize() const { return size; } -absl::optional Subexpression::RecursiveDependencyDepth() const { +std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 67f86ebb6..67c14d9b2 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -81,9 +81,9 @@ bool OverloadExists(const Resolver& resolver, absl::string_view name, // Return the qualified name of the most qualified matching overload, or // nullopt if no matches are found. -absl::optional BestOverloadMatch(const Resolver& resolver, - absl::string_view base_name, - int argument_count) { +std::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { if (IsSpecialFunction(base_name)) { return std::string(base_name); } @@ -262,8 +262,8 @@ class ReferenceResolver : public cel::AstRewriterBase { // Convert a select expr sub tree into a namespace name if possible. // If any operand of the top element is a not a select or an ident node, // return nullopt. - absl::optional ToNamespace(const Expr& expr) { - absl::optional maybe_parent_namespace; + std::optional ToNamespace(const Expr& expr) { + std::optional maybe_parent_namespace; if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index b94cae383..455796131 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -145,7 +145,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Try to check if the regex is valid, whether or not we can actually update // the plan. - absl::optional pattern = + std::optional pattern = GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); @@ -168,7 +168,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { } private: - absl::optional GetConstantString( + std::optional GetConstantString( PlannerContext& context, ProgramBuilder::Subexpression* absl_nullable subexpression, const Expr& call_expr, const Expr& re_expr) const { @@ -180,7 +180,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Already modified, can't recover the input pattern. return absl::nullopt; } - absl::optional constant; + std::optional constant; if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); auto deps = program.step->GetDependencies(); diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 4e3fa3841..17f60eaad 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -102,8 +102,8 @@ absl::Span Resolver::GetPrefixesFor( return namespace_prefixes_; } -absl::optional Resolver::FindConstant(absl::string_view name, - int64_t expr_id) const { +std::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); @@ -205,7 +205,7 @@ std::vector Resolver::FindLazyOverloads( return funcs; } -absl::StatusOr>> +absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (auto& prefix : prefixes) { From 9911e079ac31ac087e954339170fdd4bc05d3206 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Tue, 26 May 2026 11:21:49 -0700 Subject: [PATCH 089/143] Map well-known protobuf types to their CEL equivalents in LegacyRuntimeType. PiperOrigin-RevId: 921579763 --- common/BUILD | 1 - common/type.cc | 55 ++++++++++++++++++++++++++++++++++++++------- common/type_test.cc | 34 ++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/common/BUILD b/common/BUILD index ffc4ae1e9..a016d2cb5 100644 --- a/common/BUILD +++ b/common/BUILD @@ -583,7 +583,6 @@ cc_library( "//internal:string_pool", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/common/type.cc b/common/type.cc index f94e8bc52..684c5ba09 100644 --- a/common/type.cc +++ b/common/type.cc @@ -640,8 +640,6 @@ constexpr absl::string_view kUInt64TypeName = "uint"; constexpr absl::string_view kDoubleTypeName = "double"; constexpr absl::string_view kStringTypeName = "string"; constexpr absl::string_view kBytesTypeName = "bytes"; -constexpr absl::string_view kDurationTypeName = "google.protobuf.Duration"; -constexpr absl::string_view kTimestampTypeName = "google.protobuf.Timestamp"; constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; @@ -670,12 +668,6 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kBytesTypeName) { return BytesType{}; } - if (name == kDurationTypeName) { - return DurationType{}; - } - if (name == kTimestampTypeName) { - return TimestampType{}; - } if (name == kListTypeName) { return ListType{}; } @@ -685,6 +677,53 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kCelTypeTypeName) { return TypeType{}; } + if (cel::IsWellKnownMessageType(name)) { + if (name == "google.protobuf.Any") { + return AnyType(); + } + if (name == "google.protobuf.BoolValue") { + return BoolWrapperType(); + } + if (name == "google.protobuf.BytesValue") { + return BytesWrapperType(); + } + if (name == "google.protobuf.DoubleValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Duration") { + return DurationType(); + } + if (name == "google.protobuf.FloatValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Int32Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.Int64Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.ListValue") { + return ListType(); + } + if (name == "google.protobuf.StringValue") { + return StringWrapperType(); + } + if (name == "google.protobuf.Struct") { + return JsonMapType(); + } + if (name == "google.protobuf.Timestamp") { + return TimestampType(); + } + if (name == "google.protobuf.UInt32Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.UInt64Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.Value") { + return DynType(); + } + } return common_internal::MakeBasicStructType(name); } diff --git a/common/type_test.cc b/common/type_test.cc index 2cebf27ba..d6a613c3c 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -638,5 +638,39 @@ TEST(Type, Wrap) { EXPECT_EQ(Type(AnyType()).Wrap(), AnyType()); } +TEST(Type, LegacyRuntimeType) { + EXPECT_EQ(common_internal::LegacyRuntimeType("bool"), BoolType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Any"), + AnyType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BoolValue"), + BoolWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BytesValue"), + BytesWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.DoubleValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Duration"), + DurationType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.FloatValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int32Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int64Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.ListValue"), + ListType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.StringValue"), + StringWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Struct"), + JsonMapType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Timestamp"), + TimestampType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt32Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt64Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Value"), + DynType()); +} + } // namespace } // namespace cel From e8f6c48fd01acc29dc67b8b0d2d6ce68ea36c94c Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 26 May 2026 15:05:42 -0700 Subject: [PATCH 090/143] Repo move announcement PiperOrigin-RevId: 921705653 --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 23afe2b00..41b44388d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,12 @@ # C++ Implementations of the Common Expression Language +> [!WARNING] +> **On June 16, 2026, this repository will move to +> github.com/cel-expr/cel-cpp!** +> +> Please update your links and dependencies. See the [pinned +> issue](https://github.com/google/cel-cpp/issues/2029) for details. + For background on the Common Expression Language see the [cel-spec][1] repo. This is a C++ implementation of a [Common Expression Language][1] runtime, From 2f0944e9f25cbe23b202c280f48f0e47f638668c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 27 May 2026 10:42:49 -0700 Subject: [PATCH 091/143] Remove constraint that block initializer is non-empty. PiperOrigin-RevId: 922222081 --- eval/compiler/flat_expr_builder.cc | 10 ++++------ eval/compiler/flat_expr_builder_test.cc | 7 +++---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 8558c7007..1e3f4ecd3 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1047,11 +1047,7 @@ class FlatExprVisitor : public cel::AstVisitor { } const auto& list_expr = call_expr.args().front().list_expr(); block.size = list_expr.elements().size(); - if (block.size == 0) { - SetProgressStatusError(absl::InvalidArgumentError( - "malformed cel.@block: list of bound expressions is empty")); - return; - } + block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { @@ -2052,7 +2048,9 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( } // Otherwise, iterative plan. - AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + if (block.slot_count > 0) { + AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + } return CallHandlerResult::kIntercepted; } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 5fc20f01e..d84007485 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -2820,6 +2820,7 @@ TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { ParsedExpr parsed_expr; + // Allowed, but degenerate case. ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { @@ -2835,10 +2836,8 @@ TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), - StatusIs( - absl::StatusCode::kInvalidArgument, - HasSubstr( - "malformed cel.@block: list of bound expressions is empty"))); + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid @index greater than number of bindings:"))); } TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { From 9fb8e1068781eeb988cf5fcd33aaa05a0c8356b6 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 27 May 2026 12:56:46 -0700 Subject: [PATCH 092/143] Return the first match from `ResolveStatic()` PiperOrigin-RevId: 922297232 --- eval/eval/function_step.cc | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 2a10e9674..fcf429378 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -286,20 +286,12 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr ResolveStatic( absl::Span input_args, absl::Span overloads) { - ResolveResult result = absl::nullopt; - for (const auto& overload : overloads) { if (ArgumentKindsMatch(overload.descriptor, input_args)) { - // More than one overload matches our arguments. - if (result.has_value()) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); - } - - result.emplace(overload); + return overload; } } - return result; + return absl::nullopt; } absl::StatusOr ResolveLazy( @@ -315,7 +307,7 @@ absl::StatusOr ResolveLazy( input_args.begin(), input_args.end(), arg_types.begin(), [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); - cel::FunctionDescriptor matcher{name, receiver_style, arg_types}; + cel::FunctionDescriptor matcher{name, receiver_style, std::move(arg_types)}; const cel::ActivationInterface& activation = frame.activation(); for (auto provider : providers) { From 83ce6bd890aef1c2ae41d8179490dd735d2549f2 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 28 May 2026 11:15:52 -0700 Subject: [PATCH 093/143] Add benchmark test for field access implementation. Adds a `cc_test` target for benchmarking `field_access_impl.cc`. PiperOrigin-RevId: 922875336 --- eval/public/structs/BUILD | 20 ++ .../field_access_impl_benchmark_test.cc | 239 ++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 eval/public/structs/field_access_impl_benchmark_test.cc diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index d301ff0ca..d722559e3 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -442,3 +442,23 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "field_access_impl_benchmark_test", + srcs = ["field_access_impl_benchmark_test.cc"], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//extensions/protobuf/internal:map_reflection", + "//internal:benchmark", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/structs/field_access_impl_benchmark_test.cc b/eval/public/structs/field_access_impl_benchmark_test.cc new file mode 100644 index 000000000..888e424b1 --- /dev/null +++ b/eval/public/structs/field_access_impl_benchmark_test.cc @@ -0,0 +1,239 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/field_access_impl.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/benchmark.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; + +void BM_CreateValueFromSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Int64); + +void BM_CreateValueFromSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_String); + +void BM_CreateValueFromSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.mutable_standalone_message()->set_bb(123); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Message); + +void BM_CreateValueFromRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_Int64); + +void BM_CreateValueFromRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_String); + +void BM_CreateValueFromMapValue_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + (*msg.mutable_map_int64_int64())[42] = 100; + const google::protobuf::FieldDescriptor* map_desc = + TestAllTypes::descriptor()->FindFieldByName("map_int64_int64"); + const google::protobuf::FieldDescriptor* value_desc = + map_desc->message_type()->FindFieldByName("value"); + + google::protobuf::ConstMapIterator iter = + cel::extensions::protobuf_internal::ConstMapBegin(*msg.GetReflection(), + msg, *map_desc); + google::protobuf::MapValueConstRef value_ref = iter.GetValueRef(); + + for (auto _ : state) { + auto value = + CreateValueFromMapValue(&msg, value_desc, &value_ref, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromMapValue_Int64); + +void BM_SetValueToSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Int64); + +void BM_SetValueToSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_String); + +void BM_SetValueToSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + TestAllTypes::NestedMessage nested_msg; + nested_msg.set_bb(123); + CelValue val = CelProtoWrapper::CreateMessage(&nested_msg, &arena); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Message); + +void BM_AddValueToRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + msg.clear_repeated_int64(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_Int64); + +void BM_AddValueToRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_String); + +void BM_CreateValueFromRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string_piece("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_StringPiece); + +void BM_AddValueToRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string_piece(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_StringPiece); + +} // namespace +} // namespace google::api::expr::runtime::internal From 5436842151a69b1dd3e0c4350c7406b4fba36805 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 12:51:26 -0700 Subject: [PATCH 094/143] Update presubmit docker image. Use bazelisk to simplify testing bazel upgrades. PiperOrigin-RevId: 924861509 --- Dockerfile | 45 +++++++++++++++++++++++++++++---------------- cloudbuild.yaml | 4 ++-- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/Dockerfile b/Dockerfile index c2c2915be..97611fc75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,25 +12,25 @@ # # Run the following command from the root of the CEL repository: # -# gcloud builds submit --region=us -t gcr.io/cel-analysis/gcc9 . +# gcloud builds submit --region=us -t gcr.io/cel-analysis/cel-cpp/ubuntu_floor . # # Once complete get the sha256 digest from the output using the following # command: # -# gcloud artifacts versions list --package=gcc9 --repository=gcr.io \ +# gcloud artifacts versions list --package=cel-cpp/ubuntu_floor --repository=gcr.io \ # --location=us # # The cloudbuild.yaml file must be updated to use the new digest like so: # -# - name: 'gcr.io/cel-analysis/gcc9@' -FROM gcc:9 +# - name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@' +FROM gcr.io/cloud-marketplace/google/ubuntu2204:latest # Install Bazel prerequesites and required tools. # See https://docs.bazel.build/versions/master/install-ubuntu.html -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y --no-install-recommends \ - ca-certificates \ +RUN apt-get update && apt-get upgrade -y && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + bash \ + ca-certificates \ git \ libssl-dev \ make \ @@ -41,16 +41,29 @@ RUN apt-get update && \ zip \ zlib1g-dev \ default-jdk-headless \ - clang-11 && \ - apt-get clean + clang-11 \ + gcc-9 g++-9 \ + tzdata \ + && apt-get clean -# Install Bazel. -# https://github.com/bazelbuild/bazel/releases -ARG BAZEL_VERSION="7.3.2" -ADD https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh /tmp/install_bazel.sh -RUN /bin/bash /tmp/install_bazel.sh && rm /tmp/install_bazel.sh +# Install Bazelisk. +# https://github.com/bazelbuild/bazelisk/releases +ARG BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-amd64.deb" +ARG BAZELISK_CHKSUM="d8b00ea975c823e15263c80200ac42979e17368547fbff4ab177af035badfa83" +ADD ${BAZELISK_URL} /tmp/bazelisk.deb + +ENV BAZELISK_CHKSUM=${BAZELISK_CHKSUM} +RUN echo "${BAZELISK_CHKSUM} */tmp/bazelisk.deb" | sha256sum --check + +RUN apt-get install /tmp/bazelisk.deb RUN mkdir -p /workspace RUN mkdir -p /bazel -ENTRYPOINT ["/usr/local/bin/bazel"] +RUN USE_BAZEL_VERSION=8.7.0 bazelisk help +RUN USE_BAZEL_VERSION=7.3.2 bazelisk help + +ENV CC=gcc-9 +ENV CXX=g++-9 + +ENTRYPOINT ["/usr/bin/bazelisk"] diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 8272378f6..dec359f25 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,5 +1,5 @@ steps: -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' args: - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' @@ -16,7 +16,7 @@ steps: - '--google_default_credentials' id: gcc-9 waitFor: ['-'] -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' env: - 'CC=clang-11' - 'CXX=clang++-11' From f1e5042c97e641457b92dcd4cdee57d1facf96fa Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 14:06:53 -0700 Subject: [PATCH 095/143] Bump bazel version to 8.x PiperOrigin-RevId: 924902706 --- .bazelversion | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelversion b/.bazelversion index eab246c06..df5119ec6 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.3.2 +8.7.0 From dfea16a3fefaace88eafbdb294bc8ea376bb1c9b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 14:38:20 -0700 Subject: [PATCH 096/143] Update deps. Patches to antlr were upstreamed to bcr. PiperOrigin-RevId: 924919565 --- MODULE.bazel | 17 ++++------------- bazel/antlr.patch | 30 ------------------------------ 2 files changed, 4 insertions(+), 43 deletions(-) delete mode 100644 bazel/antlr.patch diff --git a/MODULE.bazel b/MODULE.bazel index 02404b645..43d0485d2 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -33,7 +33,7 @@ bazel_dep( ) bazel_dep( name = "protobuf", - version = "33.4", + version = "34.1", repo_name = "com_google_protobuf", ) bazel_dep( @@ -41,20 +41,16 @@ bazel_dep( version = "20260107.0", repo_name = "com_google_absl", ) - bazel_dep( name = "googletest", version = "1.17.0.bcr.2", - dev_dependency = True, repo_name = "com_google_googletest", ) bazel_dep( name = "google_benchmark", version = "1.9.2", - dev_dependency = True, repo_name = "com_github_google_benchmark", ) - bazel_dep( name = "re2", version = "2025-11-05.bcr.1", @@ -74,16 +70,9 @@ bazel_dep( name = "platforms", version = "1.0.0", ) - -ANTLR4_VERSION = "4.13.2" - bazel_dep( name = "antlr4-cpp-runtime", - version = ANTLR4_VERSION, -) -single_version_override( - module_name = "antlr4-cpp-runtime", - patches = ["//bazel:antlr.patch"], + version = "4.13.2.bcr.2", ) python = use_extension("@rules_python//python/extensions:python.bzl", "python") @@ -95,6 +84,8 @@ python.toolchain( http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") +ANTLR4_VERSION = "4.13.2" + http_jar( name = "antlr4_jar", sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", diff --git a/bazel/antlr.patch b/bazel/antlr.patch deleted file mode 100644 index c1aa9080c..000000000 --- a/bazel/antlr.patch +++ /dev/null @@ -1,30 +0,0 @@ ---- BUILD.bazel -+++ BUILD.bazel -@@ -17,21 +17,21 @@ - cc_library( - name = "antlr4-cpp-runtime", - srcs = glob(["runtime/src/**/*.cpp"]), - hdrs = ["runtime/src/antlr4-runtime.h"], - copts = ["-fexceptions"], -- defines = ["ANTLR4CPP_USING_ABSEIL"], -+ defines = ["ANTLR4CPP_USING_ABSEIL", "ANTLR4CPP_STATIC"], - features = ["-use_header_modules"], - includes = ["runtime/src"], - textual_hdrs = glob( - ["runtime/src/**/*.h"], - exclude = ["runtime/src/antlr4-runtime.h"], - ), - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], - ) - ---- VERSION -+++ /dev/null -@@ -1,1 +1,0 @@ --4.13.2 \ No newline at end of file From db0725c120fab156cedf32776bba9d35d9406d85 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 2 Jun 2026 13:31:45 -0700 Subject: [PATCH 097/143] Add basic workflow for windows/bazel test (#2039) Add basic workflow for running conformance tests on windows. Only run on post merge or manually for now. Closes #2039 PiperOrigin-RevId: 925547950 --- .bazelrc | 2 + .github/workflows/windows_bazel_test.yml | 28 +++++++++++++ .../windows_bazel_test_post_merge.yml | 13 ++++++ conformance/BUILD | 1 + conformance/run.bzl | 2 +- conformance/run.cc | 7 ++-- internal/BUILD | 11 +++++ internal/runfiles.cc | 40 +++++++++++++++++++ internal/runfiles.h | 30 ++++++++++++++ 9 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/windows_bazel_test.yml create mode 100644 .github/workflows/windows_bazel_test_post_merge.yml create mode 100644 internal/runfiles.cc create mode 100644 internal/runfiles.h diff --git a/.bazelrc b/.bazelrc index 475706072..1246d336b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,6 +10,8 @@ build:linux --copt=-Wno-deprecated-declarations # you will typically need to spell out the compiler for local dev # BAZEL_VC= # BAZEL_VC_FULL_VERSION=14.44.3520 +# Some dependencies rely on bash so you will likely need msys2 +# BAZEL_SH=C:\msys64\usr\bin\bash.exe build:msvc --cxxopt="-std:c++20" --cxxopt="-utf-8" --host_cxxopt="-std:c++20" build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml new file mode 100644 index 000000000..4ac7f2eec --- /dev/null +++ b/.github/workflows/windows_bazel_test.yml @@ -0,0 +1,28 @@ +name: Windows Bazel Test + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + name: Run Bazel Tests + runs-on: windows-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Setup Bazel and Bazelisk + uses: bazel-contrib/setup-bazel@0.19.0 + with: + bazelisk-cache: true + disk-cache: ${{ github.workflow }} + repository-cache: true + + - name: Run Tests + # msys2 'bash' on Windows will try to 'fix' the label prefix to + # work as a directory. + # //... won't work. + shell: bash + run: | + bazelisk test --config=msvc conformance:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml new file mode 100644 index 000000000..11801011e --- /dev/null +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -0,0 +1,13 @@ +name: Windows Bazel Test (Post-Merge) + +on: + push: + branches: + - master + +jobs: + trigger-test: + # This prevents the workflow from running automatically when someone + # pushes to their fork. + if: github.repository == 'google/cel-cpp' + uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file diff --git a/conformance/BUILD b/conformance/BUILD index 0ca90a4bc..a6f25e001 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -99,6 +99,7 @@ cc_library( deps = [ ":service", ":utils", + "//internal:runfiles", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", diff --git a/conformance/run.bzl b/conformance/run.bzl index 15850b0aa..2c0b51c0e 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -77,7 +77,7 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(location " + test + ")" for test in data], + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data], env = select( { "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, diff --git a/conformance/run.cc b/conformance/run.cc index 80164d9a4..4a0493494 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -48,6 +48,7 @@ #include "absl/types/span.h" #include "conformance/service.h" #include "conformance/utils.h" +#include "internal/runfiles.h" #include "internal/testing.h" #include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" @@ -68,8 +69,6 @@ ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); namespace { -using ::testing::IsEmpty; - using cel::expr::conformance::test::SimpleTest; using cel::expr::conformance::test::SimpleTestFile; using google::api::expr::conformance::v1alpha1::CheckRequest; @@ -78,6 +77,7 @@ using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; +using ::testing::IsEmpty; google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); @@ -282,8 +282,9 @@ int main(int argc, char** argv) { } } for (int argi = 1; argi < argc; argi++) { + std::string path = cel::internal::ResolveRunfilesPath(argv[argi]); ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, - absl::string_view(argv[argi]))); + absl::string_view(path))); } } int exit_code = RUN_ALL_TESTS(); diff --git a/internal/BUILD b/internal/BUILD index 3891c635d..0ac5c4e46 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -86,6 +86,17 @@ cc_library( ], ) +cc_library( + name = "runfiles", + srcs = ["runfiles.cc"], + hdrs = ["runfiles.h"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@rules_cc//cc/runfiles", + ], +) + cc_library( name = "status_builder", hdrs = ["status_builder.h"], diff --git a/internal/runfiles.cc b/internal/runfiles.cc new file mode 100644 index 000000000..259e2e7ca --- /dev/null +++ b/internal/runfiles.cc @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/runfiles.h" + +#include + +#include "rules_cc/cc/runfiles/runfiles.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +std::string ResolveRunfilesPath(absl::string_view path) { + using ::rules_cc::cc::runfiles::Runfiles; + static Runfiles* runfiles = []() { + std::string error; + auto runfiles = + Runfiles::CreateForTest(BAZEL_CURRENT_REPOSITORY, &error); + ABSL_QCHECK(runfiles != nullptr) + << absl::StrCat("failed to init runfiles", error); + return runfiles; + }(); + return runfiles->Rlocation(std::string(path)); +} + +} // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h new file mode 100644 index 000000000..643c677b4 --- /dev/null +++ b/internal/runfiles.h @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Resolves a path relative to the runfiles directory. +// Intended for resolving test cases from cel-spec and cel-policy. +std::string ResolveRunfilesPath(absl::string_view path); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ From e264932db519034136ba38cb7381405b32ee9da6 Mon Sep 17 00:00:00 2001 From: sahvx655-wq Date: Wed, 3 Jun 2026 23:04:15 +0530 Subject: [PATCH 098/143] fix undefined left shift of negative int in math.bitShiftLeft --- extensions/math_ext.cc | 6 +++++- extensions/math_ext_test.cc | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 4d133d90c..a7773da19 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -266,7 +266,11 @@ Value BitShiftLeftInt(int64_t lhs, int64_t rhs) { if (rhs > 63) { return IntValue(0); } - return IntValue(lhs << static_cast(rhs)); + // Shift in the unsigned domain to avoid undefined behaviour when lhs is + // negative or the shift moves bits into the sign bit, matching the bit + // pattern semantics already used by bitShiftRight. + return IntValue(absl::bit_cast(absl::bit_cast(lhs) + << static_cast(rhs))); } Value BitShiftLeftUint(uint64_t lhs, int64_t rhs) { diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 72605648f..ea9331970 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -563,6 +563,8 @@ INSTANTIATE_TEST_SUITE_P( {"math.bitNot(2) == -3"}, {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, {"math.bitShiftLeft(1, 1) == 2"}, + {"math.bitShiftLeft(-1, 1) == -2"}, + {"math.bitShiftLeft(-4, 2) == -16"}, {"math.bitShiftLeft(1u, 1) == 2u"}, {"math.bitShiftRight(4, 1) == 2"}, {"math.bitShiftRight(4u, 1) == 2u"}})); From ad137617e146bf2d050747ca3a7fedbcef0e4018 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 3 Jun 2026 14:00:23 -0700 Subject: [PATCH 099/143] Add memory test for managed constant folding arena PiperOrigin-RevId: 926250668 --- eval/tests/memory_safety_test.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index 9c0a683e4..a88844fed 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -51,7 +51,12 @@ struct TestCase { bool reference_resolver_enabled = false; }; -enum Options { kDefault, kExhaustive, kFoldConstants }; +enum Options { + kDefault, + kExhaustive, + kFoldConstants, + kFoldConstantsManagedArena +}; using ParamType = std::tuple; @@ -68,6 +73,9 @@ std::string TestCaseName(const testing::TestParamInfo& param_info) { case Options::kFoldConstants: opt = "opt"; break; + case Options::kFoldConstantsManagedArena: + opt = "opt_managed_arena"; + break; } return absl::StrCat(std::get<0>(param).name, "_", opt); @@ -110,6 +118,14 @@ class EvaluatorMemorySafetyTest : public testing::TestWithParam { options.enable_comprehension_vulnerability_check = false; options.short_circuiting = true; break; + case Options::kFoldConstantsManagedArena: + options.enable_regex_precompilation = true; + options.constant_folding = true; + options.enable_comprehension_list_append = true; + options.enable_comprehension_vulnerability_check = false; + options.short_circuiting = true; + options.constant_arena = nullptr; + break; } options.enable_qualified_identifier_rewrites = @@ -295,7 +311,8 @@ INSTANTIATE_TEST_SUITE_P( test::IsCelBool(true), }}), testing::Values(Options::kDefault, Options::kExhaustive, - Options::kFoldConstants)), + Options::kFoldConstants, + Options::kFoldConstantsManagedArena)), &TestCaseName); } // namespace From f26fd7d9ddf6cace625627e9b6d94d40507e823a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 4 Jun 2026 11:19:41 -0700 Subject: [PATCH 100/143] Fix leak for optional.of(string). PiperOrigin-RevId: 926790246 --- common/values/optional_value.cc | 4 ++-- runtime/BUILD | 1 + runtime/memory_safety_test.cc | 10 ++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index 688cf8fb0..7c214b9cb 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -345,7 +345,7 @@ OpaqueValue GenericOptionalValueClone( cel::Value* absl_nonnull result = ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { + if (!ArenaTraits<>::trivially_destructible(*result)) { arena->OwnDestructor(result); } return common_internal::MakeOptionalValue( @@ -395,7 +395,7 @@ OptionalValue OptionalValue::Of(cel::Value value, cel::Value* absl_nonnull result = ::new ( arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(std::move(value)); - if (!ArenaTraits<>::trivially_destructible(result)) { + if (!ArenaTraits<>::trivially_destructible(*result)) { arena->OwnDestructor(result); } return OptionalValue(&optional_value_dispatcher, diff --git a/runtime/BUILD b/runtime/BUILD index 776a8223d..34ff411a1 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -615,6 +615,7 @@ cc_test( ":activation", ":constant_folding", ":function_adapter", + ":optional_types", ":reference_resolver", ":regex_precompilation", ":runtime", diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 2a09be666..a60b4ce60 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -45,6 +45,7 @@ #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/function_adapter.h" +#include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/regex_precompilation.h" #include "runtime/runtime.h" @@ -174,6 +175,7 @@ absl::StatusOr> ConfigureRuntimeImpl( if (resolve_references) { CEL_RETURN_IF_ERROR(EnableReferenceResolver( runtime_builder, ReferenceResolverEnabled::kAlways)); + CEL_RETURN_IF_ERROR(extensions::EnableOptionalTypes(runtime_builder)); } if (evaluation_options == Options::kFoldConstants) { CEL_RETURN_IF_ERROR(extensions::EnableConstantFolding(runtime_builder)); @@ -315,6 +317,14 @@ INSTANTIATE_TEST_SUITE_P( {{"condition", BoolValue(false)}}, test::StringValueIs("long_right_hand_string_0123456789"), }, + {"optional_of_long_const_string", + "condition ? optional.of('lhs_short') : " + "optional.of('long_right_hand_string_0123456789')", + {{"condition", BoolValue(false)}}, + test::OptionalValueIs( + test::StringValueIs("long_right_hand_string_0123456789")), + // optional.of is a namespaced function. + /*enable_reference_resolver=*/true}, { "computed_string", "(condition ? 'a.b' : 'b.c') + '.d.e.f'", From e4ff771e68f58433ca20f4471c4a5c9dd86923eb Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 4 Jun 2026 15:39:42 -0700 Subject: [PATCH 101/143] Add ProtoTypeMaskRegistry, ProtoTypeMask, and FieldPath classes. The ProtoTypeMaskRegistry has functions that can be used to validate the input field masks and to check whether a field is visible. PiperOrigin-RevId: 926925505 --- checker/internal/BUILD | 92 ++++ checker/internal/field_path.cc | 30 ++ checker/internal/field_path.h | 77 ++++ checker/internal/field_path_test.cc | 85 ++++ checker/internal/proto_type_mask.cc | 87 ++++ checker/internal/proto_type_mask.h | 111 +++++ checker/internal/proto_type_mask_registry.cc | 180 ++++++++ checker/internal/proto_type_mask_registry.h | 83 ++++ .../internal/proto_type_mask_registry_test.cc | 402 ++++++++++++++++++ checker/internal/proto_type_mask_test.cc | 143 +++++++ 10 files changed, 1290 insertions(+) create mode 100644 checker/internal/field_path.cc create mode 100644 checker/internal/field_path.h create mode 100644 checker/internal/field_path_test.cc create mode 100644 checker/internal/proto_type_mask.cc create mode 100644 checker/internal/proto_type_mask.h create mode 100644 checker/internal/proto_type_mask_registry.cc create mode 100644 checker/internal/proto_type_mask_registry.h create mode 100644 checker/internal/proto_type_mask_registry_test.cc create mode 100644 checker/internal/proto_type_mask_test.cc diff --git a/checker/internal/BUILD b/checker/internal/BUILD index f4c60f937..25550616a 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -310,3 +310,95 @@ cc_test( "@com_google_absl//absl/types:optional", ], ) + +cc_library( + name = "field_path", + srcs = ["field_path.cc"], + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask", + srcs = ["proto_type_mask.cc"], + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/field_path.cc b/checker/internal/field_path.cc new file mode 100644 index 000000000..5ecc4219b --- /dev/null +++ b/checker/internal/field_path.cc @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace cel::checker_internal { + +std::string FieldPath::DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..d67d9b935 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + // Returns the input path. + // For example: "f.b.d". + absl::string_view GetPath() const { return path_; } + + // Returns the list of nested field names in the path. + // For example: {"f", "b", "d"}. + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + // For example: "f". + std::string GetFieldName() const { return field_selection_.front(); } + + template + friend void AbslStringify(Sink& sink, const FieldPath& field_path) { + sink.Append(field_path.DebugString()); + } + + private: + static constexpr char kPathDelimiter = '.'; + + std::string DebugString() const; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..9a1434954 --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, AbslStringifyPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(absl::StrCat(field_path), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesIdenticalFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.type"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.cc b/checker/internal/proto_type_mask.cc new file mode 100644 index 000000000..85e39cb69 --- /dev/null +++ b/checker/internal/proto_type_mask.cc @@ -0,0 +1,87 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr> ProtoTypeMask::GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const { + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, this->GetTypeName())); + absl::btree_set field_names; + for (const FieldPath& field_path : this->GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +std::string ProtoTypeMask::DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", type_name_, + absl::StrJoin(paths, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..f7d522cba --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,111 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a descriptor for the input type name. +// Returns an error if the type name is not found. +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name); + +// Returns a field descriptor for the input field name. +// Returns an error if the field name is not found. +absl::StatusOr FindField( + const google::protobuf::Descriptor* descriptor, absl::string_view field_name); + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::vector& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + } + + // Returns a set of field names. The set includes the first field name from + // each field path. We are able to return a set of absl::string_view because + // the result is backed by the descriptor pool. + absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const; + + // Returns the type's full name. + // For example: "google.rpc.context.AttributeContext". + absl::string_view GetTypeName() const { return type_name_; } + + // Returns a representation of the FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + template + friend void AbslStringify(Sink& sink, const ProtoTypeMask& proto_type_mask) { + sink.Append(proto_type_mask.DebugString()); + } + + private: + std::string DebugString() const; + + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..9c50c9784 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,180 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +// Returns a message type descriptor for the input field descriptor. +// Returns an error if the field is not a message type. +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +// Inserts the type name with an empty set into types_and_visible_fields. +// Returns an error if the type name is already present with a non-empty set. +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({std::string(type_name), {}}); + return absl::OkStatus(); +} + +// Inserts the type name and field name into types_and_visible_fields. +// Returns an error if the type name is already present with an empty set. +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(std::string(field_name)); + return absl::OkStatus(); + } + types_and_visible_fields.insert( + {std::string(type_name), {std::string(field_name)}}); + return absl::OkStatus(); +} + +// Processes the input proto type masks to create and return the +// types_and_visible_fields map. +// Returns an error if one of the proto type masks is not valid. For example, +// if a type is not found in the descriptor pool, if a field name is not +// found, or if a field is not a message type when we are expecting it to be. +// Returns an error if there is a conflict in field visibility when +// updating the map. +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set& field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> +ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + std::shared_ptr proto_type_mask_registry = + absl::WrapUnique(new ProtoTypeMaskRegistry(types_and_visible_fields)); + return proto_type_mask_registry; +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +std::string ProtoTypeMaskRegistry::DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..338353e7d --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,83 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create a ProtoTypeMaskRegistry. + // Returns an error if one of the proto type masks is not valid. For example, + // if a type is not found in the descriptor pool, if a field name is not + // found, or if a field is not a message type when we are expecting it to be. + // Returns an error if there is a conflict in field visibility when + // updating the map. + static absl::StatusOr> Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const; + + template + friend void AbslStringify( + Sink& sink, + const std::shared_ptr& proto_type_mask_registry) { + sink.Append(proto_type_mask_registry->DebugString()); + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + std::string DebugString() const; + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..3a73c8823 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,402 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDelimiterFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {"."})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneRepeatedMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"repeated_nested_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("repeated_nested_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "repeated_nested_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoRepeatedMessageFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"repeated_nested_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("field 'repeated_nested_message' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, AbslStringifyPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + absl::StrCat(proto_type_mask_registry), + AllOf(HasSubstr("ProtoTypeMaskRegistry {"), + HasSubstr("{type: 'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'}"), + HasSubstr("{type: " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'" + ", visible_fields: 'bb'}"))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..0c534f8cf --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,143 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::vector field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithDelimiterFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "."}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_string"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, + UnorderedElementsAre("single_int32", "single_string")); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskTest, AbslStringifyPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(absl::StrCat(proto_type_mask), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal From dd4f36816db185068e3ec8ba3663c95eb3f930db Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Thu, 4 Jun 2026 15:58:40 -0700 Subject: [PATCH 102/143] Fix misuses of __cxa_demangle length parameters PiperOrigin-RevId: 926935330 --- common/typeinfo.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/common/typeinfo.cc b/common/typeinfo.cc index 86bae1934..b07275712 100644 --- a/common/typeinfo.cc +++ b/common/typeinfo.cc @@ -57,18 +57,13 @@ std::string TypeInfo::DebugString() const { } return std::string(demangled.get()); #else - size_t length = 0; int status = 0; std::unique_ptr demangled( - abi::__cxa_demangle(rep_->name(), nullptr, &length, &status)); + abi::__cxa_demangle(rep_->name(), nullptr, nullptr, &status)); if (status != 0 || demangled == nullptr) { return std::string(rep_->name()); } - while (length != 0 && demangled.get()[length - 1] == '\0') { - // length includes the null terminator, remove it. - --length; - } - return std::string(demangled.get(), length); + return std::string(demangled.get()); #endif #else return absl::StrCat("0x", absl::Hex(absl::bit_cast(rep_))); From 35957777f1d8c22f3c4ac5e9cc333c81f559fc8c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 8 Jun 2026 14:25:17 -0700 Subject: [PATCH 103/143] Variadic logical operators PiperOrigin-RevId: 928770216 --- checker/internal/BUILD | 2 + checker/internal/type_checker_impl.cc | 7 +- checker/internal/type_checker_impl_test.cc | 88 ++++++ checker/internal/type_inference_context.cc | 16 +- eval/compiler/flat_expr_builder.cc | 203 +++++++------ eval/compiler/flat_expr_builder_test.cc | 319 ++++++++++++++++++--- parser/options.h | 4 + parser/parser.cc | 21 +- parser/parser_test.cc | 53 ++++ tools/cel_unparser.cc | 25 +- tools/cel_unparser_test.cc | 19 ++ 11 files changed, 613 insertions(+), 144 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 25550616a..26c7b543f 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -145,6 +145,7 @@ cc_library( "//common:container", "//common:decl", "//common:expr", + "//common:standard_definitions", "//common:type", "//common:type_kind", "//internal:lexis", @@ -238,6 +239,7 @@ cc_library( deps = [ ":format_type_name", "//common:decl", + "//common:standard_definitions", "//common:type", "//common:type_kind", "@com_google_absl//absl/container:flat_hash_map", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1ce871255..6b6b051b1 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -52,6 +52,7 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -894,8 +895,12 @@ const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( if (decl == nullptr) { return true; } + bool is_logical_op = (candidate == cel::StandardFunctions::kAnd || + candidate == cel::StandardFunctions::kOr) && + arg_count >= 2; for (const auto& ovl : decl->overloads()) { - if (ovl.member() == is_receiver && ovl.args().size() == arg_count) { + if (ovl.member() == is_receiver && + (ovl.args().size() == arg_count || is_logical_op)) { return false; } } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 893f0689d..61ef7d55b 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -55,6 +55,7 @@ #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" namespace cel { namespace checker_internal { @@ -1471,6 +1472,93 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(DynTypeSpec()))))))); } +struct VariadicLogicalCheckerTestCase { + std::string expr; +}; + +class VariadicLogicalCheckerTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalCheckerTest, Check) { + const auto& test_case = GetParam(); + + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + auto checker_builder = impl.ToBuilder(); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(parsed_ast))); + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveType::kBool))))); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalChecker, VariadicLogicalCheckerTest, + testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"}, + VariadicLogicalCheckerTestCase{"a && b && c && d"}, + VariadicLogicalCheckerTestCase{"a || b || c || d"}, + VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"}, + VariadicLogicalCheckerTestCase{"a && b && c"}, + VariadicLogicalCheckerTestCase{"a || b || c"}, + VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"}, + VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"})); + +TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) { + cel::expr::ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + } + } + )pb", + &parsed_expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + impl.Check(std::move(parsed_ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference"))); +} + TEST(TypeCheckerImplTest, ExpectedTypeMatches) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 5b909d982..1a87d9e15 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -537,21 +538,28 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, bool is_receiver) { std::optional result_type; + bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd || + decl.name() == cel::StandardFunctions::kOr) && + argument_types.size() >= 2; + std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { if (ovl.member() != is_receiver || - argument_types.size() != ovl.args().size()) { + (!is_logical_op && argument_types.size() != ovl.args().size())) { continue; } auto call_type_instance = InstantiateFunctionOverload(*this, ovl); - ABSL_DCHECK_EQ(argument_types.size(), - call_type_instance.param_types.size()); + if (!is_logical_op) { + ABSL_DCHECK_EQ(argument_types.size(), + call_type_instance.param_types.size()); + } bool is_match = true; AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { + int param_index = is_logical_op ? 0 : i; if (!assignability_context.IsAssignable( - argument_types[i], call_type_instance.param_types[i])) { + argument_types[i], call_type_instance.param_types[param_index])) { is_match = false; break; } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1e3f4ecd3..d6ccdf040 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -232,7 +232,7 @@ class BinaryCondVisitor : public CondVisitor { private: FlatExprVisitor* visitor_; const BinaryCond cond_; - Jump jump_step_; + std::vector jump_steps_; bool short_circuiting_; }; @@ -622,7 +622,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPreVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); } } } @@ -639,7 +639,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPostVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -657,7 +657,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_bind && (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { - SetProgressStatusError( + SetProgressStatusIfError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } @@ -666,7 +666,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (block.current_binding == &expr) { int index = program_builder_.ExtractSubexpression(&expr); if (index == -1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("failed to extract subexpression")); return; } @@ -686,7 +686,7 @@ class FlatExprVisitor : public cel::AstVisitor { ConvertConstant(const_expr, cel::NewDeleteAllocator()); if (!converted_value.ok()) { - SetProgressStatusError(converted_value.status()); + SetProgressStatusIfError(converted_value.status()); return; } @@ -722,13 +722,13 @@ class FlatExprVisitor : public cel::AstVisitor { if (absl::ConsumePrefix(&index_suffix, "@index")) { size_t index; if (!absl::SimpleAtoi(index_suffix, &index)) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("bad @index")))); return {-1, -1}; } if (index >= block.size) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "invalid @index greater than number of bindings: ", @@ -736,7 +736,7 @@ class FlatExprVisitor : public cel::AstVisitor { return {-1, -1}; } if (index >= block.current_index) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "@index references current or future binding: ", index, @@ -754,7 +754,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (record.iter_var_in_scope && record.comprehension->iter_var() == path) { if (record.is_optimizable_bind) { - SetProgressStatusError(issue_collector_.AddIssue( + SetProgressStatusIfError(issue_collector_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "Unexpected iter_var access in trivial comprehension")))); return {-1, -1}; @@ -781,7 +781,7 @@ class FlatExprVisitor : public cel::AstVisitor { // If we see a CSE generated comprehension variable that was not // resolvable through the normal comprehension scope resolution, reject it // now rather than surfacing errors at activation time. - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("out of scope reference to CSE " "generated comprehension variable")))); @@ -811,7 +811,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto* subexpression = program_builder_.GetExtractedSubexpression(slot.subexpression); if (subexpression == nullptr) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InternalError("bad subexpression reference")); return; } @@ -965,7 +965,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "unexpected number of dependencies for select operation.")); return; } @@ -1022,7 +1022,7 @@ class FlatExprVisitor : public cel::AstVisitor { // cel.@block if (block_.has_value()) { // There can only be one for now. - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("multiple cel.@block are not allowed")); return; } @@ -1030,17 +1030,17 @@ class FlatExprVisitor : public cel::AstVisitor { BlockInfo& block = *block_; block.in = true; if (call_expr.args().empty()) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing list of bound expressions")); return; } if (call_expr.args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing bound expression")); return; } if (!call_expr.args()[0].has_list_expr()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: first argument " "is not a list of bound expressions")); return; @@ -1051,7 +1051,7 @@ class FlatExprVisitor : public cel::AstVisitor { block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: list of bound " "expressions contains an optional")); return; @@ -1093,7 +1093,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); return; } @@ -1109,7 +1109,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (condition_plan == nullptr || !condition_plan->IsRecursive() || left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1126,45 +1126,52 @@ class FlatExprVisitor : public cel::AstVisitor { } void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (expr->call_expr().args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + int args_size = expr->call_expr().args().size(); + if (args_size < 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); return; } - const cel::Expr* left_expr = &expr->call_expr().args()[0]; - const cel::Expr* right_expr = &expr->call_expr().args()[1]; - auto* left_plan = program_builder_.GetSubexpression(left_expr); - auto* right_plan = program_builder_.GetSubexpression(right_expr); - - if (left_plan == nullptr || !left_plan->IsRecursive() || - right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + auto* current_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[0]); + if (current_plan == nullptr || !current_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } + int current_depth = current_plan->recursive_program().depth; + std::unique_ptr current_step = + current_plan->ExtractRecursiveProgram().step; - int max_depth = std::max({0, left_plan->recursive_program().depth, - right_plan->recursive_program().depth}); - - if (is_or) { - SetRecursiveStep( - CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); - } else { - SetRecursiveStep( - CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); + for (int i = 1; i < args_size; ++i) { + auto* next_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[i]); + if (next_plan == nullptr || !next_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + current_depth = + std::max(current_depth, next_plan->recursive_program().depth); + std::unique_ptr next_step = + next_plan->ExtractRecursiveProgram().step; + if (is_or) { + current_step = + CreateDirectOrStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } else { + current_step = + CreateDirectAndStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } + current_depth++; } + SetRecursiveStep(std::move(current_step), current_depth); } void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); return; } @@ -1176,7 +1183,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, left_plan->recursive_program().depth, @@ -1200,7 +1207,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1234,7 +1241,7 @@ class FlatExprVisitor : public cel::AstVisitor { loop_plan == nullptr || !loop_plan->IsRecursive() || condition_plan == nullptr || !condition_plan->IsRecursive() || result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1462,7 +1469,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( + SetProgressStatusIfError(comprehension_stack_.back().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.back().expr)); } @@ -1524,7 +1531,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateList expr")); return; } @@ -1547,7 +1554,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { - SetProgressStatusError(status_or_resolved_fields.status()); + SetProgressStatusIfError(status_or_resolved_fields.status()); return; } std::string resolved_name = @@ -1558,7 +1565,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1599,7 +1606,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1661,7 +1668,7 @@ class FlatExprVisitor : public cel::AstVisitor { "No overloads provided for FunctionStep creation"), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -1692,7 +1699,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (step.ok()) { return AddStep(*std::move(step)); } else { - SetProgressStatusError(step.status()); + SetProgressStatusIfError(step.status()); } return nullptr; } @@ -1711,19 +1718,19 @@ class FlatExprVisitor : public cel::AstVisitor { return; } if (program_builder_.current() == nullptr) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "CEL AST traversal out of order in flat_expr_builder.")); return; } program_builder_.current()->set_recursive_program(std::move(step), depth); if (depth > max_recursion_depth_) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat("Maximum recursion depth of ", options_.max_recursion_depth, " exceeded"))); } } - void SetProgressStatusError(const absl::Status& status) { + void SetProgressStatusIfError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -1765,7 +1772,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (valid_expression) { return true; } - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat(error_message, message_parts...))); return false; } @@ -1947,7 +1954,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); return CallHandlerResult::kIntercepted; } @@ -1974,7 +1981,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin not operator")); return CallHandlerResult::kIntercepted; } @@ -1997,7 +2004,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected number of args for builtin " "@not_strictly_false operator")); return CallHandlerResult::kIntercepted; @@ -2016,7 +2023,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( ABSL_DCHECK(call_expr.function() == kBlock); if (!block_.has_value() || block_->expr != &expr || call_expr.args().size() != 2 || call_expr.has_target()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected call to internal cel.@block")); return CallHandlerResult::kIntercepted; } @@ -2101,7 +2108,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin equality operator")); return CallHandlerResult::kIntercepted; } @@ -2126,7 +2133,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin 'in' operator")); return CallHandlerResult::kIntercepted; } @@ -2164,13 +2171,14 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } - if (short_circuiting_ && arg_num == 0 && + const int last_arg_index = expr->call_expr().args().size() - 1; + if (short_circuiting_ && arg_num < last_arg_index && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, // jump over the second branch and provide result of the first argument as // final output. - // Retain a pointer to the jump step so we can update the target after - // planning the second argument. + // Retain pointers to the jump steps so we can update the target after + // planning the next arguments. std::unique_ptr jump_step; switch (cond_) { case BinaryCond::kAnd: @@ -2185,7 +2193,7 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2215,7 +2223,7 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2243,28 +2251,36 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; - case BinaryCond::kOptionalOr: - visitor_->AddStep( - CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); - break; - case BinaryCond::kOptionalOrValue: - visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); - break; - default: - ABSL_UNREACHABLE(); + int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) + ? expr->call_expr().args().size() + : 2; + for (int i = 0; i < args_count - 1; ++i) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + case BinaryCond::kOptionalOr: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); + break; + case BinaryCond::kOptionalOrValue: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); + break; + default: + ABSL_UNREACHABLE(); + } } if (short_circuiting_) { // If short-circuiting is enabled, point the conditional jump past the // boolean operator step. - visitor_->SetProgressStatusError( - jump_step_.set_target(visitor_->GetCurrentIndex())); + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } @@ -2321,7 +2337,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->ValidateOrError( jump_to_second_.exists(), "Error configuring ternary operator: jump_to_second_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } @@ -2339,13 +2355,13 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->ValidateOrError( error_jump_.exists(), "Error configuring ternary operator: error_jump_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( error_jump_.set_target(visitor_->GetCurrentIndex())); } if (visitor_->ValidateOrError( jump_after_first_.exists(), "Error configuring ternary operator: jump_after_first_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } } @@ -2403,7 +2419,8 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } Jump jump_helper(index, jump_to_next); - visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); + visitor_->SetProgressStatusIfError( + jump_helper.set_target(next_step_pos_)); // Set offsets jumping to the result step. if (cond_step_) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index d84007485..e2581e3fd 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -469,7 +469,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -482,10 +482,10 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ operand{ ident_expr {name: 'var'} } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -498,11 +498,11 @@ TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ field: 'field' operand { id: 1 } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -515,7 +515,8 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -527,10 +528,10 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{accu_var: "a"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -542,12 +543,12 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: "a" iter_var: "b"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -559,7 +560,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -567,7 +568,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -579,7 +580,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -590,7 +591,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -602,7 +603,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -616,7 +617,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { const_expr {bool_value: false} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -628,7 +629,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -665,7 +666,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -683,7 +684,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; SourceInfo source_info; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( call_expr { function: "_&&_" args { @@ -697,7 +698,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -909,7 +910,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( expr { id: 1 call_expr { @@ -928,7 +929,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -946,7 +947,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CheckedExpr expr; // `foo.var1` && `bar.var2` - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -988,7 +989,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1008,7 +1009,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { CheckedExpr expr; // ext.and(var1, bar.var2) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 1 value { @@ -1057,7 +1058,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1082,7 +1083,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CheckedExpr expr; // && . - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -1125,7 +1126,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1160,7 +1161,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { CheckedExpr expr; // {`var1`: 'hello'} - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 3 value { @@ -1190,7 +1191,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1213,7 +1214,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { Expr expr; SourceInfo source_info; // {}[0].all(x, x) should evaluate OK but return an error value - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1278,7 +1279,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1295,7 +1296,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1349,7 +1350,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1721,7 +1722,7 @@ TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { TEST(FlatExprBuilderTest, MapFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1731,7 +1732,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { field: "string_int32_map" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -1765,7 +1766,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { TEST(FlatExprBuilderTest, RepeatedFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1775,7 +1776,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { field: "int32_list" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -2900,6 +2901,248 @@ TEST(FlatExprBuilderTest, BlockNested) { HasSubstr("multiple cel.@block are not allowed"))); } +struct VariadicLogicalEvalTestCase { + std::string label; + std::string expr; + std::string a_val; + std::string b_val; + std::string c_val; + std::string expected_type; // "bool", "error", "unknown" + bool expected_bool = false; +}; + +class FlatExprBuilderVariadicLogicalTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + std::vector unknown_patterns; + + // Set up variables: + auto insert_value = [&](absl::string_view name, const std::string& val) { + if (val == "true") { + activation.InsertValue(name, CelValue::CreateBool(true)); + } else if (val == "false") { + activation.InsertValue(name, CelValue::CreateBool(false)); + } else if (val == "error") { + activation.InsertValue(name, CreateErrorValue(&arena, "test error")); + } else if (val == "unknown1" || val == "unknown2") { + activation.InsertValue(name, CelValue::CreateBool(true)); + unknown_patterns.push_back(CreateCelAttributePattern(name, {})); + } + }; + + insert_value("a", test_case.a_val); + insert_value("b", test_case.b_val); + insert_value("c", test_case.c_val); + + if (!unknown_patterns.empty()) { + activation.set_unknown_attribute_patterns(std::move(unknown_patterns)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + if (test_case.expected_type == "bool") { + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_EQ(result.BoolOrDie(), test_case.expected_bool); + } else if (test_case.expected_type == "error") { + EXPECT_TRUE(result.IsError()) << result.DebugString(); + } else if (test_case.expected_type == "unknown") { + EXPECT_TRUE(result.IsUnknownSet()) << result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderVariadicLogicalTest, FlatExprBuilderVariadicLogicalTest, + testing::Values( + VariadicLogicalEvalTestCase{"AND_AllTrue", "a && b && c", "true", + "true", "true", "bool", true}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFalse", "a && b && c", + "true", "false", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFirstFalse", "a && b && c", + "false", "unset", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"OR_AllFalse", "a || b || c", "false", + "false", "false", "bool", false}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitTrue", "a || b || c", + "false", "true", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitFirstTrue", "a || b || c", + "true", "unset", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Error", "a && b && c", "true", "error", + "true", "error"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeError", + "a && b && c", "false", "error", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Error", "a || b || c", "false", "error", + "false", "error"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeError", "a || b || c", + "true", "error", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Unknown", "a && b && c", "true", + "unknown1", "true", "unknown"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeUnknown", + "a && b && c", "false", "unknown1", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Unknown", "a || b || c", "false", + "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeUnknown", + "a || b || c", "true", "unknown1", "unset", + "bool", true}, + VariadicLogicalEvalTestCase{"AND_UnknownAggregation", "a && b && c", + "unknown1", "unknown2", "true", "unknown"}, + VariadicLogicalEvalTestCase{"OR_UnknownAggregation", "a || b || c", + "unknown1", "unknown2", "false", "unknown"}, + VariadicLogicalEvalTestCase{"Exists_True", "[a, b, c].exists(x, x)", + "false", "false", "true", "bool", true}, + VariadicLogicalEvalTestCase{"Exists_Unknown", "[a, b, c].exists(x, x)", + "false", "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"All_False", "[a, b, c].all(x, x)", "true", + "true", "false", "bool", false}, + VariadicLogicalEvalTestCase{"All_Unknown", "[a, b, c].all(x, x)", + "true", "unknown1", "true", "unknown"})); + +struct RecursionDepthTestCase { + std::string label; + std::string expr; + int max_recursion_depth; + absl::StatusCode expected_status_code; + std::string expected_error_msg; +}; + +class FlatExprBuilderRecursionDepthTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderRecursionDepthTest, CheckRecursionLimit) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = test_case.max_recursion_depth; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + auto result = + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()); + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, StatusIs(test_case.expected_status_code, + HasSubstr(test_case.expected_error_msg))); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderRecursionDepthTest, FlatExprBuilderRecursionDepthTest, + testing::Values( + RecursionDepthTestCase{"AndChildLimitExceeded", "(1 + 1) && true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"AndParentLimitExceeded", "(1 + 1) && true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"AndLimitSuccess", "(1 + 1) && true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessGenerous", "(1 + 1) && true", 10, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessUnlimited", "(1 + 1) && true", + -1, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrChildLimitExceeded", "(1 + 1) || true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"OrParentLimitExceeded", "(1 + 1) || true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"OrLimitSuccess", "(1 + 1) || true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessGenerous", + "(1 + 1) || false || false || false || false || " + "(true && true && true && true && false)", + 10, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessUnlimited", "(1 + 1) || true", -1, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndDepthUpdateFromSubsequentArg", + "true && (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"}, + RecursionDepthTestCase{"OrDepthUpdateFromSubsequentArg", + "true || (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"})); + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockAndError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_&&_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockOrError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_||_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/parser/options.h b/parser/options.h index 916a941f0..719bed454 100644 --- a/parser/options.h +++ b/parser/options.h @@ -62,6 +62,10 @@ struct ParserOptions final { // Limited to field specifiers in select and message creation, // enabled by default bool enable_quoted_identifiers = true; + + // Enables parsing logical AND & OR operators as a single flat variadic call + // instead of a balanced/nested binary AST structure. + bool enable_variadic_logical_operators = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index 709e2fd41..6c6434319 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -552,7 +552,7 @@ class ExpressionBalancer final { // balance creates a balanced tree from the sub-terms and returns the final // Expr value. - Expr Balance(); + Expr Balance(bool enable_variadic = false); private: // balancedTree recursively balances the terms provided to a commutative @@ -577,10 +577,13 @@ void ExpressionBalancer::AddTerm(int64_t op, Expr term) { ops_.push_back(op); } -Expr ExpressionBalancer::Balance() { +Expr ExpressionBalancer::Balance(bool enable_variadic) { if (terms_.size() == 1) { return std::move(terms_[0]); } + if (enable_variadic) { + return factory_.NewCall(ops_[0], function_, std::move(terms_)); + } return BalancedTree(0, ops_.size() - 1); } @@ -620,7 +623,8 @@ class ParserVisitor final : public CelBaseVisitor, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, - bool enable_quoted_identifiers = false) + bool enable_quoted_identifiers = false, + bool enable_variadic_logical_operators = false) : source_(source), factory_(source_), macro_registry_(macro_registry), @@ -628,7 +632,8 @@ class ParserVisitor final : public CelBaseVisitor, max_recursion_depth_(max_recursion_depth), add_macro_calls_(add_macro_calls), enable_optional_syntax_(enable_optional_syntax), - enable_quoted_identifiers_(enable_quoted_identifiers) {} + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} ~ParserVisitor() override = default; @@ -719,6 +724,7 @@ class ParserVisitor final : public CelBaseVisitor, const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; + const bool enable_variadic_logical_operators_; }; template ParseImpl( ExprRecursionListener listener(options.max_recursion_depth); ParserVisitor visitor( source, options.max_recursion_depth, registry, options.add_macro_calls, - options.enable_optional_syntax, options.enable_quoted_identifiers); + options.enable_optional_syntax, options.enable_quoted_identifiers, + options.enable_variadic_logical_operators); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 587b63a30..33c52b1d2 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,59 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +struct VariadicLogicalOperatorsTestCase { + std::string input; + std::string expected_adorned_string; +}; + +class VariadicLogicalOperatorsTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalOperatorsTest, Parse) { + const auto& test_case = GetParam(); + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.input)); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.Print(ast->root_expr()); + EXPECT_EQ(adorned_string, test_case.expected_adorned_string); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalOperators, VariadicLogicalOperatorsTest, + testing::Values( + VariadicLogicalOperatorsTestCase{ + .input = "a && b && c && d", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a || b || c || d", + .expected_adorned_string = "_||_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a && b && (c || d || e)", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " _||_(\n" + " c^#4:Expr.Ident#,\n" + " d^#5:Expr.Ident#,\n" + " e^#7:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#3:Expr.Call#"})); + TEST(ParserTest, ParseFailurePopulatesIssues) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); diff --git a/tools/cel_unparser.cc b/tools/cel_unparser.cc index 28a1187bb..741d91208 100644 --- a/tools/cel_unparser.cc +++ b/tools/cel_unparser.cc @@ -150,6 +150,8 @@ class Unparser { // - a ternary conditional operator bool IsBinaryOrTernaryOperator(const Expr& expr); + bool IsLogicalOperator(absl::string_view op); + template void Print(Ts&&... args) { absl::StrAppend(&output_, std::forward(args)...); @@ -436,6 +438,24 @@ absl::Status Unparser::VisitUnary(const Expr::Call& expr, absl::Status Unparser::VisitBinary(const Expr::Call& expr, const std::string& op) { + if (expr.args_size() < 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& fun = expr.function(); + if (IsLogicalOperator(fun)) { + for (int i = 0; i < expr.args_size(); ++i) { + if (i > 0) { + Print(kSpace, op, kSpace); + } + const auto& arg = expr.args(i); + bool arg_paren = IsComplexOperatorWithRespectTo(arg, fun); + CEL_RETURN_IF_ERROR(VisitMaybeNested(arg, arg_paren)); + } + return absl::OkStatus(); + } + if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); @@ -443,7 +463,6 @@ absl::Status Unparser::VisitBinary(const Expr::Call& expr, const auto& lhs = expr.args(0); const auto& rhs = expr.args(1); - const auto& fun = expr.function(); // add parens if the current operator is lower precedence than the lhs expr // operator. @@ -549,6 +568,10 @@ bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); } +bool Unparser::IsLogicalOperator(absl::string_view op) { + return op == CelOperator::LOGICAL_AND || op == CelOperator::LOGICAL_OR; +} + } // namespace absl::StatusOr Unparse(const Expr& expr, diff --git a/tools/cel_unparser_test.cc b/tools/cel_unparser_test.cc index 4cba4ce4d..aca6e91fd 100644 --- a/tools/cel_unparser_test.cc +++ b/tools/cel_unparser_test.cc @@ -67,6 +67,22 @@ INSTANTIATE_TEST_SUITE_P( {// Empty Expr error {"", absl::InvalidArgumentError("Unsupported Expr")}, + // Logical operators with too few arguments (single argument) + { + R"pb( + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + { + R"pb( + call_expr { + function: "_||_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + // Constants {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, {"const_expr{bool_value: true}", "true"}, @@ -619,6 +635,7 @@ TEST_P(UnparserTestTextExpr, Test) { options.add_macro_calls = true; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = true; ASSERT_OK_AND_ASSIGN(ParsedExpr result, Parse(GetParam().expr, "unparser", options)); @@ -779,6 +796,8 @@ INSTANTIATE_TEST_SUITE_P( {"has(a.`b.c`)", ""}, {"a.`b/c`", ""}, {"a.?`b/c`", ""}, + {"a && b && c && d", ""}, + {"a || b || c || d", ""}, })); } // namespace From 09dba1a9187b58e9bd7a8c74f8b6e006d16cc99f Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 14:25:30 -0700 Subject: [PATCH 104/143] Refactor type signature generation to use TypeSpec. Refactors the internal signature generation logic to operate on `cel::TypeSpec` instead of `cel::Type`. Adds a utility to convert `cel::Type` to `cel::TypeSpec` PiperOrigin-RevId: 928770349 --- common/BUILD | 1 + common/ast/metadata.h | 4 + common/internal/BUILD | 3 +- common/internal/signature.cc | 275 ++++++++++++---------- common/internal/signature.h | 23 +- common/internal/signature_test.cc | 367 ++++++++++++++++++------------ common/type_spec_resolver.cc | 143 +++++++++++- common/type_spec_resolver.h | 3 + common/type_spec_resolver_test.cc | 27 +++ 9 files changed, 571 insertions(+), 275 deletions(-) diff --git a/common/BUILD b/common/BUILD index a016d2cb5..01710329b 100644 --- a/common/BUILD +++ b/common/BUILD @@ -53,6 +53,7 @@ cc_library( deps = [ ":ast", ":type", + ":type_kind", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/ast/metadata.h b/common/ast/metadata.h index 197790ff3..1a69b5b50 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -573,6 +573,10 @@ class TypeSpec { TypeSpecKind& mutable_type_kind() { return type_kind_; } + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + bool has_dyn() const { return absl::holds_alternative(type_kind_); } diff --git a/common/internal/BUILD b/common/internal/BUILD index 48a8dfe8b..73cbf37e9 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -145,13 +145,11 @@ cc_library( deps = [ "//common:ast", "//common:type", - "//common:type_kind", "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], @@ -165,6 +163,7 @@ cc_test( "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", diff --git a/common/internal/signature.cc b/common/internal/signature.cc index 5c75225f9..fe315bb04 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -26,11 +26,9 @@ #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "common/ast.h" #include "common/type.h" -#include "common/type_kind.h" #include "common/type_spec_resolver.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" @@ -64,125 +62,145 @@ void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { } } -absl::Status AppendTypeParameters(std::string* result, const Type& type); +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -// -// Grammar: -// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; -// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; -// TypeList = TypeElem { "," TypeElem } ; -// TypeElem = TypeDesc | TypeParam -// TypeParam = "~" Alpha ; -// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; -// (* Terminals *) -// Alpha = "a"..."z" | "A"..."Z" ; -// Digit = "0"..."9" ; -// AlphaNumeric = Alpha | Digit ; -// -// For compatibility, the implementation allows unexpected characters in -// type names and parameters and escapes them with a backslash. -absl::Status AppendTypeDesc(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - break; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - break; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - break; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - break; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - break; - case TypeKind::kString: - absl::StrAppend(result, "string"); - break; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - break; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - break; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - break; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - break; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - break; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - break; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - break; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - break; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - break; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - break; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - break; - case TypeKind::kList: - absl::StrAppend(result, "list"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kType: - absl::StrAppend(result, "type"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kTypeParam: - absl::StrAppend(result, "~"); - AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); - break; - case TypeKind::kOpaque: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kStruct: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - default: - return absl::InvalidArgumentError( - absl::StrFormat("Type kind: %s is not supported in CEL declarations", - type.DebugString())); +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); } return absl::OkStatus(); } -absl::Status AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); - if (i < parameters.size() - 1) { - result->push_back(','); - } +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list<"); + if (type_spec.list_type().elem_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map<"); + if (type_spec.map_type().key_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back(','); + if (type_spec.map_type().value_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function<"); + if (type_spec.function().result_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { + result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); + } + result->push_back('>'); + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); } return absl::OkStatus(); } @@ -190,13 +208,32 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type) { absl::StatusOr MakeTypeSignature(const Type& type) { std::string result; - CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); return result; } absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { std::string result; if (is_member) { if (!args.empty()) { @@ -589,10 +626,14 @@ absl::StatusOr ParseFunctionSignature( return out; } +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool) { - std::string stripped_sig = StripUnescapedWhitespace(signature); - CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); return cel::ConvertTypeSpecToType(type_spec, arena, pool); } diff --git a/common/internal/signature.h b/common/internal/signature.h index 3fdba4b2e..8a44fbd5c 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -27,7 +27,7 @@ namespace cel::common_internal { -// Generates an signature for a `cel::Type`, which is a string representation of +// Generates a signature for a `cel::Type`, which is a string representation of // the type. // // Examples: @@ -37,7 +37,17 @@ namespace cel::common_internal { // - `list>` absl::StatusOr MakeTypeSignature(const Type& type); -// Generates an identifier for a function overload based on the function name +// Generates a signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + +// Generates a signature for a function overload based on the function name // and the types of the arguments. If `is_member` is true, the first argument // type is used as the receiver and is prepended to the function name, followed // by a dollar sign. @@ -59,6 +69,15 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Generates a signature for a function overload based on the function name +// and the type specs of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + // Parses a string type signature directly into a `cel::Type`. absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 765055f75..17b628d88 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #include +#include #include #include @@ -24,6 +25,7 @@ #include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -42,82 +44,9 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } -void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { - switch (original.kind()) { - case TypeKind::kDyn: - EXPECT_TRUE(parsed.has_dyn()); - break; - case TypeKind::kNull: - EXPECT_TRUE(parsed.has_null()); - break; - case TypeKind::kBool: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); - break; - case TypeKind::kInt: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); - break; - case TypeKind::kUint: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); - break; - case TypeKind::kDouble: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); - break; - case TypeKind::kString: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); - break; - case TypeKind::kBytes: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); - break; - case TypeKind::kAny: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); - break; - case TypeKind::kTimestamp: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); - break; - case TypeKind::kDuration: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); - break; - case TypeKind::kList: - EXPECT_TRUE(parsed.has_list_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.list_type().elem_type(), - original.GetParameters()[0]); - } - break; - case TypeKind::kMap: - EXPECT_TRUE(parsed.has_map_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.map_type().key_type(), - original.GetParameters()[0]); - } - if (original.GetParameters().size() > 1) { - VerifyParsedMatchesType(parsed.map_type().value_type(), - original.GetParameters()[1]); - } - break; - case TypeKind::kBoolWrapper: - case TypeKind::kIntWrapper: - case TypeKind::kUintWrapper: - case TypeKind::kDoubleWrapper: - case TypeKind::kStringWrapper: - case TypeKind::kBytesWrapper: - EXPECT_TRUE(parsed.has_wrapper()); - break; - case TypeKind::kType: - EXPECT_TRUE(parsed.has_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); - } - break; - case TypeKind::kTypeParam: - EXPECT_TRUE(parsed.has_type_param()); - break; - default: - EXPECT_TRUE(parsed.has_abstract_type()); - break; - } +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); } - void VerifyTypesEqual(const Type& lhs, const Type& rhs) { EXPECT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() != rhs.kind()) return; @@ -138,7 +67,7 @@ void VerifyTypesEqual(const Type& lhs, const Type& rhs) { } struct TypeSignatureTestCase { - Type type; + TypeSpec type; std::string expected_signature; std::string expected_error; }; @@ -149,104 +78,208 @@ TEST_P(TypeSignatureTest, TypeSignature) { const auto& param = GetParam(); absl::StatusOr signature = - common_internal::MakeTypeSignature(param.type); + common_internal::MakeTypeSpecSignature(param.type); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } else { EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + + absl::StatusOr type = ConvertTypeSpecToType( + param.type, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(type, ::absl_testing::IsOk()); + EXPECT_THAT(MakeTypeSignature(*type), + IsOkAndHolds(param.expected_signature)); } } std::vector GetTypeSignatureTestCases() { return { { - .type = StringType{}, - .expected_signature = "string", + .type = TypeSpec(NullTypeSpec{}), + .expected_signature = "null", + }, + { + .type = TypeSpec(PrimitiveType::kBool), + .expected_signature = "bool", }, { - .type = IntType{}, + .type = TypeSpec(PrimitiveType::kInt64), .expected_signature = "int", }, { - .type = ListType(GetTestArena(), StringType{}), - .expected_signature = "list", + .type = TypeSpec(PrimitiveType::kUint64), + .expected_signature = "uint", }, { - .type = TypeType(GetTestArena(), IntType{}), - .expected_signature = "type", + .type = TypeSpec(PrimitiveType::kDouble), + .expected_signature = "double", + }, + { + .type = TypeSpec(PrimitiveType::kString), + .expected_signature = "string", }, { - .type = ListType(GetTestArena(), TypeParamType("A")), + .type = TypeSpec(PrimitiveType::kBytes), + .expected_signature = "bytes", + }, + { + .type = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {})), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kDuration), + .expected_signature = "duration", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_signature = "timestamp", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kString))), + .expected_signature = "list", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A")))), .expected_signature = "list<~A>", }, { - .type = ListType(GetTestArena(), TypeParamType("A(ParamTypeSpec("A(ParamTypeSpec(R"(a,b..(d)\e)")))), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec()))), .expected_signature = "map", }, { - .type = - MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C")), + .type = TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C")))), .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType(GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), - {StringType{}, BoolType{}})}), - .expected_signature = "bar>", + .type = TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(MapTypeSpec(nullptr, nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(std::make_unique(PrimitiveType::kInt64)), + .expected_signature = "type", }, { - .type = AnyType{}, + .type = TypeSpec(WellKnownTypeSpec::kAny), .expected_signature = "any", }, { - .type = DurationType{}, - .expected_signature = "duration", + .type = TypeSpec(DynTypeSpec{}), + .expected_signature = "dyn", }, { - .type = TimestampType{}, - .expected_signature = "timestamp", + .type = TypeSpec(AbstractType( + "bar", {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), + {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool)}))})), + .expected_signature = "bar>", + }, + { + .type = + TypeSpec(AbstractType("bar", {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kString)})), + .expected_signature = "bar", }, { - .type = BoolWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), .expected_signature = "bool_wrapper", }, { - .type = IntWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), .expected_signature = "int_wrapper", }, { - .type = UintWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), .expected_signature = "uint_wrapper", }, { - .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes")), - .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_signature = "double_wrapper", }, { - .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), - .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + .expected_signature = "string_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + .expected_signature = "bytes_wrapper", + }, + { + .type = TypeSpec( + FunctionTypeSpec(nullptr, {TypeSpec(PrimitiveType::kInt64)})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec( + std::make_unique(PrimitiveType::kInt64), {})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec(nullptr, {})), + .expected_signature = "function", }, }; } +INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + TEST(TypeSignatureTest, UnsupportedTypes) { EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *unknown* is not supported"))); + HasSubstr("Unsupported Type kind: *unknown*"))); EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *error* is not supported"))); -} + HasSubstr("Unsupported type in signature: *error*"))); -INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, - ValuesIn(GetTypeSignatureTestCases())); + EXPECT_THAT(common_internal::MakeTypeSpecSignature( + TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported primitive type"))); + + EXPECT_THAT(common_internal::MakeTypeSpecSignature( + TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); + + EXPECT_THAT(common_internal::MakeTypeSpecSignature(TypeSpec( + PrimitiveTypeWrapper(static_cast(999)))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported wrapper type"))); +} TEST_P(TypeSignatureTest, ParseTypeCheck) { const auto& param = GetParam(); @@ -254,13 +287,16 @@ TEST_P(TypeSignatureTest, ParseTypeCheck) { auto parsed = ParseType(param.expected_signature, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(parsed, ::absl_testing::IsOk()); - VerifyTypesEqual(*parsed, param.type); + ASSERT_OK_AND_ASSIGN(auto expected_type, + ConvertTypeSpecToType(param.type, GetTestArena(), + *GetTestingDescriptorPool())); + VerifyTypesEqual(*parsed, expected_type); } } struct OverloadSignatureTestCase { std::string function_name = "hello"; - std::vector args; + std::vector args; bool is_member = false; std::string expected_signature; std::string expected_error; @@ -285,98 +321,110 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { std::vector GetOverloadSignatureTestCases() { return { { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .expected_signature = "hello(string)", }, { - .args = {IntType{}, UintType{}}, + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, .expected_signature = "hello(int,uint)", }, { - .args = {ListType(GetTestArena(), StringType{})}, + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, .expected_signature = "hello(list)", }, { - .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, .expected_signature = "hello(list<~A>)", }, { - .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, .expected_signature = "hello(map)", }, { - .args = {MapType(GetTestArena(), TypeParamType("B"), - TypeParamType("C"))}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, .expected_signature = "hello(map<~B,~C>)", }, + { - .args = {OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, .expected_signature = "hello(bar>)", }, { - .args = {AnyType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, .expected_signature = "hello(any)", }, { - .args = {DurationType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, .expected_signature = "hello(duration)", }, { - .args = {TimestampType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, .expected_signature = "hello(timestamp)", }, { - .args = {BoolWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, .expected_signature = "hello(bool_wrapper)", }, { - .args = {IntWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, .expected_signature = "hello(int_wrapper)", }, { - .args = {UintWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, .expected_signature = "hello(uint_wrapper)", }, { - .args = {MessageType( - GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes"))}, + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .is_member = true, .expected_signature = "string.hello()", }, { - .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, .is_member = true, .expected_signature = "string.hello(list)", }, { - .args = {StringType{}, BoolType{}, DynType{}}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, { .function_name = "hello", - .args = {OpaqueType(GetTestArena(), "bar", - {TypeParamType("dummy.type")})}, + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, .is_member = true, .expected_signature = R"(bar<~dummy\.type>.hello())", }, { .function_name = "inspect", - .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, .expected_signature = "inspect(type)", }, { .function_name = R"(h.(e),l\o)", - .args = {StringType{}, - ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, .is_member = true, .expected_signature = R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", @@ -385,7 +433,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + auto signature = common_internal::MakeOverloadSignature( + "hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -564,8 +613,15 @@ TEST(ParseSignatureTest, ParsingErrors) { EXPECT_THAT(ParseFunctionSignature("foo"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT( + ParseType("list b < c>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); - // Parameter count validations for list and map types. + // Parameter count validations for list, map and type types. EXPECT_THAT(ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -578,6 +634,18 @@ TEST(ParseSignatureTest, ParsingErrors) { *GetTestingDescriptorPool()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("type", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type expects at most 1 parameter"))); + + // Invalid parameter name validations. + EXPECT_THAT(ParseType("~", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + EXPECT_THAT(ParseType("~A", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); // Enforcing valid function and identifier names. EXPECT_THAT(ParseFunctionSignature("()"), @@ -700,5 +768,20 @@ TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { HasSubstr("Empty type signature"))); } +TEST(OverloadSignatureTest, ArgumentTypeVector) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} + } // namespace } // namespace cel::common_internal diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc index 97451f390..90c9930a8 100644 --- a/common/type_spec_resolver.cc +++ b/common/type_spec_resolver.cc @@ -14,6 +14,7 @@ #include "common/type_spec_resolver.h" +#include #include #include #include @@ -22,8 +23,12 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel { @@ -85,28 +90,42 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, } if (type_spec.has_list_type()) { - CEL_ASSIGN_OR_RETURN( - auto elem_type, - ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + Type elem_type; + if (type_spec.list_type().elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + elem_type, ConvertTypeSpecToType(type_spec.list_type().elem_type(), + arena, pool)); + } return Type(ListType(arena, elem_type)); } if (type_spec.has_map_type()) { - CEL_ASSIGN_OR_RETURN( - auto key_type, - ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); - CEL_ASSIGN_OR_RETURN( - auto value_type, - ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + Type key_type; + if (type_spec.map_type().key_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + } + + Type value_type; + if (type_spec.map_type().value_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + value_type, ConvertTypeSpecToType(type_spec.map_type().value_type(), + arena, pool)); + } return Type(MapType(arena, key_type, value_type)); } if (type_spec.has_function()) { const auto& func_spec = type_spec.function(); - CEL_ASSIGN_OR_RETURN( - auto result_type, - ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + Type result_type; + if (func_spec.result_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + } std::vector arg_types; + arg_types.reserve(func_spec.arg_types().size()); for (const auto& arg_spec : func_spec.arg_types()) { CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeSpecToType(arg_spec, arena, pool)); @@ -179,4 +198,104 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, return absl::InvalidArgumentError("Unknown TypeSpec kind"); } +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + } // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h index 44e1e088f..edbfa3bde 100644 --- a/common/type_spec_resolver.h +++ b/common/type_spec_resolver.h @@ -32,6 +32,9 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc index c7fbb2cf8..1cda7280f 100644 --- a/common/type_spec_resolver_test.cc +++ b/common/type_spec_resolver_test.cc @@ -23,6 +23,7 @@ #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" @@ -33,6 +34,7 @@ namespace cel { namespace { using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::HasSubstr; @@ -67,6 +69,7 @@ TEST_P(ConversionTest, TestTypeSpecConversion) { auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), *GetTestingDescriptorPool())); EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); } INSTANTIATE_TEST_SUITE_P( @@ -104,6 +107,8 @@ TEST(TypeSpecResolverTest, ListTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsList()); EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MapTypeConversion) { @@ -116,6 +121,8 @@ TEST(TypeSpecResolverTest, MapTypeConversion) { EXPECT_TRUE(t->IsMap()); EXPECT_TRUE(t->GetMap().key().IsString()); EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, FunctionTypeConversion) { @@ -129,6 +136,8 @@ TEST(TypeSpecResolverTest, FunctionTypeConversion) { EXPECT_TRUE(t->IsFunction()); EXPECT_EQ(t->GetFunction().args().size(), 1); EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeParamConversion) { @@ -138,6 +147,8 @@ TEST(TypeSpecResolverTest, TypeParamConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsTypeParam()); EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeConversion) { @@ -148,6 +159,10 @@ TEST(TypeSpecResolverTest, MessageTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); } TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { @@ -172,6 +187,8 @@ TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { EXPECT_EQ(t->name(), "my.custom.OpaqueType"); EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, OptionalType) { @@ -186,6 +203,8 @@ TEST(TypeSpecResolverTest, OptionalType) { EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeTypeConversion) { @@ -196,6 +215,8 @@ TEST(TypeSpecResolverTest, TypeTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsType()); EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, ErrorTypeConversion) { @@ -204,6 +225,8 @@ TEST(TypeSpecResolverTest, ErrorTypeConversion) { ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { @@ -213,6 +236,8 @@ TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { @@ -231,6 +256,8 @@ TEST(TypeSpecResolverTest, EnumTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsEnum()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { From f5d0d5ff13e770bd35be79e5ce7e81fe35f63e91 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 15:38:29 -0700 Subject: [PATCH 105/143] Clean up unused dependencies in cel/cpp/common BUILD files. Remove unused dependencies from cel/cpp/common/BUILD and cel/cpp/common/internal/BUILD. PiperOrigin-RevId: 928808099 --- common/BUILD | 9 --------- common/internal/BUILD | 2 -- 2 files changed, 11 deletions(-) diff --git a/common/BUILD b/common/BUILD index 01710329b..f7c897e57 100644 --- a/common/BUILD +++ b/common/BUILD @@ -403,7 +403,6 @@ cc_library( ":allocator", ":arena", ":data", - ":native_type", ":reference_count", "//common/internal:metadata", "//common/internal:reference_count", @@ -425,13 +424,9 @@ cc_test( ":allocator", ":data", ":memory", - ":native_type", "//common/internal:reference_count", "//internal:testing", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], @@ -1024,9 +1019,6 @@ cc_library( deps = [ ":decl", ":decl_proto", - ":type", - ":type_proto", - "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1194,6 +1186,5 @@ cc_test( ":container", "//internal:testing", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:status_matchers", ], ) diff --git a/common/internal/BUILD b/common/internal/BUILD index 73cbf37e9..b07faf229 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -21,10 +21,8 @@ cc_library( name = "casting", hdrs = ["casting.h"], deps = [ - "//common:native_type", "//internal:casts", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], From 87fea87272429b9c85655caeda7ab5e4804bf5f4 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 22:27:41 -0700 Subject: [PATCH 106/143] Add support for type signatures in CEL environment YAML configuration. The `env_yaml` parser now accepts type signatures for variable types and function overload signatures. The `type` field can be used instead of `type_name` for variables, allowing a more compact representation of types, including type parameters and parameterized types. The `signature` field can be used for function overloads, providing a single string to define the overload's target, arguments, and member status. The `return` type in function overloads can now also be specified as a type signature string. PiperOrigin-RevId: 928959415 --- env/BUILD | 7 +- env/env_yaml.cc | 311 +++++++++++++++++++++++++--------- env/env_yaml.h | 37 +++- env/env_yaml_test.cc | 381 +++++++++++++++++++++++++++++++++--------- env/type_info.cc | 226 +++++++++++++++++++++++++ env/type_info.h | 7 + env/type_info_test.cc | 169 +++++++++++++++++++ 7 files changed, 978 insertions(+), 160 deletions(-) diff --git a/env/BUILD b/env/BUILD index 41ffc1723..3035e11ac 100644 --- a/env/BUILD +++ b/env/BUILD @@ -28,6 +28,7 @@ cc_library( "type_info.h", ], deps = [ + "//common:ast", "//common:constant", "//common:type", "//common:type_kind", @@ -120,7 +121,9 @@ cc_library( features = ["-use_header_modules"], deps = [ ":config", + "//common:ast", "//common:constant", + "//common/internal:signature", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", @@ -178,9 +181,11 @@ cc_test( ":config", "//common:type", "//common:type_proto", + "//common/ast:metadata", "//internal:proto_matchers", "//internal:testing", "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) @@ -201,7 +206,6 @@ cc_test( "//common:type", "//common:value", "//compiler", - "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", @@ -241,7 +245,6 @@ cc_test( "//common:value", "//compiler", "//extensions:math_ext", - "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 159786598..8c635e65f 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -35,8 +35,11 @@ #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "common/ast.h" #include "common/constant.h" +#include "common/internal/signature.h" #include "env/config.h" +#include "env/type_info.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "yaml-cpp/emitter.h" @@ -117,8 +120,8 @@ absl::StatusOr GetBinary(absl::string_view yaml, return binary; } else { return YamlError(yaml, node, - "Node '" + GetString(yaml, node) + - "' is not a valid Base64 encoded binary"); + absl::StrCat("Node '", GetString(yaml, node), + "' is not a valid Base64 encoded binary")); } } @@ -131,10 +134,22 @@ absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, return node.as(); } catch (YAML::Exception& e) { return YamlError(yaml, node, - "Node '" + std::string(key) + "' is not a boolean"); + absl::StrCat("Node '", key, "' is not a boolean")); } } +// Returns the key in the map `node` that has the given `value_node` as its +// value. If no such key exists, returns `value_node` itself. +YAML::Node GetContextNodeForKeyValue(const YAML::Node& node, + const YAML::Node& value_node) { + for (const auto& kv : node) { + if (kv.second.IsDefined() && kv.second.is(value_node)) { + return kv.first; + } + } + return value_node; +} + absl::Status ParseName(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node name = root["name"]; @@ -407,7 +422,23 @@ absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, absl::StatusOr ParseTypeInfo(const YAML::Node& node, absl::string_view yaml) { Config::TypeInfo type_config; + const YAML::Node type = node["type"]; const YAML::Node type_name = node["type_name"]; + if (type.IsDefined() && type_name.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(node, type_name), + "Node 'type' and 'type_name' are mutually exclusive"); + } + + if (type.IsDefined()) { + if (!type.IsScalar()) { + return YamlError(yaml, type, "Node 'type' is not a string"); + } + CEL_ASSIGN_OR_RETURN(auto type_spec, + common_internal::ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); + return type_config; + } + if (!type_name.IsDefined()) { return type_config; } @@ -627,7 +658,8 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, } absl::StatusOr ParseFunctionOverloadConfig( - absl::string_view yaml, const YAML::Node& overload) { + absl::string_view yaml, const YAML::Node& overload, + absl::string_view function_name) { Config::FunctionOverloadConfig overload_config; if (!overload || !overload.IsMap()) { return YamlError(yaml, overload, "Function overload is not a map"); @@ -654,40 +686,89 @@ absl::StatusOr ParseFunctionOverloadConfig( } } + const YAML::Node signature_node = overload["signature"]; const YAML::Node target = overload["target"]; - if (target.IsDefined()) { - if (!target.IsMap()) { - return YamlError(yaml, target, "Function overload target is not a map"); + const YAML::Node args = overload["args"]; + if (signature_node.IsDefined()) { + if (!signature_node.IsScalar()) { + return YamlError(yaml, signature_node, + "Function overload signature is not a string"); } - CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(target, yaml)); - overload_config.is_member_function = true; - overload_config.parameters.push_back(type_info); - } - const YAML::Node args = overload["args"]; - if (args.IsDefined()) { - if (!args.IsSequence()) { - return YamlError(yaml, args, "Function overload args is not a sequence"); + if (target.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, target), + "Function overload signature and target are mutually " + "exclusive"); + } + if (args.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, args), + "Function overload signature and args are mutually " + "exclusive"); + } + + std::string signature = GetString(yaml, signature_node); + CEL_ASSIGN_OR_RETURN( + common_internal::ParsedFunctionOverload parsed_signature, + common_internal::ParseFunctionSignature(signature)); + if (parsed_signature.function_name != function_name) { + return YamlError(yaml, signature_node, + absl::StrCat("Function overload name \"", + parsed_signature.function_name, + "\" does not match function name \"", + function_name, "\"")); + } + overload_config.is_member_function = parsed_signature.is_member; + if (!parsed_signature.signature_type.has_function()) { + return absl::InternalError(absl::StrCat( + "Function overload signature has no function type: ", signature)); } - for (const YAML::Node& arg : args) { - if (!arg.IsMap()) { - return YamlError(yaml, arg, "Function overload arg is not a map"); + const FunctionTypeSpec& function_type_spec = + parsed_signature.signature_type.function(); + for (const auto& arg : function_type_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto type_info, TypeSpecToTypeInfo(arg)); + overload_config.parameters.push_back(std::move(type_info)); + } + } else { + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); } CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(arg, yaml)); + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; overload_config.parameters.push_back(type_info); } - } + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, + "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + } const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { - if (!return_type.IsMap()) { - return YamlError(yaml, return_type, - "Function overload return type is not a map"); + if (return_type.IsScalar()) { + CEL_ASSIGN_OR_RETURN(auto type_spec, common_internal::ParseTypeSpec( + GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + TypeSpecToTypeInfo(type_spec)); + } else if (return_type.IsMap()) { + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } else { + return YamlError( + yaml, return_type, + "Function overload return type is neither a string nor a map"); } - CEL_ASSIGN_OR_RETURN(overload_config.return_type, - ParseTypeInfo(return_type, yaml)); } return overload_config; } @@ -728,8 +809,9 @@ absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, } for (const YAML::Node& overload : overloads) { - CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, - ParseFunctionOverloadConfig(yaml, overload)); + CEL_ASSIGN_OR_RETURN( + Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload, function_config.name)); function_config.overload_configs.push_back(std::move(overload_config)); } } @@ -893,26 +975,43 @@ void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { out << YAML::EndMap; } -void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { // Note: the map is already started when this is called, so we don't emit // BeginMap here or EndMap at the end. - out << YAML::Key << "type_name"; - out << YAML::Value << YAML::DoubleQuoted << type_info.name; - if (type_info.is_type_param) { - out << YAML::Key << "is_type_param" << YAML::Value << true; - } - if (!type_info.params.empty()) { - out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& param : type_info.params) { - out << YAML::BeginMap; - EmitTypeInfo(param, out); - out << YAML::EndMap; + bool signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "type"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; + } + } + } + if (!signature_generated) { + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } -void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const auto& variable_configs = env_config.GetVariableConfigs(); if (variable_configs.empty()) { return; @@ -936,7 +1035,7 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "description"; out << YAML::Value << YAML::DoubleQuoted << variable_config.description; } - EmitTypeInfo(variable_config.type_info, out); + EmitTypeInfo(variable_config.type_info, out, options); if (variable_config.value.has_value()) { const Constant& constant = variable_config.value; switch (constant.kind_case()) { @@ -991,51 +1090,97 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { } void EmitFunctionOverloadConfig( - const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + absl::string_view function_name, + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { out << YAML::BeginMap; - out << YAML::Key << "id"; - out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; - if (overload_config.is_member_function) { - out << YAML::Key << "target" << YAML::Value; - out << YAML::BeginMap; - if (overload_config.parameters.empty()) { - // This should never happen, but if it does, emit a dynamic type. - EmitTypeInfo({.name = "dyn"}, out); - } else { - EmitTypeInfo(overload_config.parameters[0], out); + if (!overload_config.overload_id.empty()) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + bool signature_generated = false; + if (options.use_type_signatures) { + bool param_type_spec_generated = true; + std::vector params; + params.reserve(overload_config.parameters.size()); + for (const auto& parameter : overload_config.parameters) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(parameter); + if (!type_spec.ok()) { + param_type_spec_generated = false; + break; + } + params.push_back(std::move(*type_spec)); } - out << YAML::EndMap; - if (overload_config.parameters.size() > 1) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (size_t i = 1; i < overload_config.parameters.size(); ++i) { - out << YAML::BeginMap; - EmitTypeInfo(overload_config.parameters[i], out); - out << YAML::EndMap; + if (param_type_spec_generated) { + absl::StatusOr signature = + common_internal::MakeOverloadSignature( + function_name, params, overload_config.is_member_function); + if (signature.ok()) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; } - out << YAML::EndSeq; } - } else { - if (!overload_config.parameters.empty()) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& parameter : overload_config.parameters) { - out << YAML::BeginMap; - EmitTypeInfo(parameter, out); - out << YAML::EndMap; + } + if (!signature_generated) { + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out, options); + } else { + EmitTypeInfo(overload_config.parameters[0], out, options); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } - out << YAML::Key << "return"; - out << YAML::Value << YAML::BeginMap; - EmitTypeInfo(overload_config.return_type, out); - out << YAML::EndMap; - + bool return_type_signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = + TypeInfoToTypeSpec(overload_config.return_type); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + return_type_signature_generated = true; + } + } + } + if (!return_type_signature_generated) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out, options); + out << YAML::EndMap; + } out << YAML::EndMap; } -void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const std::vector& function_configs = env_config.GetFunctionConfigs(); if (function_configs.empty()) { @@ -1085,7 +1230,8 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; for (const Config::FunctionOverloadConfig& overload_config : sorted_overloads) { - EmitFunctionOverloadConfig(overload_config, out); + EmitFunctionOverloadConfig(function_config.name, overload_config, out, + options); } out << YAML::EndSeq; } @@ -1116,7 +1262,8 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { return config; } -void EnvConfigToYaml(const Config& env_config, std::ostream& os) { +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options) { YAML::Emitter out(os); out.SetIndent(2); out << YAML::BeginMap; @@ -1127,8 +1274,8 @@ void EnvConfigToYaml(const Config& env_config, std::ostream& os) { EmitContainerConfig(env_config, out); EmitExtensionConfigs(env_config, out); EmitStandardLibraryConfig(env_config, out); - EmitVariableConfigs(env_config, out); - EmitFunctionConfigs(env_config, out); + EmitVariableConfigs(env_config, out, options); + EmitFunctionConfigs(env_config, out, options); out << YAML::EndMap; } diff --git a/env/env_yaml.h b/env/env_yaml.h index c96b45933..7bf7bf6b4 100644 --- a/env/env_yaml.h +++ b/env/env_yaml.h @@ -31,8 +31,43 @@ namespace cel { // expensive expressions. absl::StatusOr EnvConfigFromYaml(const std::string& yaml); +struct EnvConfigToYamlOptions { + // Whether to use type and overload signatures instead of arg/return types in + // the output YAML. + // Example of type signature: "map>" vs + // type_name: "map" + // params: + // - type_name: "int" + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // + // Example of overload signature config: + // name: "foo" + // overloads: + // - signature: "timestamp.foo(A<~B>)" + // return: "int" + // vs + // name: "foo" + // overloads: + // - id: "foo_id" + // target: + // type_name: "timestamp" + // args: + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // return: + // type_name: "int" + // TODO(uncreated-issue/91): default to true after all dependencies are updated + bool use_type_signatures = false; +}; + // EnvConfigToYaml serializes an environment configuration as a YAML string. -void EnvConfigToYaml(const Config& env_config, std::ostream& os); +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options = {}); } // namespace cel diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index d19c0dbfb..f6bde59c9 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -195,6 +195,28 @@ TEST(EnvYamlTest, ParseVariableConfigs) { } TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type: "map" + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: - name: "dict" @@ -221,7 +243,7 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { } struct ParseConstantTestCase { - std::string type_name; + std::string type; std::string value; std::string expected_error; // Empty if no error. Constant expected_constant; @@ -236,10 +258,10 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { R"yaml( variables: - name: "const" - type_name: "%s" + type: "%s" value: %s )yaml", - param.type_name, param.value); + param.type, param.value); absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); if (!param.expected_error.empty()) { EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, @@ -251,8 +273,7 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); - EXPECT_EQ(variable_config.type_info.name, param.type_name) - << " yaml: " << yaml; + EXPECT_EQ(variable_config.type_info.name, param.type) << " yaml: " << yaml; EXPECT_EQ(variable_config.value, param.expected_constant) << " yaml: " << yaml; } @@ -260,119 +281,119 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { std::vector GetParseConstantTestCases() { return { ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "\"\"", .expected_constant = Constant(nullptr), }, ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "anything", .expected_error = "Failed to parse null constant", }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "TRUE", .expected_constant = Constant(true), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "false", .expected_constant = Constant(false), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "yes", .expected_error = "Failed to parse bool constant", }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "42", .expected_constant = Constant(int64_t{42}), }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "41.999", .expected_error = "Failed to parse int constant", }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42u", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "-1", .expected_error = "Failed to parse uint constant", }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "42.42", .expected_constant = Constant(42.42), }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "abc", .expected_error = "Failed to parse double constant", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "b\"\\xFF\\x00\\x01\"", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary /wAB", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary YWJj=", .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "abc", .expected_constant = Constant(StringConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "\"\\\"abc\\\"\"", .expected_constant = Constant(StringConstant("\"abc\"")), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "1s", .expected_constant = Constant(absl::Seconds(1)), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "abc", .expected_error = "Failed to parse duration constant", }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "2023-01-01T00:00:00Z", .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "abc", .expected_error = "Failed to parse timestamp constant", }, @@ -439,6 +460,50 @@ TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { std::vector GetParseFunctionTestCases() { return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - signature: "google.protobuf.StringValue.isEmpty()" + examples: + - "''.isEmpty() // true" + return: "bool" + - signature: "list<~T>.isEmpty()" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = {{.name = "string_wrapper"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -495,6 +560,34 @@ std::vector GetParseFunctionTestCases() { }, }, }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - signature: "contains(list<~T>, ~T)" + examples: + - "contains([1, 2, 3], 2) // true" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -865,6 +958,18 @@ INSTANTIATE_TEST_SUITE_P( "| is_type_param: maybe\n" "| ^", }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + type: "opaque" + )yaml", + .expected_error = "4:19: Node 'type' and 'type_name'" + " are mutually exclusive\n" + "| type_name: \"opaque\"\n" + "| ^", + }, ParseTestCase{ .yaml = R"yaml( variables: @@ -965,12 +1070,65 @@ INSTANTIATE_TEST_SUITE_P( - name: "foo" overloads: - id: "foo_int64" - return: "to sender" + return: [1] )yaml", .expected_error = "6:31: Function overload return type" - " is not a map\n" - "| return: \"to sender\"\n" + " is neither a string nor a map\n" + "| return: [1]\n" "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + signature: "bar()" + )yaml", + .expected_error = "6:34: Function overload name \"bar\" " + "does not match function name \"foo\"\n" + "| signature: \"bar()\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: [ "foo()" ] + )yaml", + .expected_error = + "5:34: Function overload signature is not a string\n" + "| - signature: [ \"foo()\" ]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + target: + type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and target " + "are mutually exclusive\n" + "| target:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + args: + - type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and args are " + "mutually exclusive\n" + "| args:\n" + "| ^", })); std::string Unindent(std::string_view yaml) { @@ -999,6 +1157,7 @@ std::string Unindent(std::string_view yaml) { struct ExportTestCase { absl::StatusOr config; std::string expected_yaml; + std::string expected_alt_yaml; }; class EnvYamlExportTest : public testing::TestWithParam {}; @@ -1007,10 +1166,18 @@ TEST_P(EnvYamlExportTest, EnvYamlExport) { const ExportTestCase& param = GetParam(); ASSERT_OK_AND_ASSIGN(Config config, param.config); std::stringstream ss; - EnvConfigToYaml(config, ss); + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); std::string yaml_output = Unindent(ss.str()); std::string expected_yaml = Unindent(param.expected_yaml); EXPECT_EQ(yaml_output, expected_yaml); + + if (!param.expected_alt_yaml.empty()) { + std::stringstream alt_ss; + EnvConfigToYaml(config, alt_ss, {.use_type_signatures = false}); + std::string alt_yaml_output = Unindent(alt_ss.str()); + std::string expected_alt_yaml = Unindent(param.expected_alt_yaml); + EXPECT_EQ(alt_yaml_output, expected_alt_yaml); + } } std::vector GetExportTestCases() { @@ -1211,7 +1378,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "null" + type: "null" )yaml", }, ExportTestCase{ @@ -1224,6 +1391,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bool" + value: true + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "bool" @@ -1240,6 +1413,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "int" + value: 42 + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "int" @@ -1258,7 +1437,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "uint" + type: "uint" value: 777 )yaml", }, @@ -1274,7 +1453,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "double" + type: "double" value: 0.75 )yaml", }, @@ -1291,7 +1470,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" )yaml", }, @@ -1309,7 +1488,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "string" + type: "string" value: "'single' \"double\"" )yaml", }, @@ -1324,6 +1503,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "duration" + value: 1h2m3s + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "duration" @@ -1340,6 +1525,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "timestamp" @@ -1358,7 +1549,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "google.expr.proto3.test.TestAllTypes" + type: "google.expr.proto3.test.TestAllTypes" )yaml", }, ExportTestCase{ @@ -1373,6 +1564,11 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "A" + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "A" @@ -1402,12 +1598,22 @@ std::vector GetExportTestCases() { {.overload_id = "foo_overload_id", .is_member_function = true, .parameters = {{.name = "timestamp"}, - {.name = "A", .params = {{.name = "B"}}}}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, .return_type = {.name = "int"}}, }})); return config; }(), .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( functions: - name: "foo" overloads: @@ -1418,6 +1624,7 @@ std::vector GetExportTestCases() { - type_name: "A" params: - type_name: "B" + is_type_param: true return: type_name: "int" )yaml", @@ -1427,6 +1634,7 @@ std::vector GetExportTestCases() { Config config; CEL_RETURN_IF_ERROR(config.AddFunctionConfig( {.name = "foo", + .description = "my desc", .overload_configs = { {.overload_id = "foo_overload_a", .parameters = {{.name = "timestamp"}}, @@ -1442,6 +1650,19 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( functions: - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + signature: "foo(double,A)" + return: "string" + - id: "foo_overload_a" + signature: "foo(timestamp)" + return: "list" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" overloads: - id: "foo_overload_b" args: @@ -1466,9 +1687,10 @@ std::vector GetExportTestCases() { INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, ::testing::ValuesIn(GetExportTestCases())); -class EnvYamlRoundTripTest : public testing::TestWithParam {}; +class EnvYamlStructuredRoundTripTest + : public testing::TestWithParam {}; -TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { +TEST_P(EnvYamlStructuredRoundTripTest, EnvYamlRoundTrip) { const std::string& yaml = Unindent(GetParam()); ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); @@ -1477,7 +1699,7 @@ TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { EXPECT_EQ(ss.str(), yaml); } -std::vector GetRoundTripTestCases() { +std::vector GetStructuredRoundTripTestCases() { return { R"yaml( stdlib: @@ -1536,74 +1758,83 @@ std::vector GetRoundTripTestCases() { overloads: - id: "string_to_timestamp" )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlStructuredRoundTripTest, EnvYamlStructuredRoundTripTest, + ::testing::ValuesIn(GetStructuredRoundTripTestCases())); + +class EnvYamlSignatureRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlSignatureRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetSignatureRoundTripTestCases() { + return { R"yaml( variables: - name: "a" - type_name: "null" + type: "null" - name: "b" - type_name: "bool" + type: "bool" value: true - name: "c" - type_name: "int" + type: "int" value: 42 - name: "d" - type_name: "uint" + type: "uint" value: 777 - name: "e" - type_name: "double" + type: "double" value: 0.75 - name: "f" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" - name: "g" - type_name: "string" + type: "string" value: "plain 'single' \"double\"" - name: "h" - type_name: "duration" + type: "duration" value: 1h2m3s - name: "i" - type_name: "timestamp" + type: "timestamp" value: 2026-01-02T03:04:05Z )yaml", - R"yaml( - functions: - - name: "bar" - - name: "foo" - )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" - target: - type_name: "timestamp" - args: - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "int" + signature: "timestamp.foo(A<~B>)" + return: "int" )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" - args: - - type_name: "timestamp" - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "list" - params: - - type_name: "int" + signature: "foo(timestamp,A<~B>)" + return: "list" )yaml", }; } -INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, - ::testing::ValuesIn(GetRoundTripTestCases())); +INSTANTIATE_TEST_SUITE_P(EnvYamlSignatureRoundTripTest, + EnvYamlSignatureRoundTripTest, + ::testing::ValuesIn(GetSignatureRoundTripTestCases())); } // namespace } // namespace cel diff --git a/env/type_info.cc b/env/type_info.cc index a5b47b6f1..f49fab9f4 100644 --- a/env/type_info.cc +++ b/env/type_info.cc @@ -14,13 +14,17 @@ #include "env/type_info.h" +#include #include +#include #include #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "env/config.h" @@ -180,5 +184,227 @@ absl::StatusOr TypeInfoToType( return DynType(); } } +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + return TypeSpec(ParamTypeSpec(type_info.name)); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty()) { + return TypeSpec(MessageTypeSpec(type_info.name)); + } else { + std::vector param_specs; + param_specs.reserve(type_info.params.size()); + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(TypeSpec param_spec, TypeInfoToTypeSpec(param)); + param_specs.push_back(std::move(param_spec)); + } + return TypeSpec(AbstractType(type_info.name, std::move(param_specs))); + } + } + + switch (*type_kind) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec()); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kList: { + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(TypeSpec elem_type, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } else { + return TypeSpec(ListTypeSpec()); + } + } + case TypeKind::kMap: { + if (type_info.params.empty()) { + return TypeSpec(MapTypeSpec()); + } + CEL_ASSIGN_OR_RETURN(TypeSpec key_type, + TypeInfoToTypeSpec(type_info.params[0])); + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN(TypeSpec value_type, + TypeInfoToTypeSpec(type_info.params[1])); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + return TypeSpec(MapTypeSpec( + std::make_unique(std::move(key_type)), nullptr)); + } + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec()); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeSpec(std::make_unique(DynTypeSpec())); + } + CEL_ASSIGN_OR_RETURN(TypeSpec type_param, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec(std::make_unique(std::move(type_param))); + } + default: + return TypeSpec(DynTypeSpec()); + } +} + +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec) { + Config::TypeInfo type_info; + + if (type_spec.has_dyn()) { + type_info.name = "dyn"; + } else if (type_spec.has_null()) { + type_info.name = "null"; + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + type_info.name = "bool"; + break; + case PrimitiveType::kInt64: + type_info.name = "int"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint"; + break; + case PrimitiveType::kDouble: + type_info.name = "double"; + break; + case PrimitiveType::kString: + type_info.name = "string"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes"; + break; + default: + return absl::InvalidArgumentError("Unspecified primitive type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + type_info.name = "bool_wrapper"; + break; + case PrimitiveType::kInt64: + type_info.name = "int_wrapper"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint_wrapper"; + break; + case PrimitiveType::kDouble: + type_info.name = "double_wrapper"; + break; + case PrimitiveType::kString: + type_info.name = "string_wrapper"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes_wrapper"; + break; + default: + return absl::InvalidArgumentError("Unspecified wrapper type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + type_info.name = "any"; + break; + case WellKnownTypeSpec::kTimestamp: + type_info.name = "timestamp"; + break; + case WellKnownTypeSpec::kDuration: + type_info.name = "duration"; + break; + default: + return absl::InvalidArgumentError("Unspecified well known type"); + } + } else if (type_spec.has_list_type()) { + type_info.name = "list"; + const ListTypeSpec& list_type = type_spec.list_type(); + if (list_type.has_elem_type() && list_type.elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(list_type.elem_type())); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_map_type()) { + type_info.name = "map"; + const MapTypeSpec& map_type = type_spec.map_type(); + bool has_key = + map_type.has_key_type() && map_type.key_type().is_specified(); + bool has_value = + map_type.has_value_type() && map_type.value_type().is_specified(); + if (has_key || has_value) { + if (has_key) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(map_type.key_type())); + type_info.params.push_back(std::move(param)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + if (has_value) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_value, + TypeSpecToTypeInfo(map_type.value_type())); + type_info.params.push_back(std::move(param_value)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + } + } else if (type_spec.has_message_type()) { + type_info.name = type_spec.message_type().type(); + } else if (type_spec.has_type_param()) { + type_info.name = type_spec.type_param().type(); + type_info.is_type_param = true; + } else if (type_spec.has_type()) { + type_info.name = "type"; + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(type_spec.type())); + type_info.params.push_back(std::move(param)); + } else if (type_spec.has_abstract_type()) { + type_info.name = type_spec.abstract_type().name(); + for (const TypeSpec& param_spec : + type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(param_spec)); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_error()) { + return absl::InvalidArgumentError( + "ErrorType cannot be converted to TypeInfo"); + } else if (type_spec.has_function()) { + return absl::InvalidArgumentError( + "FunctionType cannot be converted to TypeInfo"); + } else { + return absl::InvalidArgumentError("Unknown TypeSpec kind"); + } + + return type_info; +} } // namespace cel diff --git a/env/type_info.h b/env/type_info.h index bb3cfde43..3f802ce1a 100644 --- a/env/type_info.h +++ b/env/type_info.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" #include "env/config.h" #include "google/protobuf/arena.h" @@ -30,6 +31,12 @@ absl::StatusOr TypeInfoToType( const Config::TypeInfo& type_info, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); +// Converts a Config::TypeInfo to a cel::TypeSpec. +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info); + +// Converts a cel::TypeSpec to a Config::TypeInfo. +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc index 015d8a928..f9d46f9a9 100644 --- a/env/type_info_test.cc +++ b/env/type_info_test.cc @@ -14,9 +14,14 @@ #include "env/type_info.h" +#include +#include +#include #include #include +#include "absl/status/status.h" +#include "common/ast/metadata.h" #include "common/type.h" #include "common/type_proto.h" #include "env/config.h" @@ -28,9 +33,27 @@ #include "google/protobuf/text_format.h" namespace cel { + +std::ostream& operator<<(std::ostream& os, const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + os << "?"; + } + os << type_info.name; + if (!type_info.params.empty()) { + os << "<"; + for (size_t i = 0; i < type_info.params.size(); ++i) { + if (i > 0) os << ", "; + os << type_info.params[i]; + } + os << ">"; + } + return os; +} + namespace { using absl_testing::IsOk; +using absl_testing::StatusIs; using testing::ValuesIn; struct TestCase { @@ -127,5 +150,151 @@ std::vector GetTestCases() { INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); +bool TypeInfoEqImpl(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + if (actual.name != expected.name) return false; + if (actual.is_type_param != expected.is_type_param) return false; + if (actual.params.size() != expected.params.size()) return false; + for (size_t i = 0; i < actual.params.size(); ++i) { + if (!TypeInfoEqImpl(actual.params[i], expected.params[i])) return false; + } + return true; +} + +MATCHER_P(TypeInfoEq, expected, "") { return TypeInfoEqImpl(arg, expected); } + +struct TypeSpecTestCase { + TypeSpec type_spec; + Config::TypeInfo expected_type_info; +}; + +using TypeSpecToTypeInfoTest = testing::TestWithParam; + +TEST_P(TypeSpecToTypeInfoTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config::TypeInfo actual_type_info, + TypeSpecToTypeInfo(param.type_spec)); + EXPECT_THAT(actual_type_info, TypeInfoEq(param.expected_type_info)); +} + +std::vector GetTypeSpecTestCases() { + return { + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveType::kInt64), + .expected_type_info = {.name = "int"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(ListTypeSpec()), + .expected_type_info = {.name = "list"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(MapTypeSpec()), + .expected_type_info = {.name = "map"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto2.TestAllTypes")), + .expected_type_info = + {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + }, + TypeSpecTestCase{ + .type_spec = + TypeSpec(AbstractType("A", {TypeSpec(ParamTypeSpec("B"))})), + .expected_type_info = {.name = "A", + .params = {Config::TypeInfo{ + .name = "B", .is_type_param = true}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kAny), + .expected_type_info = {.name = "any"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_type_info = {.name = "timestamp"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_type_info = {.name = "double_wrapper"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + std::make_unique(WellKnownTypeSpec::kDuration)), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = + "duration"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(std::make_unique(DynTypeSpec())), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(DynTypeSpec{}), + .expected_type_info = {.name = "dyn"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(NullTypeSpec{}), + .expected_type_info = {.name = "null"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "dyn"}, + Config::TypeInfo{.name = "int"}}}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSpecToTypeInfoTest, TypeSpecToTypeInfoTest, + ValuesIn(GetTypeSpecTestCases())); + +using TypeInfoToTypeSpecTest = testing::TestWithParam; + +TEST_P(TypeInfoToTypeSpecTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(TypeSpec actual_type_spec, + TypeInfoToTypeSpec(param.expected_type_info)); + EXPECT_EQ(actual_type_spec, param.type_spec); +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoToTypeSpecTest, TypeInfoToTypeSpecTest, + ValuesIn(GetTypeSpecTestCases())); + +TEST(TypeSpecToTypeInfoTest, ErrorConversions) { + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(ErrorTypeSpec::kValue)), + StatusIs(absl::StatusCode::kInvalidArgument, + "ErrorType cannot be converted to TypeInfo")); + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(FunctionTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, + "FunctionType cannot be converted to TypeInfo")); + EXPECT_THAT( + TypeSpecToTypeInfo(TypeSpec(UnsetTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, "Unknown TypeSpec kind")); +} + } // namespace } // namespace cel From a8b3224f86712c8c564b9a54abf4eac7cc5ee39a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 9 Jun 2026 11:04:51 -0700 Subject: [PATCH 107/143] Add the function AddContextDeclarationWithProtoTypeMask to the type checker. PiperOrigin-RevId: 929295083 --- checker/internal/BUILD | 6 + checker/internal/type_check_env.cc | 5 + checker/internal/type_check_env.h | 14 + checker/internal/type_checker_builder_impl.cc | 56 +++- checker/internal/type_checker_builder_impl.h | 9 + .../type_checker_builder_impl_test.cc | 311 ++++++++++++++++++ checker/type_checker_builder.h | 22 ++ checker/type_checker_builder_factory_test.cc | 47 +++ 8 files changed, 464 insertions(+), 6 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 26c7b543f..777457830 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -66,6 +66,8 @@ cc_library( hdrs = ["type_check_env.h"], deps = [ ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", "//common:constant", "//common:container", "//common:decl", @@ -76,6 +78,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -129,6 +132,7 @@ cc_library( deps = [ ":format_type_name", ":namespace_generator", + ":proto_type_mask", ":type_check_env", ":type_inference_context", "//checker:checker_options", @@ -154,6 +158,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", @@ -226,6 +231,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 763d9ba46..47487220c 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -96,6 +97,10 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { + if (proto_type_mask_registry_ != nullptr && + !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return absl::nullopt; + } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the // same name -- the later type provider will still be considered when diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 15f8ecc4d..00fea0ba3 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -25,16 +25,20 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "common/constant.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -154,6 +158,14 @@ class TypeCheckEnv { variables_[decl.name()] = std::move(decl); } + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -224,6 +236,8 @@ class TypeCheckEnv { absl::flat_hash_map variables_; absl::flat_hash_map functions_; + std::shared_ptr proto_type_mask_registry_; + // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 85b581e83..9b91fc926 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,13 +24,16 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" @@ -86,10 +90,19 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } absl::Status AddWellKnownContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, - bool use_json_name) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } Type type = MessageTypeField(field).GetType(); if (type.IsEnum()) { type = IntType(); @@ -109,11 +122,15 @@ absl::Status AddWellKnownContextDeclarationVariables( } absl::Status AddContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { const bool use_json_name = env.proto_type_introspector().use_json_name(); if (IsWellKnownMessageType(descriptor)) { - return AddWellKnownContextDeclarationVariables(descriptor, env, - use_json_name); + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); } CEL_ASSIGN_OR_RETURN(auto fields, env.proto_type_introspector().ListFieldsForStructType( @@ -131,6 +148,13 @@ absl::Status AddContextDeclarationVariables( absl::string_view name = field_entry.name; + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(name)) { + continue; + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( absl::StrCat("variable '", name, @@ -317,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } for (const google::protobuf::Descriptor* context_type : config.context_types) { - CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); } for (VariableDeclRecord& var : config.variables) { @@ -339,6 +364,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } } + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + return absl::OkStatus(); } @@ -462,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(template_env_.descriptor_pool())); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 646a5d16f..9895a8aee 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -21,6 +21,7 @@ #include #include "absl/base/nullability.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -28,6 +29,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -76,6 +78,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; @@ -130,6 +134,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector functions; std::vector> type_providers; std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 494e7e440..913e704ee 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -15,12 +15,15 @@ #include "checker/internal/type_checker_builder_impl.h" #include +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" @@ -107,6 +110,168 @@ INSTANTIATE_TEST_SUITE_P( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))})); +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnEmptyFieldPaths) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {}), + StatusIs(absl::StatusCode::kInvalidArgument, + "field paths cannot be the empty set")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnUnknownFieldPath) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"}), + StatusIs(absl::StatusCode::kInvalidArgument, + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'")); +} + +class ContextDeclsWithProtoTypeMaskFieldsDefinedTest + : public testing::TestWithParam {}; + +std::string LogFieldName(absl::string_view field_name, absl::string_view expr) { + return absl::StrCat("field_name: ", field_name, ", expr: ", expr); +} + +TEST_P(ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + ContextDeclsWithProtoTypeMaskFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {GetParam().expr}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + std::vector field_names = { + "single_int64", "single_uint32", "single_double", + "single_string", "single_any", "single_duration", + "single_bool_wrapper", "list_value", "standalone_message", + "standalone_enum", "repeated_bytes", "repeated_nested_message", + "map_int32_timestamp", "single_struct"}; + for (auto& field_name : field_names) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(field_name)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + if (field_name == GetParam().expr) { + // The field name that is part of the proto type mask is visible. + ASSERT_TRUE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type) + << LogFieldName(field_name, GetParam().expr); + } else { + // The field names that are not part of the proto type mask are not + // visible. + EXPECT_FALSE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAccess) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("payload.standalone_message.bb")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int32")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAssignment) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes.NestedMessage{bb: 12345})")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int32: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int64: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -120,6 +285,20 @@ TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { "already exists")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"}), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -129,6 +308,16 @@ TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { "context declaration 'com.example.UnknownType' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask("com.example.UnknownType", + {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -139,6 +328,17 @@ TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { "context declaration 'google.protobuf.Timestamp' is not a struct")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Timestamp", {"any_field_name"}), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + TEST(ContextDeclsTest, CustomStructNotSupported) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -160,6 +360,28 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { "context declaration 'com.example.MyStruct' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return absl::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "com.example.MyStruct", {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -179,6 +401,69 @@ TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + ErrorOnOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + NonOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.NestedTestAllTypes", + {"payload.single_int64"}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int32")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); +} + TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -193,6 +478,32 @@ TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { "variable 'single_int64' declared multiple times")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, NonOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), IsOk()); +} + TEST(TypeCheckerBuilderImplTest, InvalidTypeParamNameVariableValidationDisabled) { CheckerOptions options; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 5dd1f5256..f145b8a98 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -17,6 +17,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" @@ -102,6 +103,27 @@ class TypeCheckerBuilder { // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // This version accepts a mask in terms of field selections from the + // context type. The mask specifies which fields are visible on the + // struct and its members. The visible fields for a type accumulate + // across calls. This is a lightweight way to adjust the type checking + // behavior for a group of related types. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 38430de5f..9c4775e7f 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -396,6 +396,27 @@ TEST(TypeCheckerBuilderTest, AddContextDeclaration) { EXPECT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, AddContextDeclarationWithProtoTypeMask) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, @@ -428,6 +449,32 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, + AllowWellKnownTypeContextDeclarationWithProtoTypeMask) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Any", {"value"}), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + // Visible field: value + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("value")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not visible field: type_url + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("type_url")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; From a721a9a8d287edd074cd7c4fd8157aef9fa3074a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 9 Jun 2026 13:28:55 -0700 Subject: [PATCH 108/143] Add support for old style variable format in env yaml parser. PiperOrigin-RevId: 929372359 --- env/env_yaml.cc | 10 ++++++++-- env/env_yaml_test.cc | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 8c635e65f..1bbfe6b36 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -620,8 +620,14 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, } variable_config.description = GetString(yaml, description); } - - CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); + const YAML::Node type = variable["type"]; + Config::TypeInfo type_info; + if (type.IsDefined() && !type.IsScalar()) { + // Old format, type spec is in 'type' instead of directly embedded. + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable["type"], yaml)); + } else { + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable, yaml)); + } ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); std::string value_str; YAML::Node value = variable["value"]; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index f6bde59c9..a60048617 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -242,6 +242,24 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { EXPECT_THAT(type_info.params[1].params, IsEmpty()); } +TEST(EnvYamlTest, ParseVariableConfigWithNestedRuleOldFormat) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "x" + type: + type_name: "int" + )yaml")); + + ASSERT_THAT(config.GetVariableConfigs(), SizeIs(1)); + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "x"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "int"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); +} + struct ParseConstantTestCase { std::string type; std::string value; From d301931cafb674699daa54d922691ccddc6c55ba Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 9 Jun 2026 16:19:27 -0700 Subject: [PATCH 109/143] internal PiperOrigin-RevId: 929463453 --- env/env.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/env/env.cc b/env/env.cc index 42652ce59..6cd3a3cdc 100644 --- a/env/env.cc +++ b/env/env.cc @@ -122,6 +122,7 @@ absl::StatusOr FunctionConfigToFunctionDecl( Env::Env() { compiler_options_.parser_options.enable_quoted_identifiers = true; + compiler_options_.adapt_parser_errors = true; } absl::StatusOr> Env::NewCompilerBuilder() { From 167e797bdd017b074c1f4fe4a80598d2fa1e1613 Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Thu, 11 Jun 2026 00:28:02 -0700 Subject: [PATCH 110/143] Fix source range offsets for lists, maps and messages. Update the parser to derive the source range for list, map, and struct/message creation expressions from the full parser rule context rather than just the opening token. This ensures the recorded offsets in `EnrichedSourceInfo` span the entire composite expression. PiperOrigin-RevId: 930338742 --- parser/parser.cc | 6 +++--- parser/parser_test.cc | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/parser/parser.cc b/parser/parser.cc index 6c6434319..a858337a4 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1104,7 +1104,7 @@ std::any ParserVisitor::visitCreateMessage( } else { name = absl::StrJoin(parts, "."); } - int64_t obj_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t obj_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector fields; if (ctx->entries) { fields = visitFields(ctx->entries); @@ -1206,7 +1206,7 @@ std::any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { } std::any ParserVisitor::visitCreateList(CelParser::CreateListContext* ctx) { - int64_t list_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t list_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); auto elems = visitList(ctx->elems); return ExprToAny(factory_.NewList(list_id, std::move(elems))); } @@ -1244,7 +1244,7 @@ std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { } std::any ParserVisitor::visitCreateMap(CelParser::CreateMapContext* ctx) { - int64_t struct_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t struct_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector entries; if (ctx->entries) { entries = visitEntries(ctx->entries); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 33c52b1d2..1add80f84 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1520,6 +1520,29 @@ TEST_P(ExpressionTest, Parse) { } } +TEST(ExpressionTest, CompositeExpressionOffsets) { + ParserOptions options; + std::vector macros = Macro::AllMacros(); + + std::string list_expr = "[1, 2]"; + auto list_result = EnrichedParse(list_expr, macros, "", options); + ASSERT_THAT(list_result, IsOk()); + auto list_offsets = list_result->enriched_source_info().offsets(); + EXPECT_EQ(list_offsets.at(1), std::make_pair(0, 5)); + + std::string map_expr = "{'a': 1}"; + auto map_result = EnrichedParse(map_expr, macros, "", options); + ASSERT_THAT(map_result, IsOk()); + auto map_offsets = map_result->enriched_source_info().offsets(); + EXPECT_EQ(map_offsets.at(1), std::make_pair(0, 7)); + + std::string msg_expr = "Msg{f: 1}"; + auto msg_result = EnrichedParse(msg_expr, macros, "", options); + ASSERT_THAT(msg_result, IsOk()); + auto msg_offsets = msg_result->enriched_source_info().offsets(); + EXPECT_EQ(msg_offsets.at(1), std::make_pair(0, 8)); +} + TEST(ExpressionTest, TsanOom) { Parse( "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" From e51522fe9eaea8d5005ee58b4385b16ee642146f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 11 Jun 2026 10:53:19 -0700 Subject: [PATCH 111/143] Update abbreviation / import validation. More closely match the validation in the go and java implementations. PiperOrigin-RevId: 930622025 --- common/container.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/container.cc b/common/container.cc index f69f0cc80..e1db8f86c 100644 --- a/common/container.cc +++ b/common/container.cc @@ -19,6 +19,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "internal/lexis.h" @@ -92,9 +93,11 @@ absl::Status ExpressionContainer::SetContainer(absl::string_view name) { } absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + abrev = absl::StripAsciiWhitespace(abrev); if (!IsValidQualifiedName(abrev)) { return absl::InvalidArgumentError( - absl::StrCat("invalid qualified name: ", abrev)); + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); } auto pos = abrev.rfind('.'); From 6975536ba104d314fcaab2cca1cfa179239554c9 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 11 Jun 2026 11:02:18 -0700 Subject: [PATCH 112/143] Fix variadic logical operator planning PiperOrigin-RevId: 930627361 --- conformance/BUILD | 23 +++++-- conformance/run.bzl | 17 ++++-- conformance/run.cc | 5 ++ conformance/service.cc | 46 +++++++++----- conformance/service.h | 1 + eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 80 +++++++++++++------------ eval/compiler/flat_expr_builder_test.cc | 7 ++- 8 files changed, 113 insertions(+), 67 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index a6f25e001..35d554c7b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -32,7 +32,6 @@ cc_library( "//common:ast", "//common:ast_proto", "//common:decl_proto_v1alpha1", - "//common:expr", "//common:source", "//common:value", "//common/internal:value_conversion", @@ -57,8 +56,6 @@ cc_library( "//extensions/protobuf:enum_adapter", "//internal:status_macros", "//parser", - "//parser:macro", - "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", @@ -75,8 +72,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", @@ -302,6 +297,24 @@ gen_conformance_tests( skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) +gen_conformance_tests( + name = "conformance_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( diff --git a/conformance/run.bzl b/conformance/run.bzl index 2c0b51c0e..8faeb6c16 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators): args = [] if modern: args.append("--modern") @@ -72,12 +72,14 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--noskip_check") if dashboard: args.append("--dashboard") + if enable_variadic_logical_operators: + args.append("--enable_variadic_logical_operators") return args -def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): +def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data], + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data], env = select( { "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, @@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ tags = tags, ) -def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []): +def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False): """Generates conformance tests. Args: name: prefix for all tests + data: textproto targets describing conformance tests modern: run using modern APIs checked: whether to apply type checking - data: textproto targets describing conformance tests + select_opt: enable select optimization + dashboard: enable dashboard mode skip_tests: tests to skip in the format of the cel-spec test runner. See documentation in github.com/google/cel-spec/tests/simple/simple_test.go tags: tags added to the generated targets - dashboard: enable dashboard mode + enable_variadic_logical_operators: enable variadic logical operators """ skip_check = not checked tests = [] @@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op skip_tests = _expand_tests_to_skip(skip_tests), tags = tags, dashboard = dashboard, + enable_variadic_logical_operators = enable_variadic_logical_operators, ) native.test_suite( name = name, diff --git a/conformance/run.cc b/conformance/run.cc index 4a0493494..1be16ba60 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -66,6 +66,9 @@ ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); +ABSL_FLAG(bool, enable_variadic_logical_operators, false, + "Enable parsing logical AND & OR operators as a single flat variadic " + "call."); namespace { @@ -261,6 +264,8 @@ NewConformanceServiceFromFlags() { .modern = absl::GetFlag(FLAGS_modern), .recursive = absl::GetFlag(FLAGS_recursive), .select_optimization = absl::GetFlag(FLAGS_select_optimization), + .enable_variadic_logical_operators = + absl::GetFlag(FLAGS_enable_variadic_logical_operators), }); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( diff --git a/conformance/service.cc b/conformance/service.cc index 7e3eded82..d81200cad 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -128,13 +128,15 @@ cel::expr::Expr ExtractExpr( absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response, - bool enable_optional_syntax) { + bool enable_optional_syntax, + bool enable_variadic_logical_operators) { if (request.cel_source().empty()) { return absl::InvalidArgumentError("no source code"); } cel::ParserOptions options; options.enable_optional_syntax = enable_optional_syntax; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = enable_variadic_logical_operators; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); CEL_RETURN_IF_ERROR( @@ -236,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, class LegacyConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< @@ -313,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder->GetRegistry(), options)); - return absl::WrapUnique( - new LegacyConformanceServiceImpl(std::move(builder))); + return absl::WrapUnique(new LegacyConformanceServiceImpl( + std::move(builder), enable_variadic_logical_operators)); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/false); + LegacyParse(request, response, /*enable_optional_syntax=*/false, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -418,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { } private: - explicit LegacyConformanceServiceImpl( - std::unique_ptr builder) - : builder_(std::move(builder)) {} + LegacyConformanceServiceImpl(std::unique_ptr builder, + bool enable_variadic_logical_operators) + : builder_(std::move(builder)), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} std::unique_ptr builder_; + bool enable_variadic_logical_operators_; }; class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< @@ -470,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { options.max_recursion_depth = 48; } - return absl::WrapUnique(new ModernConformanceServiceImpl( - options, optimize, select_optimization)); + return absl::WrapUnique( + new ModernConformanceServiceImpl(options, optimize, select_optimization, + enable_variadic_logical_operators)); } absl::StatusOr> Setup( @@ -523,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/true); + LegacyParse(request, response, /*enable_optional_syntax=*/true, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -614,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { private: ModernConformanceServiceImpl(const RuntimeOptions& options, bool enable_optimizations, - bool enable_select_optimization) + bool enable_select_optimization, + bool enable_variadic_logical_operators) : options_(options), enable_optimizations_(enable_optimizations), - enable_select_optimization_(enable_select_optimization) {} + enable_select_optimization_(enable_select_optimization), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} static absl::StatusOr> Plan( const cel::Runtime& runtime, @@ -648,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool enable_optimizations_; bool enable_select_optimization_; + bool enable_variadic_logical_operators_; }; } // namespace @@ -660,10 +672,12 @@ absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } } diff --git a/conformance/service.h b/conformance/service.h index 2dd2abf32..8eb97296e 100644 --- a/conformance/service.h +++ b/conformance/service.h @@ -46,6 +46,7 @@ struct ConformanceServiceOptions { bool arena; bool recursive; bool select_optimization; + bool enable_variadic_logical_operators = false; }; absl::StatusOr> diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index ed8e4d20c..f7300cb58 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -193,6 +193,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//parser:options", "//runtime:function", "//runtime:function_adapter", "//runtime:runtime_options", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index d6ccdf040..fc6d87b16 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -2154,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { case BinaryCond::kOr: visitor_->ValidateOrError( !expr->call_expr().has_target() && - expr->call_expr().args().size() == 2, + expr->call_expr().args().size() >= 2, "Invalid argument count for a binary function call."); break; case BinaryCond::kOptionalOr: @@ -2172,28 +2172,40 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { return; } const int last_arg_index = expr->call_expr().args().size() - 1; - if (short_circuiting_ && arg_num < last_arg_index && - (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { - // If first branch evaluation result is enough to determine output, - // jump over the second branch and provide result of the first argument as - // final output. - // Retain pointers to the jump steps so we can update the target after - // planning the next arguments. - std::unique_ptr jump_step; - switch (cond_) { - case BinaryCond::kAnd: - jump_step = CreateCondJumpStep(false, true, {}, expr->id()); - break; - case BinaryCond::kOr: - jump_step = CreateCondJumpStep(true, true, {}, expr->id()); - break; - default: - ABSL_UNREACHABLE(); + if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) { + if (arg_num > 0) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + default: + break; + } + if (short_circuiting_ && !jump_steps_.empty()) { + visitor_->SetProgressStatusIfError( + jump_steps_.back().set_target(visitor_->GetCurrentIndex())); + } } - ProgramStepIndex index = visitor_->GetCurrentIndex(); - if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); - jump_step_ptr) { - jump_steps_.push_back(Jump(index, jump_step_ptr)); + if (short_circuiting_ && arg_num < last_arg_index) { + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kAnd: + jump_step = CreateCondJumpStep(false, true, {}, expr->id()); + break; + case BinaryCond::kOr: + jump_step = CreateCondJumpStep(true, true, {}, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); + } } } } @@ -2251,17 +2263,9 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) - ? expr->call_expr().args().size() - : 2; - for (int i = 0; i < args_count - 1; ++i) { + if (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue) { switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; case BinaryCond::kOptionalOr: visitor_->AddStep( CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); @@ -2273,13 +2277,11 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { default: ABSL_UNREACHABLE(); } - } - if (short_circuiting_) { - // If short-circuiting is enabled, point the conditional jump past the - // boolean operator step. - for (auto& jump : jump_steps_) { - visitor_->SetProgressStatusIfError( - jump.set_target(visitor_->GetCurrentIndex())); + if (short_circuiting_) { + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index e2581e3fd..105060282 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -64,6 +64,7 @@ #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/options.h" #include "parser/parser.h" #include "runtime/function.h" #include "runtime/function_adapter.h" @@ -2916,7 +2917,11 @@ class FlatExprBuilderVariadicLogicalTest TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { const auto& test_case = GetParam(); - ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + parser::ParserOptions parser_options; + parser_options.enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse(test_case.expr, test_case.label, parser_options)); cel::RuntimeOptions options; options.unknown_processing = From c0e43073ea7f4fbbcdc59b45f34abe8bd5b9364a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 11 Jun 2026 11:22:17 -0700 Subject: [PATCH 113/143] refactor: Move type formatter to common. PiperOrigin-RevId: 930639404 --- checker/internal/BUILD | 15 ++---------- checker/internal/type_checker_impl.cc | 2 +- checker/internal/type_inference_context.cc | 18 +++++++-------- common/BUILD | 23 +++++++++++++++++++ .../internal => common}/format_type_name.cc | 6 ++--- .../internal => common}/format_type_name.h | 10 ++++---- .../format_type_name_test.cc | 6 ++--- 7 files changed, 46 insertions(+), 34 deletions(-) rename {checker/internal => common}/format_type_name.cc (97%) rename {checker/internal => common}/format_type_name.h (74%) rename {checker/internal => common}/format_type_name_test.cc (97%) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 777457830..20c476db2 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -130,7 +130,6 @@ cc_library( "type_checker_impl.h", ], deps = [ - ":format_type_name", ":namespace_generator", ":proto_type_mask", ":type_check_env", @@ -149,6 +148,7 @@ cc_library( "//common:container", "//common:decl", "//common:expr", + "//common:format_type_name", "//common:standard_definitions", "//common:type", "//common:type_kind", @@ -243,8 +243,8 @@ cc_library( srcs = ["type_inference_context.cc"], hdrs = ["type_inference_context.h"], deps = [ - ":format_type_name", "//common:decl", + "//common:format_type_name", "//common:standard_definitions", "//common:type", "//common:type_kind", @@ -275,17 +275,6 @@ cc_test( ], ) -cc_library( - name = "format_type_name", - srcs = ["format_type_name.cc"], - hdrs = ["format_type_name.h"], - deps = [ - "//common:type", - "//common:type_kind", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "descriptor_pool_type_introspector", srcs = ["descriptor_pool_type_introspector.cc"], diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6b6b051b1..f3a06a28d 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -36,7 +36,6 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/checker_options.h" -#include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_builder_impl.h" @@ -52,6 +51,7 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/format_type_name.h" #include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 1a87d9e15..4681784af 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -28,8 +28,8 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/format_type_name.h" #include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -657,14 +657,14 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, std::string TypeInferenceContext::DebugString() const { return absl::StrCat( "type_parameter_bindings: ", - absl::StrJoin( - type_parameter_bindings_, "\n ", - [](std::string* out, const auto& binding) { - absl::StrAppend( - out, binding.first, " (", binding.second.name, ") -> ", - checker_internal::FormatTypeName( - binding.second.type.value_or(Type(TypeParamType("none"))))); - })); + absl::StrJoin(type_parameter_bindings_, "\n ", + [](std::string* out, const auto& binding) { + absl::StrAppend( + out, binding.first, " (", binding.second.name, + ") -> ", + cel::FormatTypeName(binding.second.type.value_or( + Type(TypeParamType("none"))))); + })); } void TypeInferenceContext::AssignabilityContext:: diff --git a/common/BUILD b/common/BUILD index f7c897e57..93410306f 100644 --- a/common/BUILD +++ b/common/BUILD @@ -601,6 +601,17 @@ cc_library( ], ) +cc_library( + name = "format_type_name", + srcs = ["format_type_name.cc"], + hdrs = ["format_type_name.h"], + deps = [ + ":type", + ":type_kind", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "type_test", srcs = glob([ @@ -623,6 +634,18 @@ cc_test( ], ) +cc_test( + name = "format_type_name_test", + srcs = ["format_type_name_test.cc"], + deps = [ + ":format_type_name", + ":type", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "value", srcs = glob( diff --git a/checker/internal/format_type_name.cc b/common/format_type_name.cc similarity index 97% rename from checker/internal/format_type_name.cc rename to common/format_type_name.cc index 7cd17251f..4bd6c2e61 100644 --- a/checker/internal/format_type_name.cc +++ b/common/format_type_name.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include #include @@ -20,7 +20,7 @@ #include "common/type.h" #include "common/type_kind.h" -namespace cel::checker_internal { +namespace cel { namespace { struct FormatImplRecord { @@ -177,4 +177,4 @@ std::string FormatTypeName(const Type& type) { return out; } -} // namespace cel::checker_internal +} // namespace cel diff --git a/checker/internal/format_type_name.h b/common/format_type_name.h similarity index 74% rename from checker/internal/format_type_name.h rename to common/format_type_name.h index c31e1c4d0..723ac20fd 100644 --- a/checker/internal/format_type_name.h +++ b/common/format_type_name.h @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ -#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ #include #include "common/type.h" -namespace cel::checker_internal { +namespace cel { // Format the type name for presentation in error messages. Matches the // formatting used in github.com/cel-spec. std::string FormatTypeName(const Type& type); -} // namespace cel::checker_internal +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ diff --git a/checker/internal/format_type_name_test.cc b/common/format_type_name_test.cc similarity index 97% rename from checker/internal/format_type_name_test.cc rename to common/format_type_name_test.cc index ff04e04d2..ca63f60b0 100644 --- a/checker/internal/format_type_name_test.cc +++ b/common/format_type_name_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include "common/type.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" -namespace cel::checker_internal { +namespace cel { namespace { using ::cel::expr::conformance::proto2::GlobalEnum_descriptor; @@ -115,4 +115,4 @@ TEST(FormatTypeNameTest, ArbitraryNesting) { #endif } // namespace -} // namespace cel::checker_internal +} // namespace cel From d2a8a7831b76fb128de4d03254bc8f81dfcf5652 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 11 Jun 2026 12:12:47 -0700 Subject: [PATCH 114/143] Create a `proto_to_predicate` compiler for converting proto messages into CEL expressions PiperOrigin-RevId: 930668150 --- common/expr_factory.h | 5 + internal/json.h | 4 +- tools/BUILD | 50 +++ tools/proto_to_predicate.cc | 459 ++++++++++++++++++++++++ tools/proto_to_predicate.h | 48 +++ tools/proto_to_predicate_test.cc | 593 +++++++++++++++++++++++++++++++ tools/testdata/BUILD | 19 +- tools/testdata/test_policy.proto | 73 ++++ 8 files changed, 1245 insertions(+), 6 deletions(-) create mode 100644 tools/proto_to_predicate.cc create mode 100644 tools/proto_to_predicate.h create mode 100644 tools/proto_to_predicate_test.cc create mode 100644 tools/testdata/test_policy.proto diff --git a/common/expr_factory.h b/common/expr_factory.h index 5607d8deb..757318545 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -34,6 +34,10 @@ class MacroExprFactory; class ParserMacroExprFactory; class OptimizerExprFactory; +namespace tools { +class ProtoToPredicateBuilder; +} + class ExprFactory { protected: // `IsExprLike` determines whether `T` is some `Expr`. Currently that means @@ -380,6 +384,7 @@ class ExprFactory { friend class MacroExprFactory; friend class ParserMacroExprFactory; friend class OptimizerExprFactory; + friend class tools::ProtoToPredicateBuilder; ExprFactory() : accu_var_(kAccumulatorVariableName) {} diff --git a/internal/json.h b/internal/json.h index d32c42741..e35909d0e 100644 --- a/internal/json.h +++ b/internal/json.h @@ -26,7 +26,7 @@ namespace cel::internal { // Converts the given message to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageToJson( const google::protobuf::Message& message, @@ -45,7 +45,7 @@ absl::Status MessageToJson( google::protobuf::Message* absl_nonnull result); // Converts the given message field to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageFieldToJson( const google::protobuf::Message& message, diff --git a/tools/BUILD b/tools/BUILD index ceb2befc5..af006a67b 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -204,6 +204,56 @@ cc_library( ], ) +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "//common:operators", + "//internal:status_macros", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//env:config", + "//env:env_runtime", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//tools/testdata:test_policy_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..8c89ee2f0 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,459 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { + +using ::google::api::expr::common::CelOperator; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +class ProtoToPredicateBuilder final : private ExprFactory { + public: + ProtoToPredicateBuilder() : id_(1) {} + + absl::StatusOr Build(absl::string_view input_name, + const Message& message) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(message, base_expr, predicates)); + Expr root = LogicalAnd(predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::string_view input_name, + absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + message_asts.reserve(messages.size()); + for (const auto* message : messages) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(*message, base_expr, predicates)); + message_asts.push_back(LogicalAnd(predicates)); + } + + return Ast(LogicalOr(message_asts), std::move(source_info_)); + } + + private: + // Retrieves the "match_path" string option from the field options if + // defined, returning an empty string otherwise. + std::string GetMatchPath(const ::google::protobuf::FieldDescriptor* field) { + const ::google::protobuf::Message& options = field->options(); + const ::google::protobuf::Reflection* refl = options.GetReflection(); + std::vector fields; + refl->ListFields(options, &fields); + for (const auto* f : fields) { + if (f->name() == "match_path") { + return refl->GetString(options, f); + } + } + return ""; + } + + // Parses a dot-separated string representation of a path (e.g. "dest.region") + // and builds a corresponding select chain AST. + Expr ParseAndBuildPath(absl::string_view path_str) { + std::vector parts = absl::StrSplit(path_str, '.'); + Expr e = NewIdent(NextId(), parts[0]); + for (size_t i = 1; i < parts.size(); ++i) { + e = NewSelect(NextId(), std::move(e), parts[i]); + } + return e; + } + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Log a warning as message should be handled by Walk. + ABSL_LOG(WARNING) << "PrimitiveToExpr: Unhandled field type: " + << FieldDescriptor::TypeName(field->type()); + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args = {std::move(lhs), std::move(rhs)}; + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp(CelOperator::EQUALS, std::move(lhs), + std::move(rhs)); + } + + Expr LogicalOr(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_OR, exprs); + } + + Expr LogicalAnd(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_AND, exprs); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr LogicalOp(absl::string_view op, std::vector& exprs) { + if (exprs.empty()) { + return NewBoolConst(NextId(), true); + } + if (exprs.size() == 1) { + return std::move(exprs[0]); + } + return NewCall(NextId(), op, std::move(exprs)); + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all key-value pairs + // specified in the policy are present in the input map field: + // "key" in input.map && input.map["key"] == value + absl::Status WalkMapField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, const Expr& base_expr, + int size, std::vector& predicates) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + + struct MapEntry { + const Message* message; + }; + std::vector entries; + entries.reserve(size); + for (int i = 0; i < size; ++i) { + entries.push_back({&reflection->GetRepeatedMessage(message, field, i)}); + } + + if (!entries.empty()) { + const Reflection* const entry_ref = entries[0].message->GetReflection(); + std::sort(entries.begin(), entries.end(), + [entry_ref, key_field](const MapEntry& a, const MapEntry& b) { + switch (key_field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return entry_ref->GetInt32(*a.message, key_field) < + entry_ref->GetInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_INT64: + return entry_ref->GetInt64(*a.message, key_field) < + entry_ref->GetInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT32: + return entry_ref->GetUInt32(*a.message, key_field) < + entry_ref->GetUInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT64: + return entry_ref->GetUInt64(*a.message, key_field) < + entry_ref->GetUInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_BOOL: + return !entry_ref->GetBool(*a.message, key_field) && + entry_ref->GetBool(*b.message, key_field); + case FieldDescriptor::CPPTYPE_STRING: + return entry_ref->GetString(*a.message, key_field) < + entry_ref->GetString(*b.message, key_field); + default: + return false; + } + }); + } + + std::vector map_checks; + map_checks.reserve(size); + for (const auto& entry : entries) { + const Message& entry_msg = *entry.message; + const Reflection* const entry_ref = entry_msg.GetReflection(); + + Expr key_expr = PrimitiveToExpr(entry_msg, entry_ref, key_field); + + // Represents `"key" in input.map` to assert the key exists. + Expr in_check = NewCall(NextId(), CelOperator::IN, + std::vector{key_expr, map_path}); + // Represents `input.map["key"]` to lookup the value. + Expr lookup_path = NewCall(NextId(), CelOperator::INDEX, + std::vector{map_path, key_expr}); + + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& value_msg = + entry_ref->GetMessage(entry_msg, value_field); + std::vector val_predicates; + CEL_RETURN_IF_ERROR(Walk(value_msg, lookup_path, val_predicates)); + + if (!val_predicates.empty()) { + // Represents `"key" in input.map && (nested message fields check...)` + map_checks.push_back(std::move(in_check)); + map_checks.insert(map_checks.end(), + std::make_move_iterator(val_predicates.begin()), + std::make_move_iterator(val_predicates.end())); + } else { + // Represents `"key" in input.map` if nested message is empty. + map_checks.push_back(std::move(in_check)); + } + } else { + Expr value_expr = PrimitiveToExpr(entry_msg, entry_ref, value_field); + // Represents `input.map["key"] == value` + Expr eq_check = + ConstructEquality(std::move(lookup_path), std::move(value_expr)); + + // Represents `"key" in input.map && input.map["key"] == value` + map_checks.push_back(std::move(in_check)); + map_checks.push_back(std::move(eq_check)); + } + } + + predicates.push_back(LogicalAnd(map_checks)); + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds predicates for a repeated field: + // - Repeated Messages are mapped to a logical OR (||) of the generated + // predicates for each message. + // - Repeated Primitives are mapped either to: + // - `lhs in [values]` if a "match_path" option is specified. + // - `value in input.field` conjoined with && for each value otherwise. + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const Expr& base_expr, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + std::vector message_asts; + message_asts.reserve(size); + for (int i = 0; i < size; ++i) { + const Message& sub_message = + reflection->GetRepeatedMessage(message, field, i); + std::vector sub_predicates; + Expr sub_base = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, sub_base, sub_predicates)); + message_asts.push_back(LogicalAnd(sub_predicates)); + } + // Represents alternate message predicates conjoined with OR: `msg_1 || + // msg_2 || ...` + predicates.push_back(LogicalOr(message_asts)); + return absl::OkStatus(); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::string match_path_val = GetMatchPath(field); + if (!match_path_val.empty()) { + Expr lhs = ParseAndBuildPath(match_path_val); + // Represents `lhs in [values]` check (e.g. `dest.region in ["us-east", + // "us-west"]`). + predicates.push_back( + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(lhs), std::move(literal_list)})); + return absl::OkStatus(); + } + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + std::vector element_checks; + element_checks.reserve(size); + for (int i = 0; i < size; ++i) { + Expr elem_expr = RepeatedPrimitiveToExpr(message, reflection, field, i); + // Represents `value in input.field` check. + Expr in_check = + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(elem_expr), map_path}); + element_checks.push_back(std::move(in_check)); + } + // Represents `"val1" in input.list && "val2" in input.list && ...` + predicates.push_back(LogicalAnd(element_checks)); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, const Expr& base_expr, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkMapField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkRepeatedField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, field_path, predicates)); + } else { + // Primitive field: base_expr.field == + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + predicates.push_back( + ConstructEquality(std::move(field_path), + PrimitiveToExpr(message, reflection, field))); + } + } + return absl::OkStatus(); + } + + ExprId id_; + SourceInfo source_info_; +}; + +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, message); +} + +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..ed01cb1e8 --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +// Translates a Protocol Buffer message into a CEL AST representing a predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message); + +// Translates a list of Protocol Buffer messages into a CEL AST representing a +// conjoined or alternate predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..80ad140c7 --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,593 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/value.h" +#include "env/config.h" +#include "env/env_runtime.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/testdata/test_policy.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/json/json.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "google.api.expr.runtime.TestMessage" +)"; + +TestMessage ParseTestMessage(absl::string_view textproto) { + TestMessage msg; + google::protobuf::TextFormat::ParseFromString(textproto, &msg); + return msg; +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +struct TestCase { + std::string name; + std::vector input_textprotos; + std::string expected_unparsed; + std::string eval_textproto; + bool expected_eval_result = true; + // If true, skip the eval step of the test. This is useful for tests where + // the expected expression does not share the same type structure as the + // input proto, such as empty messages. + bool skip_eval = false; +}; + +class ProtoToPredicateTest : public ::testing::TestWithParam {}; + +TEST_P(ProtoToPredicateTest, ConformanceTests) { + const TestCase& param = GetParam(); + + std::vector input_messages; + input_messages.reserve(param.input_textprotos.size()); + for (const auto& proto_str : param.input_textprotos) { + input_messages.push_back(ParseTestMessage(proto_str)); + } + + std::vector ptr_messages; + ptr_messages.reserve(input_messages.size()); + for (const auto& msg : input_messages) { + ptr_messages.push_back(&msg); + } + + absl::StatusOr ast_or; + if (input_messages.size() == 1) { + ast_or = ProtoToPredicateAst("input", input_messages[0]); + } else { + ast_or = ProtoToPredicateAst("input", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); + + if (!param.skip_eval) { + TestMessage eval_msg = ParseTestMessage(param.eval_textproto); + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, eval_msg)); + EXPECT_EQ(eval_result, param.expected_eval_result); + } +} + +INSTANTIATE_TEST_SUITE_P( + ProtoToPredicateSubCases, ProtoToPredicateTest, + testing::Values( + TestCase{ + .name = "EmptyMessageTest", + .input_textprotos = {""}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "EmptyMessagesListTest", + .input_textprotos = {}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "PrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 42 string_value: "hello" + )pb", + }, + TestCase{ + .name = "AllPrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.int64_value == 43 && " + "input.uint32_value == 44u && input.uint64_value == 45u && " + "input.float_value == 46.5 && input.double_value == 47.5 && " + "input.string_value == \"hello\" && " + "input.bytes_value == b\"world\" && " + "input.bool_value == true && " + "input.enum_value == 1", + .eval_textproto = R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb", + }, + TestCase{ + .name = "NestedMessageTest", + .input_textprotos = {R"pb( + message_value: { int32_value: 42 } + )pb"}, + .expected_unparsed = "input.message_value.int32_value == 42", + .eval_textproto = R"pb( + message_value: { int32_value: 42 } + )pb", + }, + TestCase{ + .name = "RepeatedFieldTest", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 2 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldSingleElementTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + )pb"}, + .expected_unparsed = "42 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldEmptyTest", + .input_textprotos = {R"pb( + int32_list: [] + )pb"}, + .expected_unparsed = "true", + .eval_textproto = R"pb( + int32_list: [] + )pb", + }, + TestCase{ + .name = "ListFieldEvalNegative", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 3 ] + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "SingleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb"}, + .expected_unparsed = "42 in input.int32_list && " + "43 in input.int64_list && " + "44u in input.uint32_list && " + "45u in input.uint64_list && " + "46.5 in input.float_list && " + "47.5 in input.double_list && " + "\"hello\" in input.string_list && " + "b\"world\" in input.bytes_list && " + "true in input.bool_list && " + "1 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb", + }, + TestCase{ + .name = "MultipleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb"}, + .expected_unparsed = + "42 in input.int32_list && 142 in input.int32_list && " + "43 in input.int64_list && 143 in input.int64_list && " + "44u in input.uint32_list && 144u in input.uint32_list && " + "45u in input.uint64_list && 145u in input.uint64_list && " + "46.5 in input.float_list && 146.5 in input.float_list && " + "47.5 in input.double_list && 147.5 in input.double_list && " + "\"hello\" in input.string_list && \"universe\" in " + "input.string_list && " + "b\"world\" in input.bytes_list && b\"space\" in " + "input.bytes_list && " + "true in input.bool_list && false in input.bool_list && " + "1 in input.enum_list && 2 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb", + }, + TestCase{ + .name = "MapFieldTest", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb", + }, + TestCase{ + .name = "MapFieldEvalNegativeVal", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 3 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldEvalNegativeNoKey", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldIntKeySortingTest", + .input_textprotos = {R"pb( + int32_int32_map: { key: 10 value: 100 } + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + )pb"}, + .expected_unparsed = "5 in input.int32_int32_map && " + "input.int32_int32_map[5] == 50 && " + "8 in input.int32_int32_map && " + "input.int32_int32_map[8] == 80 && " + "10 in input.int32_int32_map && " + "input.int32_int32_map[10] == 100", + .eval_textproto = R"pb( + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + int32_int32_map: { key: 10 value: 100 } + )pb", + }, + TestCase{ + .name = "MultipleMessagesTest", + .input_textprotos = {R"pb( + int32_value: 42 + )pb", + R"pb( + int32_value: 41 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 41 string_value: "hello" + )pb", + }, + TestCase{ + .name = "RepeatedMessageFieldTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 } + , { int32_value: 43 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42 || " + "input.message_list.int32_value == 43", + .skip_eval = true, + }, + TestCase{ + .name = "RepeatedMessageSingleElementTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42", + .skip_eval = true, + })); + +struct PolicyTestCase { + std::string name; + std::string json_input; + std::string expected_unparsed; +}; + +class PolicyJsonTest : public ::testing::TestWithParam {}; + +TEST_P(PolicyJsonTest, Conformance) { + const PolicyTestCase& param = GetParam(); + + cel::cpp::tools::Policy policy; + google::protobuf::json::ParseOptions options; + options.ignore_unknown_fields = true; + auto status = + google::protobuf::json::JsonStringToMessage(param.json_input, &policy, options); + ASSERT_THAT(status, IsOk()) << "Failed to parse JSON: " << param.json_input; + + absl::StatusOr ast_or; + std::vector ptr_messages; + ptr_messages.reserve(policy.destinations_size()); + for (const auto& dest : policy.destinations()) { + ptr_messages.push_back(&dest); + } + + if (ptr_messages.empty()) { + auto parsed_expr_or = google::api::expr::parser::Parse("false"); + ASSERT_THAT(parsed_expr_or, IsOk()); + auto ast_ptr_or = cel::CreateAstFromParsedExpr(*parsed_expr_or); + ASSERT_THAT(ast_ptr_or, IsOk()); + ast_or = std::move(**ast_ptr_or); + } else if (ptr_messages.size() == 1) { + ast_or = ProtoToPredicateAst("dest", *ptr_messages[0]); + } else { + ast_or = ProtoToPredicateAst("dest", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); +} + +INSTANTIATE_TEST_SUITE_P( + PolicyJsonSubCases, PolicyJsonTest, + testing::Values( + PolicyTestCase{ + .name = "SimpleMatch", + .json_input = + R"({ "destinations": [ { "agent": { "id": "agent-007" } } ] })", + .expected_unparsed = "dest.agent.name == \"agent-007\"", + }, + PolicyTestCase{ + .name = "MultipleFields", + .json_input = + R"({ "destinations": [ { + "tool": { + "name": "admin_tool", + "annotations": { + "read_only_hint": false + } + } + } + ] })", + .expected_unparsed = + "dest.tool.name == \"admin_tool\" && " + "dest.tool.annotations.read_only_hint == false", + }, + PolicyTestCase{ + .name = "RepeatedMessages", + .json_input = + R"({ "destinations": [ + { "agent": { "id": "worker-1" } }, + { "agent": { "id": "worker-2" } }, + ] })", + .expected_unparsed = "dest.agent.name == \"worker-1\" || " + "dest.agent.name == \"worker-2\"", + }, + PolicyTestCase{ + .name = "RepeatedPrimitiveArraySingleElement", + .json_input = + R"({ "destinations": [ { + "tool": { + "role_members": { + "admin": { + "principals": ["alice"] + } + } + } + } ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "\"alice\" in dest.tool.role_members[\"admin\"].principals", + }, + PolicyTestCase{ + .name = "RepeatedArrayEmpty", + .json_input = R"({ "destinations": [ { "tool": { } } ] })", + .expected_unparsed = "true", + }, + PolicyTestCase{ + .name = "MapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "name": "shell", + "labels": { + "cluster": "us-central1", + "project": "dev" + } + } + } ] })", + .expected_unparsed = + "dest.tool.name == \"shell\" && \"cluster\" in " + "dest.tool.labels && dest.tool.labels[\"cluster\"] == " + "\"us-central1\" && \"project\" in dest.tool.labels && " + "dest.tool.labels[\"project\"] == \"dev\"", + }, + PolicyTestCase{ + .name = "NestedMapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "role_members": { + "admin": { + "all_users": true + } + } + } } + ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "dest.tool.role_members[\"admin\"].all_users == true", + }, + PolicyTestCase{ + .name = "EmptyPolicy", + .json_input = "{}", + .expected_unparsed = "false", + })); + +} // namespace +} // namespace cel::tools diff --git a/tools/testdata/BUILD b/tools/testdata/BUILD index 493f0ff2f..c88c9c478 100644 --- a/tools/testdata/BUILD +++ b/tools/testdata/BUILD @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@com_github_google_flatbuffers//:build_defs.bzl", - "flatbuffer_library_public", -) +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") licenses(["notice"]) @@ -46,3 +45,15 @@ cc_library( linkstatic = True, deps = ["@com_github_google_flatbuffers//:runtime_cc"], ) + +proto_library( + name = "test_policy_proto", + srcs = ["test_policy.proto"], + visibility = ["//tools:__subpackages__"], +) + +cc_proto_library( + name = "test_policy_cc_proto", + visibility = ["//tools:__subpackages__"], + deps = [":test_policy_proto"], +) diff --git a/tools/testdata/test_policy.proto b/tools/testdata/test_policy.proto new file mode 100644 index 000000000..b5d424c04 --- /dev/null +++ b/tools/testdata/test_policy.proto @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test schema representing client-configured policies. +// It is used by the `proto_to_predicate` tool to translate Protobuf policies +// into CEL predicates. +edition = "2023"; + +package cel.cpp.tools; + +option cc_enable_arenas = true; + +// Represents the targeted client agent. +message Agent { + string name = 1 [json_name = "id"]; +} + +// Specifies additional metadata tool annotations. +message ToolAnnotations { + bool read_only_hint = 1; +} + +// Represents a mapped nested message entry value inside map fields. +message Members { + repeated string principals = 1; + + repeated string regions = 2; + + bool all_users = 3; + + bool all_authenticated_users = 4; +} + +// Represents a metadata tool block. +message Tool { + // The name of the tool. + string name = 1; + + // Additional metadata annotations for the tool. + ToolAnnotations annotations = 2; + + // A string-to-string map, transpiled as conjoined existence and equality + // checks. + map labels = 3; + + // A map with string keys representing roles and Member instances as values. + map role_members = 4; +} + +// Represents a policy mapping destination block. +message Target { + oneof kind { + Agent agent = 1; + Tool tool = 2; + } +} + +// Represents the top-level policy containing multiple alternate destination +// rules. +message Policy { + repeated Target destinations = 1; +} From 9597d49ef616726bdc8b2a553fd89fd46ad8b493 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 15 Jun 2026 11:43:00 -0700 Subject: [PATCH 115/143] Add support for context types env.yaml Add support for declaring a protobuf context message whose top level fields are declared as variables in the CEL environment. PiperOrigin-RevId: 932580314 --- env/BUILD | 1 + env/config.h | 6 ++++++ env/env.cc | 5 +++++ env/env_test.cc | 19 +++++++++++++++++++ env/env_yaml.cc | 30 ++++++++++++++++++++++++++++++ env/env_yaml_test.cc | 41 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 102 insertions(+) diff --git a/env/BUILD b/env/BUILD index 3035e11ac..bd82e8ec6 100644 --- a/env/BUILD +++ b/env/BUILD @@ -130,6 +130,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/env/config.h b/env/config.h index e427832ff..68e4a1dd9 100644 --- a/env/config.h +++ b/env/config.h @@ -32,6 +32,11 @@ class Config { void SetName(std::string name) { name_ = std::move(name); } std::string GetName() const { return name_; } + void SetContextType(std::string context_type) { + context_type_ = std::move(context_type); + } + std::string GetContextType() const { return context_type_; } + struct ContainerConfig { std::string name; std::vector abbreviations; @@ -150,6 +155,7 @@ class Config { private: std::string name_; + std::string context_type_; ContainerConfig container_config_; std::vector extension_configs_; StandardLibraryConfig standard_library_config_; diff --git a/env/env.cc b/env/env.cc index 6cd3a3cdc..22d24295e 100644 --- a/env/env.cc +++ b/env/env.cc @@ -138,6 +138,11 @@ absl::StatusOr> Env::NewCompilerBuilder() { for (const auto& abbr : config_.GetContainerConfig().abbreviations) { CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); } + + if (!config_.GetContextType().empty()) { + CEL_RETURN_IF_ERROR( + checker_builder.AddContextDeclaration(config_.GetContextType())); + } for (const auto& alias : config_.GetContainerConfig().aliases) { CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); } diff --git a/env/env_test.cc b/env/env_test.cc index b599aa569..fda87dfab 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -344,6 +344,25 @@ TEST(ContainerConfigTest, ContainerConfigWithAliases) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContextVariableConfigTest, Basic) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContextType("cel.expr.conformance.proto3.TestAllTypes"); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + // Top-level fields of TestAllTypes like "single_int32" should resolve + // successfully. + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("single_int32 > 10")); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto result_invalid, + compiler->Compile("non_existent_field > 10")); + EXPECT_THAT(result_invalid.GetIssues(), Not(IsEmpty())); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 1bbfe6b36..e7b8a7885 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -26,6 +26,7 @@ #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" @@ -1245,6 +1246,34 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, } out << YAML::EndSeq; } + +absl::Status ParseContextVariableConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node context_variable = root["context_variable"]; + if (!context_variable.IsDefined()) { + return absl::OkStatus(); + } + if (!context_variable.IsMap()) { + return YamlError(yaml, context_variable, + "Node 'context_variable' is not a map"); + } + + const YAML::Node type_name = context_variable["type_name"]; + const YAML::Node type = context_variable["type"]; + const YAML::Node* type_node = nullptr; + if (type.IsDefined() && type.IsScalar()) { + type_node = &type; + } else if (type_name.IsDefined() && type_name.IsScalar()) { + type_node = &type_name; + } else { + return YamlError(yaml, context_variable, + "Node 'context_variable' does not have a valid type"); + } + ABSL_DCHECK(type_node != nullptr); + config.SetContextType(GetString(yaml, *type_node)); + return absl::OkStatus(); +} + } // namespace absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { @@ -1263,6 +1292,7 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContextVariableConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); return config; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index a60048617..9c5b3f04f 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -216,6 +216,47 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { EXPECT_THAT(type_info.params[1].params, IsEmpty()); } +TEST(EnvYamlTest, ParseContextVariableConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type_name: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableConfigAlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable) { + EXPECT_THAT(EnvConfigFromYaml(R"yaml( + context_variable: 123 + + )yaml"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' is not a map"))); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable2) { + EXPECT_THAT( + EnvConfigFromYaml(R"yaml( + context_variable: + type: + foo: bar + )yaml"), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' does not have a valid type"))); +} + TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: From 01e6b6379d8100f6f6dcb6e41627037d1b2a7112 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 16 Jun 2026 06:44:51 -0700 Subject: [PATCH 116/143] Export policy compiler. Co-authored-by: Dmitri Plotnikov PiperOrigin-RevId: 933062225 --- MODULE.bazel | 14 + conformance/policy/BUILD | 78 ++ .../policy/policy_conformance_test.bzl | 46 + conformance/policy/policy_conformance_test.cc | 659 ++++++++++ internal/BUILD | 1 + internal/runfiles.cc | 15 +- internal/runfiles.h | 6 + policy/BUILD | 239 ++++ policy/cel_policy.cc | 273 +++++ policy/cel_policy.h | 320 +++++ policy/cel_policy_parse_context.cc | 49 + policy/cel_policy_parse_context.h | 65 + policy/cel_policy_parse_result.cc | 91 ++ policy/cel_policy_parse_result.h | 105 ++ policy/cel_policy_parser.h | 40 + policy/cel_policy_test.cc | 220 ++++ policy/cel_policy_validation_result.cc | 32 + policy/cel_policy_validation_result.h | 84 ++ policy/compiler.cc | 1058 +++++++++++++++++ policy/compiler.h | 50 + policy/compiler_test.cc | 946 +++++++++++++++ policy/internal/BUILD | 68 ++ policy/internal/issue_reporter.cc | 45 + policy/internal/issue_reporter.h | 57 + policy/internal/optimizer_expr_factory.cc | 373 ++++++ policy/internal/optimizer_expr_factory.h | 419 +++++++ .../internal/optimizer_expr_factory_test.cc | 570 +++++++++ policy/test_custom_yaml_policy_parser.cc | 188 +++ policy/test_util.cc | 221 ++++ policy/test_util.h | 33 + policy/testdata/BUILD | 19 + policy/testdata/cel_policy.yaml | 42 + policy/testdata/cel_policy_parser.baseline | 89 ++ policy/testdata/custom_policy_format.yaml | 29 + .../custom_policy_format_parser.baseline | 75 ++ .../custom_policy_format_with_errors.yaml | 33 + ..._policy_format_with_errors_parser.baseline | 16 + policy/testdata/nested_rule.yaml | 37 + policy/testdata/nested_rule_parser.baseline | 84 ++ policy/yaml_policy_parser.cc | 411 +++++++ policy/yaml_policy_parser.h | 135 +++ policy/yaml_policy_parser_test.cc | 305 +++++ 42 files changed, 7639 insertions(+), 1 deletion(-) create mode 100644 conformance/policy/BUILD create mode 100644 conformance/policy/policy_conformance_test.bzl create mode 100644 conformance/policy/policy_conformance_test.cc create mode 100644 policy/BUILD create mode 100644 policy/cel_policy.cc create mode 100644 policy/cel_policy.h create mode 100644 policy/cel_policy_parse_context.cc create mode 100644 policy/cel_policy_parse_context.h create mode 100644 policy/cel_policy_parse_result.cc create mode 100644 policy/cel_policy_parse_result.h create mode 100644 policy/cel_policy_parser.h create mode 100644 policy/cel_policy_test.cc create mode 100644 policy/cel_policy_validation_result.cc create mode 100644 policy/cel_policy_validation_result.h create mode 100644 policy/compiler.cc create mode 100644 policy/compiler.h create mode 100644 policy/compiler_test.cc create mode 100644 policy/internal/BUILD create mode 100644 policy/internal/issue_reporter.cc create mode 100644 policy/internal/issue_reporter.h create mode 100644 policy/internal/optimizer_expr_factory.cc create mode 100644 policy/internal/optimizer_expr_factory.h create mode 100644 policy/internal/optimizer_expr_factory_test.cc create mode 100644 policy/test_custom_yaml_policy_parser.cc create mode 100644 policy/test_util.cc create mode 100644 policy/test_util.h create mode 100644 policy/testdata/BUILD create mode 100644 policy/testdata/cel_policy.yaml create mode 100644 policy/testdata/cel_policy_parser.baseline create mode 100644 policy/testdata/custom_policy_format.yaml create mode 100644 policy/testdata/custom_policy_format_parser.baseline create mode 100644 policy/testdata/custom_policy_format_with_errors.yaml create mode 100644 policy/testdata/custom_policy_format_with_errors_parser.baseline create mode 100644 policy/testdata/nested_rule.yaml create mode 100644 policy/testdata/nested_rule_parser.baseline create mode 100644 policy/yaml_policy_parser.cc create mode 100644 policy/yaml_policy_parser.h create mode 100644 policy/yaml_policy_parser_test.cc diff --git a/MODULE.bazel b/MODULE.bazel index 43d0485d2..187d68164 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -31,6 +31,7 @@ bazel_dep( name = "rules_python", version = "1.6.3", ) +bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep( name = "protobuf", version = "34.1", @@ -96,3 +97,16 @@ bazel_dep( name = "yaml-cpp", version = "0.9.0", ) + +_CEL_POLICY_TAG = "ebfb2361f47080af643c14cf4da4c2b551a68740" + +_CEL_POLICY_SHA = "ea69e9c6b7bd5bc37d358148aebd2fcca38bc7c45a23feb635de72338e0327c1" + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "cel_policy", + sha256 = _CEL_POLICY_SHA, + strip_prefix = "cel-policy-%s" % _CEL_POLICY_TAG, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % _CEL_POLICY_TAG, +) diff --git a/conformance/policy/BUILD b/conformance/policy/BUILD new file mode 100644 index 000000000..29210e02d --- /dev/null +++ b/conformance/policy/BUILD @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//conformance/policy:policy_conformance_test.bzl", + "cel_policy_conformance_test", +) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "policy_conformance_test_lib", + testonly = True, + srcs = ["policy_conformance_test.cc"], + deps = [ + "//common:ast", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//extensions/protobuf:bind_proto_to_activation", + "//extensions/protobuf:enum_adapter", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//internal:testing_no_main", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "//policy:cel_policy_validation_result", + "//policy:compiler", + "//policy:test_util", + "//policy:yaml_policy_parser", + "//runtime", + "//runtime:activation", + "//runtime:function_adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_policy_conformance_test( + name = "policy_conformance_test", + example = "@cel_policy//conformance:testdata/nested_rule/policy.yaml", + skip_tests = [ + # TODO(b/506179116): Fix these. + # Need to add k8s custom yaml parser and mock runtime. + "k8s", + ], + test_files = [ + "@cel_policy//conformance:testdata", + ], +) diff --git a/conformance/policy/policy_conformance_test.bzl b/conformance/policy/policy_conformance_test.bzl new file mode 100644 index 000000000..0b4d1a4c6 --- /dev/null +++ b/conformance/policy/policy_conformance_test.bzl @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating policy conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +def cel_policy_conformance_test(name, test_files, example, skip_tests = [], **kwargs): + """Generates a policy conformance test target. + + Args: + name: Name of the test target. + test_files: List of targets or files representing the test data. + example: A specific example file from test_files used for runfiles resolution. + skip_tests: List of test cases to skip. + testdata_dir: Path to testdata directory under runfiles. + **kwargs: Additional arguments passed to the underlying cc_test. + """ + args = ["--gunit_fail_if_no_test_linked"] + args.append("--testdata_example='$(rlocationpath {})'".format(example)) + + if skip_tests: + args.append("--skip_tests=" + ",".join(skip_tests)) + + cc_test( + name = name, + data = test_files + [example], + deps = [ + "//conformance/policy:policy_conformance_test_lib", + ], + args = args, + **kwargs + ) diff --git a/conformance/policy/policy_conformance_test.cc b/conformance/policy/policy_conformance_test.cc new file mode 100644 index 000000000..0d68f8abf --- /dev/null +++ b/conformance/policy/policy_conformance_test.cc @@ -0,0 +1,659 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +// NOLINTNEXTLINE(build/c++17) for OSS compatibility +#include + +#include "cel/expr/eval.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/protobuf/bind_proto_to_activation.h" +#include "extensions/protobuf/enum_adapter.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/compiler.h" +#include "policy/test_util.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +// Use a specific file to handle bazel runfiles resolution correctly. We find +// parent directory named 'testdata' to use as the root of the test cases. +ABSL_FLAG(std::string, testdata_example, "", + "Path to a specific example file."); +ABSL_FLAG(std::vector, skip_tests, {}, + "Comma-separated list of tests to skip."); + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; + +// Implementations for extension functions referenced in conformance tests. +cel::Value LocationCode(const cel::StringValue& ip, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory, google::protobuf::Arena* arena) { + std::string ip_str = ip.ToString(); + if (ip_str == "10.0.0.1") return cel::StringValue(arena, "us"); + if (ip_str == "10.0.0.2") return cel::StringValue(arena, "de"); + return cel::StringValue(arena, "ir"); +} + +// TODO(uncreated-issue/92): This should be migrated to use the testrunner utility +// after adding support for reading the yaml specification for envs/tests. +class InputEvaluator { + public: + static absl::StatusOr> Create( + const std::shared_ptr& pool) { + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + // Enable default extensions (optional, bindings) + cel::Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + env.SetConfig(config); + env_runtime.SetConfig(config); + + auto compiler_builder_or = env.NewCompilerBuilder(); + CEL_ASSIGN_OR_RETURN(auto compiler_builder, std::move(compiler_builder_or)); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + auto runtime_builder_or = env_runtime.CreateRuntimeBuilder(); + CEL_ASSIGN_OR_RETURN(auto runtime_builder, std::move(runtime_builder_or)); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + return absl::WrapUnique( + new InputEvaluator(std::move(compiler), std::move(runtime))); + } + + absl::StatusOr Evaluate( + absl::string_view expr_str, google::protobuf::Arena* arena, + google::protobuf::MessageFactory* message_factory) const { + CEL_ASSIGN_OR_RETURN(auto validation_result, compiler_->Compile(expr_str)); + if (!validation_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to compile input expr: ", expr_str)); + } + CEL_ASSIGN_OR_RETURN(auto ast, validation_result.ReleaseAst()); + CEL_ASSIGN_OR_RETURN( + auto program, + runtime_->CreateProgram(std::make_unique(std::move(*ast)))); + cel::Activation activation; + EvaluateOptions options; + options.message_factory = message_factory; + return program->Evaluate(arena, activation, options); + } + + private: + InputEvaluator(std::unique_ptr compiler, + std::unique_ptr runtime) + : compiler_(std::move(compiler)), runtime_(std::move(runtime)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; +}; + +absl::StatusOr EvaluateInputValue( + const cel::expr::conformance::test::InputValue& input_val, + const InputEvaluator& evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + if (input_val.has_expr()) { + return evaluator.Evaluate(input_val.expr(), arena, message_factory); + } + if (input_val.has_value()) { + return cel::test::FromExprValue(input_val.value(), descriptor_pool, + message_factory, arena); + } + return absl::InvalidArgumentError("Empty InputValue"); +} + +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + CelValueMatcherImpl(cel::Value expected_val, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) + : expected_val_(std::move(expected_val)), + pool_(pool), + message_factory_(message_factory), + arena_(arena) {} + + bool MatchAndExplain(const cel::Value& actual_val, + testing::MatchResultListener* listener) const override { + cel::Value actual = actual_val; + if (actual.IsOptional() && !expected_val_.IsOptional()) { + auto opt_val = actual.AsOptional(); + if (opt_val->HasValue()) { + actual = opt_val->Value(); + } + } + cel::Value eq_result; + auto eq_status = actual.Equal(expected_val_, pool_, message_factory_, + arena_, &eq_result); + if (!eq_status.ok()) { + *listener << "equality check failed with status: " << eq_status; + return false; + } + if (!eq_result.IsTrue()) { + *listener << "expected: " << expected_val_.DebugString() + << "\nactual: " << actual.DebugString(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "is equal to " << expected_val_.DebugString(); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not equal to " << expected_val_.DebugString(); + } + + private: + cel::Value expected_val_; + const google::protobuf::DescriptorPool* pool_; + google::protobuf::MessageFactory* message_factory_; + google::protobuf::Arena* arena_; +}; + +absl::StatusOr> MakeExpectedValueMatcher( + const cel::expr::conformance::test::TestOutput& output, + const InputEvaluator& input_evaluator, const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + cel::Value expected_val; + if (output.has_result_expr()) { + CEL_ASSIGN_OR_RETURN( + expected_val, + input_evaluator.Evaluate(output.result_expr(), arena, message_factory)); + } else if (output.has_result_value()) { + CEL_ASSIGN_OR_RETURN(expected_val, + cel::test::FromExprValue(output.result_value(), pool, + message_factory, arena)); + } else { + return absl::InvalidArgumentError("Unsupported output kind"); + } + return testing::Matcher( + new CelValueMatcherImpl(expected_val, pool, message_factory, arena)); +} + +bool ShouldRunTest(absl::string_view test_name, + const std::vector& skip_tests) { + for (const std::string& skip : skip_tests) { + if (absl::StartsWith(test_name, skip)) { + return false; + } + } + return true; +} + +absl::Status PopulateActivation( + const cel::expr::conformance::test::TestCase& test, + const InputEvaluator& input_evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + absl::string_view context_msg_type_name, google::protobuf::Arena* arena, + Activation& activation) { + if (!test.has_input_context()) { + for (const auto& [var_name, input_val] : test.input()) { + CEL_ASSIGN_OR_RETURN( + auto val, + EvaluateInputValue(input_val, input_evaluator, descriptor_pool, + message_factory, arena)); + activation.InsertOrAssignValue(var_name, std::move(val)); + } + return absl::OkStatus(); + } + + const auto& input_context = test.input_context(); + const google::protobuf::Message* context_message = nullptr; + + if (input_context.has_context_message()) { + const google::protobuf::Any& any_msg = input_context.context_message(); + const google::protobuf::Descriptor* msg_descriptor = + descriptor_pool->FindMessageTypeByName(context_msg_type_name); + if (msg_descriptor == nullptr) { + return absl::NotFoundError(absl::StrCat( + "Failed to find message descriptor for: ", context_msg_type_name)); + } + const google::protobuf::Message* prototype = + message_factory->GetPrototype(msg_descriptor); + if (prototype == nullptr) { + return absl::NotFoundError( + absl::StrCat("Failed to get prototype for: ", context_msg_type_name)); + } + auto* buf = prototype->New(arena); + if (!any_msg.UnpackTo(buf)) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to unpack context message to ", context_msg_type_name)); + } + context_message = buf; + } else if (input_context.has_context_expr() && + !context_msg_type_name.empty()) { + CEL_ASSIGN_OR_RETURN(cel::Value evaluated_val, + input_evaluator.Evaluate(input_context.context_expr(), + arena, message_factory)); + + if (!evaluated_val.IsParsedMessage()) { + return absl::InvalidArgumentError( + absl::StrCat("Context expression did not evaluate to a message: ", + input_context.context_expr())); + } + if (evaluated_val.GetParsedMessage().GetDescriptor()->full_name() != + context_msg_type_name) { + return absl::InvalidArgumentError(absl::StrCat( + "Context expression evaluated to a message of type ", + evaluated_val.GetParsedMessage().GetDescriptor()->full_name(), + " which does not match the expected type ", context_msg_type_name)); + } + context_message = static_cast( + evaluated_val.GetParsedMessage().operator->()); + } + if (context_message == nullptr) { + return absl::InvalidArgumentError( + "Failed to resolve context message for test case"); + } + + return cel::extensions::BindProtoToActivation( + *context_message, + cel::extensions::BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool, message_factory, arena, &activation); +} + +class PolicyTestSuiteRunner { + public: + PolicyTestSuiteRunner(std::string suite_name, + std::unique_ptr compiler, + std::unique_ptr runtime, + std::shared_ptr policy_source, + CelPolicyValidationResult compile_result, + std::shared_ptr pool, + std::shared_ptr message_factory, + std::shared_ptr input_evaluator, + std::string context_msg_type_name, + bool expect_compile_fail = false) + : suite_name_(std::move(suite_name)), + compiler_(std::move(compiler)), + runtime_(std::move(runtime)), + policy_source_(std::move(policy_source)), + compile_result_(std::move(compile_result)), + pool_(std::move(pool)), + message_factory_(std::move(message_factory)), + input_evaluator_(std::move(input_evaluator)), + context_msg_type_name_(std::move(context_msg_type_name)), + expect_compile_fail_(expect_compile_fail) {} + + void RunTest(const cel::expr::conformance::test::TestCase& test, + absl::string_view full_test_name) { + const auto& output = test.output(); + + if (expect_compile_fail_) { + ASSERT_FALSE(compile_result_.IsValid()) + << "Expected compilation to fail in " << full_test_name; + ASSERT_TRUE(output.has_eval_error()) + << "Expected eval_error to be present in compile error test " + << full_test_name; + std::string err_msg = compile_result_.FormatIssues(); + for (const auto& expected_err : output.eval_error().errors()) { + EXPECT_THAT(err_msg, HasSubstr(expected_err.message())) + << "Did not find expected compile time error"; + } + return; + } + + // Compilation should have succeeded for evaluation tests + ASSERT_TRUE(compile_result_.IsValid()) + << "Compilation has validation errors in " << full_test_name << ": " + << compile_result_.FormatIssues(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime_->CreateProgram(std::make_unique( + *compile_result_.GetAst()))); + + // Parse Inputs and evaluate them + google::protobuf::Arena arena; + Activation activation; + ASSERT_THAT(PopulateActivation(test, *input_evaluator_, pool_.get(), + message_factory_.get(), + context_msg_type_name_, &arena, activation), + IsOk()); + + // Evaluate Policy + auto eval_result_or = program->Evaluate(&arena, activation); + ASSERT_THAT(eval_result_or.status(), IsOk()) + << "Evaluation failed in " << full_test_name; + cel::Value actual_val = *eval_result_or; + + ASSERT_OK_AND_ASSIGN( + auto matcher, + MakeExpectedValueMatcher(output, *input_evaluator_, pool_.get(), + message_factory_.get(), &arena)); + + // Apply matcher to the output of evaluation + EXPECT_THAT(actual_val, matcher) << "Test failed: " << full_test_name; + } + + private: + std::string suite_name_; + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::shared_ptr policy_source_; + CelPolicyValidationResult compile_result_; + std::shared_ptr pool_; + std::shared_ptr message_factory_; + std::shared_ptr input_evaluator_; + std::string context_msg_type_name_; + bool expect_compile_fail_; +}; + +class CelPolicyTest : public testing::Test { + public: + explicit CelPolicyTest(std::shared_ptr runner, + cel::expr::conformance::test::TestCase test_case, + std::string full_test_name, bool skip) + : runner_(std::move(runner)), + test_case_(std::move(test_case)), + full_test_name_(std::move(full_test_name)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP() << "Skipping test: " << full_test_name_; + } + EXPECT_NO_FATAL_FAILURE(runner_->RunTest(test_case_, full_test_name_)); + } + + private: + std::shared_ptr runner_; + cel::expr::conformance::test::TestCase test_case_; + std::string full_test_name_; + bool skip_; +}; + + +absl::Status RegisterTestSuite( + const std::filesystem::path& dir_path, const std::string& suite_name, + const std::shared_ptr& input_evaluator, + const std::shared_ptr& pool, + const std::shared_ptr& message_factory, + const std::vector& skip_tests) { + // Check if the entire suite should be skipped (prefix match) + for (const auto& skip : skip_tests) { + if (suite_name == skip || + absl::StartsWith(suite_name, absl::StrCat(skip, "/"))) { + std::cout << "[ SKIPPED SUITE ] " << suite_name << std::endl; + return absl::OkStatus(); + } + } + + std::filesystem::path policy_path = dir_path / "policy.yaml"; + std::filesystem::path tests_path = dir_path / "tests.yaml"; + bool is_yaml = true; + if (!std::filesystem::exists(tests_path)) { + tests_path = dir_path / "tests.textproto"; + is_yaml = false; + } + std::filesystem::path config_path = dir_path / "config.yaml"; + + if (!std::filesystem::exists(policy_path) || + !std::filesystem::exists(tests_path)) { + // Not a valid test suite, assume it's a directory we don't care about. + return absl::OkStatus(); + } + + // Parse Environment Config + cel::Config config; + if (std::filesystem::exists(config_path)) { + std::string config_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(config_path.string(), &config_content)); + CEL_ASSIGN_OR_RETURN(config, cel::EnvConfigFromYaml(config_content)); + } + + // Enable default extensions (optional, bindings) in the config + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + + // Set up compiler & runtime environments + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler_builder, env.NewCompilerBuilder()); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + env_runtime.CreateRuntimeBuilder()); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + // Register locationCode in runtime + CEL_RETURN_IF_ERROR( + (cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("locationCode", LocationCode, + runtime_builder.function_registry()))); + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + // Parse Policy + std::string policy_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(policy_path.string(), &policy_content)); + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(policy_content, "policy.yaml")); + auto policy_source = std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse policy.yaml in ", suite_name, + "\nIssues:\n", parse_result.FormattedIssues())); + } + const CelPolicy* policy = parse_result.GetPolicy(); + + // Compile Policy (unexpected non-ok status represents a bug) + CEL_ASSIGN_OR_RETURN(CelPolicyValidationResult compile_result, + CompilePolicy(*compiler, *policy)); + + std::string tests_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(tests_path.string(), &tests_content)); + TestSuite test_suite; + if (is_yaml) { + CEL_ASSIGN_OR_RETURN(test_suite, + cel::test::ParsePolicyTestSuiteYaml(tests_content)); + } else { + if (!google::protobuf::TextFormat::ParseFromString(tests_content, &test_suite)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse text proto in ", tests_path.string())); + } + } + + auto runner = std::make_shared( + suite_name, std::move(compiler), std::move(runtime), + std::move(policy_source), std::move(compile_result), pool, + message_factory, input_evaluator, config.GetContextType(), + /*expect_compile_fail=*/absl::StrContains(suite_name, "compile_errors")); + + for (const auto& section : test_suite.sections()) { + std::string section_name = section.name(); + for (const auto& test : section.tests()) { + std::string test_name = test.name(); + std::string full_test_name = + absl::StrCat(suite_name, "/", section_name, "/", test_name); + + bool skip = !ShouldRunTest(full_test_name, skip_tests); + + testing::RegisterTest( + suite_name.c_str(), + absl::StrCat(section_name, "/", test_name).c_str(), nullptr, + test_name.c_str(), __FILE__, __LINE__, + [runner, test, full_test_name, skip]() -> CelPolicyTest* { + return new CelPolicyTest(runner, test, full_test_name, skip); + }); + } + } + return absl::OkStatus(); +} + +void RegisterAllTests() { + // cel::google3-end + std::string testdata_example_flag = absl::GetFlag(FLAGS_testdata_example); + std::vector skip_tests = absl::GetFlag(FLAGS_skip_tests); + + std::string abs_testdata_example = + cel::internal::ResolveRunfilesPath(testdata_example_flag); + ABSL_CHECK(!abs_testdata_example.empty()) + << "Could not find testdata directory: " << testdata_example_flag; + + std::shared_ptr pool = + GetSharedTestingDescriptorPool(); + auto message_factory = + std::make_shared(pool.get()); + message_factory->SetDelegateToGeneratedFactory(true); + auto evaluator_or = InputEvaluator::Create(pool); + ABSL_CHECK_OK(evaluator_or.status()) << "Failed to create input evaluator"; + std::shared_ptr evaluator = std::move(evaluator_or.value()); + + std::filesystem::path testdata_path(abs_testdata_example); + ABSL_CHECK(std::filesystem::exists(testdata_path)) + << "Testdata path does not exist: " << testdata_path; + // walk up to find 'testdata' parent. A work around to portably + // get the expected directory from bazel. + while (!absl::EndsWith(testdata_path.string(), "testdata")) { + testdata_path = testdata_path.parent_path(); + ABSL_CHECK(testdata_path.string().size() > sizeof("testdata")) + << "could not resolve testdata directory"; + } + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(testdata_path)) { + if (!entry.is_directory()) { + continue; + } + std::filesystem::path dir_path = entry.path(); + // Check if this directory has policy.yaml and tests.yaml (or + // tests.textproto) + if (std::filesystem::exists(dir_path / "policy.yaml") && + (std::filesystem::exists(dir_path / "tests.yaml") || + std::filesystem::exists(dir_path / "tests.textproto"))) { + std::string suite_name = absl::StrReplaceAll( + std::filesystem::relative(dir_path, testdata_path).string(), + {{"\\", "/"}}); + + ABSL_CHECK_OK(RegisterTestSuite(dir_path, suite_name, evaluator, pool, + message_factory, skip_tests)); + } + } +} + +} // namespace +} // namespace cel + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + cel::RegisterAllTests(); + return RUN_ALL_TESTS(); +} diff --git a/internal/BUILD b/internal/BUILD index 0ac5c4e46..6d0efab72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -92,6 +92,7 @@ cc_library( hdrs = ["runfiles.h"], deps = [ "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@rules_cc//cc/runfiles", ], diff --git a/internal/runfiles.cc b/internal/runfiles.cc index 259e2e7ca..bffbfa9d1 100644 --- a/internal/runfiles.cc +++ b/internal/runfiles.cc @@ -14,11 +14,14 @@ #include "internal/runfiles.h" +#include +#include #include #include "rules_cc/cc/runfiles/runfiles.h" - #include "absl/log/absl_check.h" + +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -37,4 +40,14 @@ std::string ResolveRunfilesPath(absl::string_view path) { return runfiles->Rlocation(std::string(path)); } +absl::Status GetFileContents(absl::string_view path, std::string* out) { + std::ifstream file{std::string(path)}; + if (!file.is_open()) { + return absl::NotFoundError(absl::StrCat("Failed to open file: ", path)); + } + out->append((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + return absl::OkStatus(); +} + } // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h index 643c677b4..11fdcf337 100644 --- a/internal/runfiles.h +++ b/internal/runfiles.h @@ -11,12 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// Utilities for working with bazel runfiles. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" namespace cel::internal { @@ -25,6 +28,9 @@ namespace cel::internal { // Intended for resolving test cases from cel-spec and cel-policy. std::string ResolveRunfilesPath(absl::string_view path); +// Read contents of a file at a resolved path to a string. +absl::Status GetFileContents(absl::string_view path, std::string* out); + } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ diff --git a/policy/BUILD b/policy/BUILD new file mode 100644 index 000000000..19195be2b --- /dev/null +++ b/policy/BUILD @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cel_policy", + srcs = [ + "cel_policy.cc", + ], + hdrs = [ + "cel_policy.h", + ], + deps = [ + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cel_policy_test", + srcs = ["cel_policy_test.cc"], + deps = [ + ":cel_policy", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cel_policy_parser", + srcs = [ + "cel_policy_parse_context.cc", + "cel_policy_parse_result.cc", + ], + hdrs = [ + "cel_policy_parse_context.h", + "cel_policy_parse_result.h", + "cel_policy_parser.h", + ], + deps = [ + ":cel_policy", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "yaml_policy_parser", + srcs = [ + "yaml_policy_parser.cc", + ], + hdrs = ["yaml_policy_parser.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:source", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_library( + name = "cel_policy_validation_result", + srcs = [ + "cel_policy_validation_result.cc", + ], + hdrs = [ + "cel_policy_validation_result.h", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//policy/internal:issue_reporter", + "//policy/internal:optimizer_expr_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "yaml_policy_parser_test", + srcs = [ + "test_custom_yaml_policy_parser.cc", + "yaml_policy_parser_test.cc", + ], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":yaml_policy_parser", + "//common:source", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + ":compiler", + ":yaml_policy_parser", + "//common:ast", + "//common:decl", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@yaml-cpp", + ], +) diff --git a/policy/cel_policy.cc b/policy/cel_policy.cc new file mode 100644 index 000000000..c2d97edeb --- /dev/null +++ b/policy/cel_policy.cc @@ -0,0 +1,273 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +std::string IdDebugString(CelPolicyElementId id) { + if (id == -1) { + return ""; + } + return absl::StrCat("#", id, "> "); +} + +std::string IndentBlock(absl::string_view text) { + if (text.empty()) { + return ""; + } + std::vector lines; + for (absl::string_view line : absl::StrSplit(text, '\n')) { + if (line.empty()) { + lines.push_back(""); + } else { + lines.push_back(absl::StrCat(" ", line)); + } + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +void CelPolicySource::NoteSourcePosition(CelPolicyElementId id, + SourcePosition position) { + source_positions_[id] = position; +} + +std::optional CelPolicySource::GetSourcePosition( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return it->second; +} + +std::optional CelPolicySource::GetSourceLocation( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return policy_source_->GetLocation(it->second); +} + +std::string CelPolicySource::DebugString() const { + std::string result; + + // Sort the source elements in descending order of position + std::vector> sorted_positions; + for (const auto& pair : source_positions_) { + sorted_positions.push_back(pair); + } + std::sort(sorted_positions.begin(), sorted_positions.end(), + [](const auto& a, const auto& b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + }); + + result = policy_source_->content().ToString(); + for (const auto& [id, position] : sorted_positions) { + result.insert(position, IdDebugString(id)); + } + return result; +} + +std::string ValueString::DebugString() const { + return absl::StrCat(IdDebugString(id_), "\"", value_, "\""); +} + +std::string Import::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "name: ", name_.DebugString()); + return result; +} + +std::string OutputBlock::DebugString() const { + std::string result; + absl::StrAppend(&result, "output: ", output_.DebugString()); + if (explanation_.has_value()) { + absl::StrAppend(&result, "\nexplanation: ", explanation_->DebugString()); + } + return result; +} + +Match::Match(const Match& other) + : id_(other.id_), condition_(other.condition_) { + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = + std::make_unique(*std::get>(other.result_)); + } +} + +Match& Match::operator=(const Match& other) { + if (this != &other) { + id_ = other.id_; + condition_ = other.condition_; + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = std::make_unique( + *std::get>(other.result_)); + } + } + return *this; +} + +std::string Match::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "match: {\n"); + if (condition_.has_value()) { + absl::StrAppend(&result, " condition: ", condition_->DebugString(), "\n"); + } + if (has_rule()) { + absl::StrAppend(&result, " result:\n", + IndentBlock(IndentBlock(rule().DebugString())), "\n"); + } else { + absl::StrAppend(&result, " result: {\n", + IndentBlock(IndentBlock(output_block().DebugString())), + "\n }\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Variable::DebugString() const { + std::string result; + absl::StrAppend(&result, "variable: {\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + absl::StrAppend(&result, " expression: ", expression_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Rule::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "rule: {\n"); + if (rule_id_.has_value()) { + absl::StrAppend(&result, " rule_id: ", rule_id_->DebugString(), "\n"); + } + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + for (const Variable& variable : variables_) { + absl::StrAppend(&result, IndentBlock(variable.DebugString()), "\n"); + } + for (const Match& match : matches_) { + absl::StrAppend(&result, IndentBlock(match.DebugString()), "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string MetadataValueDebugString(std::any value) { + if (value.type() == typeid(std::monostate)) { + return "null"; + } + if (value.type() == typeid(ValueString)) { + return std::any_cast(value).DebugString(); + } + if (value.type() == typeid(bool)) { + return std::any_cast(value) ? "true" : "false"; + } + if (value.type() == typeid(int)) { + return absl::StrCat(std::any_cast(value)); + } + if (value.type() == typeid(std::string)) { + return std::any_cast(value); + } + return absl::StrCat("typeid: ", value.type().name()); +} + +std::string CelPolicy::DebugString() const { + std::string result; + absl::StrAppend(&result, "CelPolicy{\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, IndentBlock(IndentBlock(source_->DebugString())), + "\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + if (!metadata_.empty()) { + std::vector sorted_keys; + for (const auto& [key, _] : metadata_) { + sorted_keys.push_back(key); + } + std::sort(sorted_keys.begin(), sorted_keys.end()); + + absl::StrAppend(&result, " metadata: {\n"); + for (const auto& key : sorted_keys) { + const auto& value = metadata_.at(key); + absl::StrAppend(&result, " ", key, ": ", + MetadataValueDebugString(value), "\n"); + } + absl::StrAppend(&result, " }\n"); + } + if (!imports_.empty()) { + absl::StrAppend(&result, " imports:\n"); + for (const Import& import : imports_) { + absl::StrAppend(&result, " ", import.DebugString(), "\n"); + } + } + absl::StrAppend(&result, IndentBlock(rule_.DebugString()), "\n"); + absl::StrAppend(&result, "}"); + return result; +} + +} // namespace cel diff --git a/policy/cel_policy.h b/policy/cel_policy.h new file mode 100644 index 000000000..af8f7c977 --- /dev/null +++ b/policy/cel_policy.h @@ -0,0 +1,320 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +using CelPolicyElementId = int32_t; + +class CelPolicySource { + public: + explicit CelPolicySource(cel::SourcePtr policy_source) + : policy_source_(std::move(policy_source)) {} + + const Source* absl_nonnull content() const { return policy_source_.get(); } + + void NoteSourcePosition(CelPolicyElementId id, SourcePosition position); + + std::optional GetSourcePosition(CelPolicyElementId id) const; + + std::optional GetSourceLocation(CelPolicyElementId id) const; + + std::string DebugString() const; + + private: + cel::SourcePtr policy_source_; + absl::flat_hash_map source_positions_; +}; + +class ValueString { + public: + ValueString() : id_(-1) {} + + explicit ValueString(CelPolicyElementId id, absl::string_view value) + : id_(id), value_(value) {} + + CelPolicyElementId id() const { return id_; } + absl::string_view value() const { return value_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + std::string value_; +}; + +class Import { + public: + Import(CelPolicyElementId id, ValueString name) + : id_(id), name_(std::move(name)) {} + CelPolicyElementId id() const { return id_; } + const ValueString& name() const { return name_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + ValueString name_; +}; + +// Defines a variable that can be used in CEL expressions within the policy. +// Variables are evaluated once and stored in the activation context. +class Variable { + public: + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + const ValueString& expression() const { return expression_; } + void set_expression(ValueString expression) { + expression_ = std::move(expression); + } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + + std::string DebugString() const; + + private: + ValueString name_; + ValueString expression_; + std::optional description_; + std::optional display_name_; +}; + +class Rule; + +class OutputBlock { + public: + OutputBlock() = default; + OutputBlock(ValueString output, std::optional explanation) + : output_(std::move(output)), explanation_(std::move(explanation)) {} + + const ValueString& output() const { return output_; } + void set_output(ValueString output) { output_ = std::move(output); } + + const std::optional& explanation() const { return explanation_; } + void set_explanation(ValueString explanation) { + explanation_ = std::move(explanation); + } + + std::string DebugString() const; + + private: + ValueString output_; + std::optional explanation_; +}; + +// Defines a match condition and result. +// If the result is a Rule, it is considered a sub-rule and will be evaluated +// only if the match condition evaluates to true. +class Match { + public: + Match() = default; + Match(const Match& other); + Match& operator=(const Match& other); + + CelPolicyElementId id() const; + void set_id(CelPolicyElementId id); + + bool has_condition() const; + std::optional condition() const; + void set_condition(ValueString condition); + + bool has_output_block() const; + const OutputBlock& output_block() const; + OutputBlock& mutable_output_block(); + + bool has_rule() const; + const Rule& rule() const; + Rule& mutable_rule(); + + void set_result(OutputBlock result); + void set_result(std::unique_ptr result); + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional condition_; + std::variant> result_; +}; + +// Rule is the body of the policy and contains a list of variables and matches. +// Variables are evaluated once and stored in the activation context. +// Matches are evaluated in order and the first match is returned. If the +// match contains a sub-rule, the sub-rule is evaluated only if the match +// condition evaluates to true. +class Rule { + public: + Rule() = default; + Rule(const Rule& other) = default; + + CelPolicyElementId id() const { return id_; } + void set_id(CelPolicyElementId id) { id_ = id; } + + const std::optional& rule_id() const { return rule_id_; } + void set_rule_id(ValueString rule_id) { rule_id_ = std::move(rule_id); } + + const std::optional& description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + const std::vector& variables() const { return variables_; } + std::vector& mutable_variables() { return variables_; } + + const std::vector& matches() const { return matches_; } + std::vector& mutable_matches() { return matches_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional rule_id_; + std::optional description_; + std::vector variables_; + std::vector matches_; +}; + +// CelPolicy is the top-level policy object. +// It contains a source, name, description, display name, imports, and a rule. +// The source is the CEL policy source code. +// The name, description, and display name are metadata about the policy. +// The rule is the main body of the policy. +class CelPolicy { + public: + explicit CelPolicy(std::shared_ptr source) + : source_(std::move(source)) {} + + CelPolicy(const CelPolicy& other) = default; + CelPolicy& operator=(const CelPolicy& other) = default; + + const CelPolicySource* absl_nullable source() const { return source_.get(); } + const std::shared_ptr& source_ptr() const { return source_; } + + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + const absl::flat_hash_map& metadata() const { + return metadata_; + } + absl::flat_hash_map& mutable_metadata() { + return metadata_; + } + const std::vector& imports() const { return imports_; } + std::vector& mutable_imports() { return imports_; } + + const Rule& rule() const { return rule_; } + Rule& mutable_rule() { return rule_; } + + std::string DebugString() const; + + private: + std::shared_ptr source_; + ValueString name_; + std::optional description_; + std::optional display_name_; + absl::flat_hash_map metadata_; + std::vector imports_; + Rule rule_; +}; + +// Implementation details. + +inline CelPolicyElementId Match::id() const { return id_; } +inline void Match::set_id(CelPolicyElementId id) { id_ = id; } + +inline bool Match::has_condition() const { return condition_.has_value(); } + +inline std::optional Match::condition() const { + return condition_; +} + +inline void Match::set_condition(ValueString condition) { + condition_ = std::move(condition); +} + +inline bool Match::has_output_block() const { + return std::holds_alternative(result_); +} + +inline const OutputBlock& Match::output_block() const { + ABSL_DCHECK(std::holds_alternative(result_)); + return std::get(result_); +} + +inline OutputBlock& Match::mutable_output_block() { + if (!std::holds_alternative(result_)) { + result_ = OutputBlock(); + } + return std::get(result_); +} + +inline bool Match::has_rule() const { + return std::holds_alternative>(result_); +} + +inline const Rule& Match::rule() const { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline Rule& Match::mutable_rule() { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline void Match::set_result(OutputBlock result) { + result_ = std::move(result); +} + +inline void Match::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ diff --git a/policy/cel_policy_parse_context.cc b/policy/cel_policy_parse_context.cc new file mode 100644 index 000000000..66861d085 --- /dev/null +++ b/policy/cel_policy_parse_context.cc @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_context.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +CelPolicy& CelPolicyParseContext::policy() const { + ABSL_CHECK(policy_ != nullptr) + << "CelPolicyParseContext::policy() called after GetResult()"; + return *policy_; +} + +CelPolicyParseResult CelPolicyParseContext::GetResult() { + if (policy_ != nullptr && issues_.empty()) { + return CelPolicyParseResult(std::move(policy_source_), std::move(policy_), + std::move(issues_)); + } + policy_.reset(); + return CelPolicyParseResult(std::move(policy_source_), nullptr, + std::move(issues_)); +} + +void CelPolicyParseContext::ReportError(CelPolicyElementId element_id, + std::string_view message) { + issues_.push_back(CelPolicyIssue(element_id, std::string(message))); +} + +} // namespace cel diff --git a/policy/cel_policy_parse_context.h b/policy/cel_policy_parse_context.h new file mode 100644 index 000000000..6482fa1ae --- /dev/null +++ b/policy/cel_policy_parse_context.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// A mutable context for parsing a CelPolicy. An instance of this class is +// created for each policy parse and is passed to the parser, which is meant to +// be stateless. +// +// Parsers call methods on this class to report issues and populate the policy +// being parsed. Call GetResult() to obtain the resulting CelPolicyParseResult, +// which takes ownership of the parsed policy. Do not use the context after +// calling GetResult(). +class CelPolicyParseContext { + public: + explicit CelPolicyParseContext(std::shared_ptr policy_source) + : policy_source_(std::move(policy_source)), + policy_(std::make_unique(policy_source_)) {} + + CelPolicySource& policy_source() const { return *policy_source_; } + + // Returns the policy being parsed. It should not be used after + // calling GetResult(). + CelPolicy& policy() const; + + // The context should not be used after calling GetResult(). + CelPolicyParseResult GetResult(); + + // Reports an error for the given element with the given error message. + void ReportError(CelPolicyElementId id, std::string_view message); + + CelPolicyElementId next_element_id() { return next_element_id_++; } + + private: + std::shared_ptr policy_source_; + CelPolicyElementId next_element_id_ = 0; + std::vector issues_; + std::unique_ptr policy_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ diff --git a/policy/cel_policy_parse_result.cc b/policy/cel_policy_parse_result.cc new file mode 100644 index 000000000..32d6431bb --- /dev/null +++ b/policy/cel_policy_parse_result.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_result.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { +namespace { + +absl::string_view SeverityString(CelPolicyIssue::Severity severity) { + switch (severity) { + case CelPolicyIssue::Severity::kInformation: + return "INFORMATION"; + case CelPolicyIssue::Severity::kWarning: + return "WARNING"; + case CelPolicyIssue::Severity::kError: + return "ERROR"; + case CelPolicyIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string CelPolicyIssue::ToDisplayString( + const CelPolicySource* absl_nullable source) const { + SourceLocation location; + std::string description; + std::string snippet; + if (source != nullptr) { + if (relative_position_) { + std::optional base = + source->GetSourcePosition(element_id_); + if (element_id_ == -1) { + base.emplace(0); + } + if (base) { + location = source->content() + ->GetLocation(*base + *relative_position_) + .value_or(SourceLocation{}); + } + } else { + location = + source->GetSourceLocation(element_id_).value_or(SourceLocation{}); + } + description = std::string(source->content()->description()); + snippet = source->content()->DisplayErrorLocation(location); + } + + const int display_column = location.column >= 0 ? location.column + 1 : -1; + + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + description, location.line, display_column, message_, + snippet); +} + +std::string CelPolicyParseResult::FormattedIssues() const { + std::string formatted_issues; + for (const CelPolicyIssue& issue : issues_) { + if (!formatted_issues.empty()) { + absl::StrAppend(&formatted_issues, "\n"); + } + absl::StrAppend(&formatted_issues, issue.ToDisplayString(*policy_source_)); + } + return formatted_issues; +} + +} // namespace cel diff --git a/policy/cel_policy_parse_result.h b/policy/cel_policy_parse_result.h new file mode 100644 index 000000000..2bf80b1ce --- /dev/null +++ b/policy/cel_policy_parse_result.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { + +class CelPolicyIssue { + public: + enum class Severity { kInformation, kDeprecated, kWarning, kError }; + + CelPolicyIssue(CelPolicyElementId element_id, absl::string_view message) + : element_id_(element_id), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, Severity severity, + absl::string_view message) + : element_id_(element_id), severity_(severity), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, Severity severity, + absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + severity_(severity), + message_(message) {} + + std::string ToDisplayString( + const CelPolicySource* absl_nullable source) const; + std::string ToDisplayString(const CelPolicySource& source) const { + return ToDisplayString(&source); + } + + Severity severity() const { return severity_; } + absl::string_view message() const { return message_; } + + private: + CelPolicyElementId element_id_; + std::optional relative_position_; + Severity severity_ = Severity::kError; + std::string message_; +}; + +class CelPolicyParseResult { + public: + explicit CelPolicyParseResult(std::shared_ptr policy_source, + std::unique_ptr policy, + std::vector issues) + : policy_source_(std::move(policy_source)), + policy_(std::move(policy)), + issues_(std::move(issues)) {} + + bool IsValid() const { return policy_ != nullptr; } + + const CelPolicy* absl_nullable GetPolicy() const { return policy_.get(); } + + absl::StatusOr> ReleasePolicy() { + if (policy_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyParseResult is empty. Check for Issues."); + } + return std::move(policy_); + } + + absl::Span GetIssues() const { return issues_; } + + std::string FormattedIssues() const; + + private: + std::shared_ptr policy_source_; + absl_nullable std::unique_ptr policy_; + std::vector issues_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ diff --git a/policy/cel_policy_parser.h b/policy/cel_policy_parser.h new file mode 100644 index 000000000..0a11c9e68 --- /dev/null +++ b/policy/cel_policy_parser.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ + +#include "absl/status/status.h" +#include "policy/cel_policy_parse_context.h" + +namespace cel { + +// A policy parser for a given policy format. The type `T` parameter is the +// representation of the input file format, such as `` for YAML. +// +// Parsers are intended to be stateless: all state, including the resulting +// policy and any issues encountered, should be kept in the context passed to +// the `ParsePolicy` method. +template +class CelPolicyParser { + public: + virtual ~CelPolicyParser() = default; + + // Parses the input and populates a CelPolicy in the context. + virtual absl::Status ParsePolicy(CelPolicyParseContext& ctx) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ diff --git a/policy/cel_policy_test.cc b/policy/cel_policy_test.cc new file mode 100644 index 000000000..640247e7f --- /dev/null +++ b/policy/cel_policy_test.cc @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Field; +using testing::Optional; +using testing::SizeIs; + +TEST(CelPolicyBuilderTest, Build) { + CelPolicyElementId next_id = 1; + ASSERT_OK_AND_ASSIGN(SourcePtr source, NewSource("CEL\n policy\n source")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CelPolicy policy(policy_source); + policy.set_name(ValueString(next_id++, "test_policy")); + policy.set_description(ValueString(next_id++, "test_description")); + policy.set_display_name(ValueString(next_id++, "test_display_name")); + ValueString import1_name = ValueString(next_id++, "test_import1"); + policy.mutable_imports().push_back(Import(next_id++, import1_name)); + ValueString import2_name = ValueString(next_id++, "test_import2"); + policy.mutable_imports().push_back(Import(next_id++, import2_name)); + + Rule& rule = policy.mutable_rule(); + rule.set_id(next_id++); + rule.set_rule_id(ValueString(next_id++, "test_rule_id")); + rule.set_description(ValueString(next_id++, "test_rule_description")); + + Variable variable; + variable.set_name(ValueString(next_id++, "test_variable")); + variable.set_expression(ValueString(next_id++, "test_expression")); + variable.set_description(ValueString(next_id++, "test_variable_description")); + variable.set_display_name( + ValueString(next_id++, "test_variable_display_name")); + + Match match1; + match1.set_id(next_id++); + match1.set_condition(ValueString(next_id++, "test_condition")); + CelPolicyElementId output_id = next_id++; + CelPolicyElementId explanation_id = next_id++; + match1.set_result( + OutputBlock(ValueString(output_id, "test_result"), + ValueString(explanation_id, "test_explanation"))); + + Match match2; + match2.set_id(next_id++); + match2.set_condition(ValueString(next_id++, "test_condition2")); + + auto sub_rule = std::make_unique(); + sub_rule->set_id(next_id++); + sub_rule->set_rule_id(ValueString(next_id++, "sub_rule_id")); + sub_rule->set_description(ValueString(next_id++, "sub_rule_description")); + Match sub_rule_match; + sub_rule_match.set_id(next_id++); + sub_rule_match.set_condition(ValueString(next_id++, "sub_rule_condition")); + sub_rule_match.set_result( + OutputBlock(ValueString(next_id++, "sub_rule_result"), std::nullopt)); + sub_rule->mutable_matches().push_back(sub_rule_match); + + match2.set_result(std::move(sub_rule)); + + rule.mutable_variables().push_back(variable); + rule.mutable_matches().push_back(match1); + rule.mutable_matches().push_back(match2); + + EXPECT_EQ(policy.name().value(), "test_policy"); + ASSERT_TRUE(policy.description().has_value()); + EXPECT_EQ(policy.description()->value(), "test_description"); + ASSERT_TRUE(policy.display_name().has_value()); + EXPECT_EQ(policy.display_name()->value(), "test_display_name"); + + ASSERT_THAT(policy.imports(), SizeIs(2)); + + EXPECT_EQ(policy.imports()[0].name().value(), "test_import1"); + EXPECT_EQ(policy.imports()[1].name().value(), "test_import2"); + ASSERT_TRUE(policy.rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().rule_id()->value(), "test_rule_id"); + ASSERT_TRUE(policy.rule().description().has_value()); + EXPECT_EQ(policy.rule().description()->value(), "test_rule_description"); + + ASSERT_THAT(policy.rule().variables(), SizeIs(1)); + + EXPECT_EQ(policy.rule().variables()[0].name().value(), "test_variable"); + EXPECT_EQ(policy.rule().variables()[0].expression().value(), + "test_expression"); + ASSERT_TRUE(policy.rule().variables()[0].description().has_value()); + EXPECT_EQ(policy.rule().variables()[0].description()->value(), + "test_variable_description"); + ASSERT_TRUE(policy.rule().variables()[0].display_name().has_value()); + EXPECT_EQ(policy.rule().variables()[0].display_name()->value(), + "test_variable_display_name"); + + ASSERT_THAT(policy.rule().matches(), SizeIs(2)); + + EXPECT_EQ(policy.rule().matches()[0].condition().value().value(), + "test_condition"); + ASSERT_TRUE(policy.rule().matches()[0].has_output_block()); + EXPECT_EQ(policy.rule().matches()[0].output_block().output().value(), + "test_result"); + ASSERT_TRUE( + policy.rule().matches()[0].output_block().explanation().has_value()); + EXPECT_EQ(policy.rule().matches()[0].output_block().explanation()->value(), + "test_explanation"); + + EXPECT_EQ(policy.rule().matches()[1].condition().value().value(), + "test_condition2"); + ASSERT_TRUE(policy.rule().matches()[1].has_rule()); + ASSERT_TRUE(policy.rule().matches()[1].rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().rule_id()->value(), + "sub_rule_id"); + ASSERT_TRUE(policy.rule().matches()[1].rule().description().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().description()->value(), + "sub_rule_description"); + ASSERT_THAT(policy.rule().matches()[1].rule().matches(), SizeIs(1)); + EXPECT_EQ(policy.rule() + .matches()[1] + .rule() + .matches()[0] + .condition() + .value() + .value(), + "sub_rule_condition"); + + std::string actual = policy.DebugString(); + EXPECT_EQ(actual, absl::StrReplaceAll(R"(CelPolicy{ + =========================================================== + CEL + policy + source + =========================================================== + name: #1> "test_policy" + description: #2> "test_description" + display_name: #3> "test_display_name" + imports: + #5> name: #4> "test_import1" + #7> name: #6> "test_import2" + #8> rule: { + rule_id: #9> "test_rule_id" + description: #10> "test_rule_description" + variable: { + name: #11> "test_variable" + expression: #12> "test_expression" + description: #13> "test_variable_description" + display_name: #14> "test_variable_display_name" + } + #15> match: { + condition: #16> "test_condition" + result: { + output: #17> "test_result" + explanation: #18> "test_explanation" + } + } + #19> match: { + condition: #20> "test_condition2" + result: + #21> rule: { + rule_id: #22> "sub_rule_id" + description: #23> "sub_rule_description" + #24> match: { + condition: #25> "sub_rule_condition" + result: { + output: #26> "sub_rule_result" + } + } + } + } + } + })", + {{"\n ", "\n"}})); +} + +TEST(CelPolicySourceTest, Build) { + std::string source = + "name: test_policy\n imports:\n - name: test_import\n"; + + ASSERT_OK_AND_ASSIGN(SourcePtr source_ptr, NewSource(source)); + CelPolicySource policy_source(std::move(source_ptr)); + policy_source.NoteSourcePosition(1, source.find("test_policy")); + policy_source.NoteSourcePosition(2, source.find("test_import")); + + EXPECT_THAT(policy_source.GetSourcePosition(1), Optional(6)); + EXPECT_THAT(policy_source.GetSourceLocation(1), + Optional(AllOf(Field(&SourceLocation::line, 1), + Field(&SourceLocation::column, 6)))); + EXPECT_THAT(policy_source.GetSourcePosition(2), Optional(44)); + EXPECT_THAT(policy_source.GetSourceLocation(2), + Optional(AllOf(Field(&SourceLocation::line, 3), + Field(&SourceLocation::column, 13)))); + EXPECT_EQ(policy_source.GetSourcePosition(3), std::nullopt); + EXPECT_EQ(policy_source.GetSourceLocation(3), std::nullopt); + EXPECT_EQ( + policy_source.DebugString(), + "name: #1> test_policy\n imports:\n - name: #2> test_import\n"); +} + +} // namespace +} // namespace cel diff --git a/policy/cel_policy_validation_result.cc b/policy/cel_policy_validation_result.cc new file mode 100644 index 000000000..e257f064c --- /dev/null +++ b/policy/cel_policy_validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +std::string CelPolicyValidationResult::FormatIssues() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const CelPolicyIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/policy/cel_policy_validation_result.h b/policy/cel_policy_validation_result.h new file mode 100644 index 000000000..bddb9a3ca --- /dev/null +++ b/policy/cel_policy_validation_result.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// CelPolicyValidationResult holds the result of policy compilation. +// +// Policy compilation/validation errors are captured in issues. +class CelPolicyValidationResult { + public: + CelPolicyValidationResult( + std::unique_ptr ast, std::vector issues, + std::shared_ptr source = nullptr) + : ast_(std::move(ast)), + issues_(std::move(issues)), + source_(std::move(source)) {} + + explicit CelPolicyValidationResult( + std::vector issues, + std::shared_ptr source = nullptr) + : ast_(nullptr), issues_(std::move(issues)), source_(std::move(source)) {} + + // Returns true if validation succeeded and an AST is present. + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + // Moves out and returns the AST. + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyValidationResult is empty. Check for CelPolicyIssues."); + } + return std::move(ast_); + } + + // Returns the list of issues encountered during compilation. + absl::Span GetIssues() const { return issues_; } + + // Returns the contained policy source, if any. + const CelPolicySource* absl_nullable GetSource() const { + return source_.get(); + } + + // Returns a formatted error string of the compiled issues. + std::string FormatIssues() const; + + private: + absl_nullable std::unique_ptr ast_; + std::vector issues_; + std::shared_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ diff --git a/policy/compiler.cc b/policy/compiler.cc new file mode 100644 index 000000000..7a892447c --- /dev/null +++ b/policy/compiler.cc @@ -0,0 +1,1058 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/internal/issue_reporter.h" +#include "policy/internal/optimizer_expr_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +constexpr absl::string_view kCelBlock = "cel.@block"; + +enum class RuleSemantics { + // TODO(b/506179116): will also need "aggregate" or similar concept. + kFirstMatch, + + kNotForUseWithExhaustiveSwitchStatements, +}; + +template +void AbslStringify(Sink& s, RuleSemantics semantics) { + switch (semantics) { + case RuleSemantics::kFirstMatch: + s.Append("first_match"); + return; + default: + s.Append(""); + return; + } +} + +struct EmbeddedAst { + CelPolicyElementId id; + std::unique_ptr ast; +}; + +struct CompiledVariable { + std::string ident; + EmbeddedAst ast; +}; + +struct CompiledOutputBlock { + EmbeddedAst output_ast; + cel::Type result_type; + std::optional explanation_ast; +}; + +struct CompiledRule; + +struct CompiledMatch { + using Production = + std::variant absl_nonnull, + CompiledOutputBlock>; + + CelPolicyElementId id; + std::optional condition; + Production production; +}; + +struct CompiledRule { + CelPolicyElementId id; + std::vector variables; + std::vector matches; + // Not set if cannot be determined. + std::optional result_type; +}; + +std::optional GetOutputType( + const CompiledMatch::Production& production) { + return std::visit( + [](const auto& production) -> std::optional { + if constexpr (std::is_same_v, + CompiledOutputBlock>) { + return production.result_type; + } else if constexpr (std::is_same_v, + std::unique_ptr>) { + return production->result_type; + } + return std::nullopt; + }, + production); +} + +// Internal representation of the compiled policy elements. +// +// This is used for checking the component expression before composing into the +// final AST based on the provided rule semantics. +class IntermediateCompiledPolicy { + public: + CompiledRule& mutable_root_rule() { return root_rule_; } + + const CompiledRule& root_rule() const { return root_rule_; } + + void set_name(absl::string_view name) { name_ = name; } + absl::string_view name() const { return name_; } + void set_display_name(absl::string_view display_name) { + display_name_ = display_name; + } + absl::string_view display_name() const { return display_name_; } + void set_description(absl::string_view description) { + description_ = description; + } + absl::string_view description() const { return description_; } + + void set_semantics(RuleSemantics semantics) { semantics_ = semantics; } + RuleSemantics semantics() const { return semantics_; } + + private: + std::string name_; + std::string display_name_; + std::string description_; + RuleSemantics semantics_ = RuleSemantics::kFirstMatch; + + CompiledRule root_rule_; +}; + +CelPolicyIssue::Severity MapSeverity(cel::TypeCheckIssue::Severity severity) { + switch (severity) { + case cel::TypeCheckIssue::Severity::kError: + return CelPolicyIssue::Severity::kError; + case cel::TypeCheckIssue::Severity::kWarning: + return CelPolicyIssue::Severity::kWarning; + case cel::TypeCheckIssue::Severity::kDeprecated: + return CelPolicyIssue::Severity::kDeprecated; + default: + return CelPolicyIssue::Severity::kError; + } +} + +bool IsWrapperOf(cel::TypeKind wrapper_kind, cel::TypeKind primitive_kind) { + switch (wrapper_kind) { + case cel::TypeKind::kBoolWrapper: + return primitive_kind == cel::TypeKind::kBool; + case cel::TypeKind::kIntWrapper: + return primitive_kind == cel::TypeKind::kInt; + case cel::TypeKind::kUintWrapper: + return primitive_kind == cel::TypeKind::kUint; + case cel::TypeKind::kDoubleWrapper: + return primitive_kind == cel::TypeKind::kDouble; + case cel::TypeKind::kStringWrapper: + return primitive_kind == cel::TypeKind::kString; + case cel::TypeKind::kBytesWrapper: + return primitive_kind == cel::TypeKind::kBytes; + default: + return false; + } +} + +cel::Type FilterSpecialTypes(cel::Type type) { + if (type.IsTypeParam()) { + // Free type param should not appear in the output type, but if it does, + // force it to dyn. + return DynType(); + } + if (type.IsEnum()) { + return IntType{}; + } + if (type.IsError()) { + return DynType(); + } + if (type.IsType()) { + // drop parameters so all type types are compatible. + return TypeType{}; + } + return type; +} + +// Returns true if `from` is assignable to `to`. +// +// Slightly adjusted from the standard routine to cover some edge cases around +// null and wrappers. +// +// TODO(b/522391716): try to standardize assignability checks. +bool OutputTypeIsAssignable(cel::Type from, cel::Type to) { + from = FilterSpecialTypes(from); + to = FilterSpecialTypes(to); + + // Any and dyn are assignable to/from everything. + if (from.kind() == cel::TypeKind::kAny || + from.kind() == cel::TypeKind::kDyn || to.kind() == cel::TypeKind::kAny || + to.kind() == cel::TypeKind::kDyn) { + return true; + } + + // Wrappers auto-unwrap. + if (IsWrapperOf(from.kind(), to.kind()) || + IsWrapperOf(to.kind(), from.kind())) { + return true; + } + + // Null is assignable to anything that is message-like. + if (from.kind() == cel::TypeKind::kNull) { + switch (to.kind()) { + case cel::TypeKind::kNull: + case cel::TypeKind::kStruct: + case cel::TypeKind::kOpaque: + case cel::TypeKind::kTimestamp: + case cel::TypeKind::kDuration: + case cel::TypeKind::kBytesWrapper: + case cel::TypeKind::kBoolWrapper: + case cel::TypeKind::kIntWrapper: + case cel::TypeKind::kUintWrapper: + case cel::TypeKind::kDoubleWrapper: + case cel::TypeKind::kStringWrapper: + return true; + default: + return false; + } + } + + if (from.kind() != to.kind()) { + return false; + } + + if (from.name() != to.name()) { + return false; + } + + if (from.GetParameters().size() != to.GetParameters().size()) { + return false; + } + + for (int i = 0; i < from.GetParameters().size(); ++i) { + if (!OutputTypeIsAssignable(from.GetParameters()[i], + to.GetParameters()[i])) { + return false; + } + } + + return true; +} + +bool OutputTypeIsCompatible(cel::Type from, cel::Type to) { + // We don't handle widening like in a self-contained CEL expression, but + // permit some cases where one type is more specific than the other. + return OutputTypeIsAssignable(from, to) || OutputTypeIsAssignable(to, from); +} + +bool HasErrors(const policy_internal::IssueReporter& issues) { + for (const auto& issue : issues.issues()) { + if (issue.severity() == CelPolicyIssue::Severity::kError) { + return true; + } + } + return false; +} + +// Note on lifetime safety: +// +// The output policy will contain references to types that are owned by the +// arena member of this class. This is safe as long as the policy compiler lives +// as long as the output policies. +class PolicyCompiler { + public: + explicit PolicyCompiler(policy_internal::IssueReporter* issues, + std::unique_ptr base_compiler) + : issues_(*issues), base_compiler_(std::move(base_compiler)) {} + + absl::string_view GetSourceDescription() const { + if (src_ == nullptr) { + return ""; + } + return src_->content()->description(); + } + + void AdaptTypeCheckIssues(CelPolicyElementId id, const ValidationResult& r) { + const Source* source = r.GetSource(); + + for (const auto& iss : r.GetIssues()) { + std::optional offset; + if (source != nullptr) { + offset = source->GetPosition(iss.location()); + } + if (offset.has_value()) { + issues_.ReportOffsetIssue(id, offset.value(), + MapSeverity(iss.severity()), iss.message()); + continue; + } + issues_.ReportIssue(id, MapSeverity(iss.severity()), iss.message()); + } + } + + absl::StatusOr CompileOutputBlock( + const cel::OutputBlock& output_block, const Compiler* env) { + CompiledOutputBlock output; + CEL_ASSIGN_OR_RETURN(auto output_validation, + env->Compile(output_block.output().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.output().id(), output_validation); + + cel::Type result_type = DynType(); + if (output_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, output_validation.ReleaseAst()); + auto root_expr_id = ast->root_expr().id(); + output.output_ast = + EmbeddedAst{output_block.output().id(), std::move(ast)}; + if (auto it = output_validation.GetResolvedTypeMap().find(root_expr_id); + it != output_validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + } + if (output_block.explanation().has_value()) { + CEL_ASSIGN_OR_RETURN(auto explanation_validation, + env->Compile(output_block.explanation()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.explanation()->id(), + explanation_validation); + if (explanation_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, explanation_validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kString) { + issues_.ReportError(output_block.explanation()->id(), + "explanation must evaluate to string"); + } else { + output.explanation_ast = + EmbeddedAst{output_block.explanation()->id(), std::move(ast)}; + } + } + } + output.result_type = result_type; + return output; + } + + absl::Status CompileMatch(const Match& match, const Compiler* env, + CompiledRule* out) { + CompiledMatch c_match; + c_match.id = match.id(); + if (match.condition().has_value()) { + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(match.condition()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(match.condition()->id(), validation); + if (validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kBool) { + issues_.ReportError(match.condition()->id(), + "condition must evaluate to bool"); + } + c_match.condition = + EmbeddedAst{match.condition()->id(), std::move(ast)}; + } + } + + if (match.has_output_block()) { + CEL_ASSIGN_OR_RETURN(c_match.production, + CompileOutputBlock(match.output_block(), env)); + } else if (match.has_rule()) { + auto rule = std::make_unique(); + CEL_RETURN_IF_ERROR(CompileRule(match.rule(), env, rule.get())); + c_match.production = std::move(rule); + } else { + issues_.ReportError(match.id(), "match must specify an output or rule"); + } + out->matches.push_back(std::move(c_match)); + return absl::OkStatus(); + } + + absl::Status CompileRule(const Rule& rule, const cel::Compiler* env, + CompiledRule* out) { + out->id = rule.id(); + std::unique_ptr buf; + + absl::flat_hash_set seen_variables; + for (const auto& variable : rule.variables()) { + std::string name(variable.name().value()); + if (!seen_variables.insert(name).second) { + issues_.ReportError( + variable.expression().id(), + absl::StrCat("overlapping identifier for name 'variables.", name, + "'")); + continue; + } + std::string ident = absl::StrCat("variables.", name); + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(variable.expression().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(variable.expression().id(), validation); + if (!validation.IsValid()) { + continue; + } + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + cel::Type result_type = DynType(); + + if (auto it = validation.GetResolvedTypeMap().find(ast->root_expr().id()); + it != validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + out->variables.push_back(CompiledVariable{ + ident, + EmbeddedAst{variable.expression().id(), std::move(ast)}, + }); + auto next = env->ToBuilder(); + auto status = next->GetCheckerBuilder().AddOrReplaceVariable( + MakeVariableDecl(ident, result_type)); + if (!status.ok()) { + issues_.ReportError(variable.expression().id(), status.message()); + continue; + } + CEL_ASSIGN_OR_RETURN(buf, next->Build()); + env = buf.get(); + } + + std::optional overall_type; + for (const auto& match : rule.matches()) { + CEL_RETURN_IF_ERROR(CompileMatch(match, env, out)); + if (!overall_type.has_value()) { + overall_type = GetOutputType(out->matches.back().production); + continue; + } + + if (std::optional match_type = + GetOutputType(out->matches.back().production); + match_type.has_value()) { + if (!OutputTypeIsCompatible(*match_type, *overall_type)) { + issues_.ReportError( + match.id(), + absl::StrCat("incompatible output types: block has output type ", + FormatTypeName(*match_type), + ", but previous outputs have type ", + FormatTypeName(*overall_type))); + } + } + } + + out->result_type = overall_type; + return absl::OkStatus(); + } + + absl::Status CompilePolicy(const CelPolicy& policy, + IntermediateCompiledPolicy* out) { + src_ = policy.source(); + out->set_semantics(RuleSemantics::kFirstMatch); + out->set_name(policy.name().value()); + out->set_display_name( + policy.display_name().value_or(ValueString{}).value()); + out->set_description(policy.description().value_or(ValueString{}).value()); + + return CompileRule(policy.rule(), base_compiler_.get(), + &out->mutable_root_rule()); + } + + private: + google::protobuf::Arena arena_; + const CelPolicySource* absl_nullable src_; + policy_internal::IssueReporter& issues_; + std::unique_ptr base_compiler_; +}; + +bool IsExhaustive(const CompiledRule& rule); + +class FirstMatchComposer { + public: + FirstMatchComposer(const IntermediateCompiledPolicy& icp, + const Compiler& compiler, + policy_internal::IssueReporter& issues) + : issues_(issues), icp_(icp), compiler_(compiler) {} + + absl::Status Compose(); + + bool success() const { return ast_ != nullptr; } + + std::unique_ptr ReleaseAst() { return std::move(ast_); } + + private: + using VariableScope = absl::flat_hash_map; + + std::optional ResolvePolicyVariable(absl::string_view reference); + + absl::flat_hash_map ResolveBlockIndexes(const Ast& ast); + + bool CheckMatchStructure(const CompiledRule& rule); + + // Returns true if already optional wrapped. + absl::StatusOr ComposeRule(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + // returns true if already optional wrapped. + absl::StatusOr ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr); + + void MapVariables(Ast& ast); + + void ComposeRuleVariables(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + policy_internal::IssueReporter& issues_; + OptimizerExprFactory factory_; + const IntermediateCompiledPolicy& icp_; + const Compiler& compiler_; + std::vector scopes_; + bool optionalize_ = false; + std::unique_ptr ast_; +}; + +absl::Status FirstMatchComposer::Compose() { + ABSL_DCHECK(icp_.semantics() == RuleSemantics::kFirstMatch); + + factory_.mutable_ast().mutable_root_expr() = factory_.NewCall( + "cel.@block", factory_.NewList(), factory_.NewUnspecified()); + auto& block_init_list = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[0]; + auto& insertion_expr = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]; + optionalize_ = !IsExhaustive(icp_.root_rule()); + if (!CheckMatchStructure(icp_.root_rule())) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + bool optional_wrapped, + ComposeRule(icp_.root_rule(), block_init_list, insertion_expr)); + + if (optional_wrapped != optionalize_) { + return absl::InternalError( + "composition failed to handle non-exhaustive rules"); + } + + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, + compiler_.GetTypeChecker().Check(factory_.ast())); + if (!result.IsValid()) { + for (const auto& iss : result.GetIssues()) { + issues_.ReportError(icp_.root_rule().id, iss.message()); + } + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ast_, result.ReleaseAst()); + + return absl::OkStatus(); +} + +bool IsTriviallyTrueCondition(const CompiledMatch& match) { + if (!match.condition.has_value() || match.condition->ast == nullptr) { + return true; + } + const cel::Expr& expr = match.condition->ast->root_expr(); + if (expr.has_const_expr()) { + const cel::Constant& const_expr = expr.const_expr(); + if (const_expr.has_bool_value() && const_expr.bool_value()) { + return true; + } + } + return false; +} + +bool IsExhaustive(const CompiledRule& rule); + +bool IsExhaustive(const CompiledMatch& match) { + if (std::holds_alternative(match.production)) { + return true; + } + + const auto* nested_rule_ptr = + std::get_if>(&match.production); + ABSL_DCHECK(nested_rule_ptr != nullptr); + const CompiledRule& nested_rule = **nested_rule_ptr; + return IsExhaustive(nested_rule); +} + +bool IsExhaustive(const CompiledRule& rule) { + if (rule.matches.empty()) { + // Validation should fail, but generalization would be false. + return false; + } + bool has_default = false; + for (const auto& match : rule.matches) { + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + // If this isn't the last match in the rule, it should get flagged + // during validation since it means there are trivially unreachable + // matches. + has_default = true; + } + if (!IsTriviallyTrueCondition(match) && !IsExhaustive(match)) { + // There is a nested rule that might return an optional.none(). + return false; + } + } + // Otherwise, everything in this branch is exhaustive so we can defer + // wrapping. + return has_default; +} + +bool FirstMatchComposer::CheckMatchStructure(const CompiledRule& rule) { + if (rule.matches.empty()) { + issues_.ReportError(rule.id, "rule does not specify match conditions"); + return false; + } + + bool valid = true; + bool seen_trivially_true = false; + + for (const auto& match : rule.matches) { + if (seen_trivially_true) { + if (std::holds_alternative(match.production)) { + issues_.ReportError(match.id, "match creates unreachable outputs"); + } else if (std::holds_alternative>( + match.production)) { + issues_.ReportError(match.id, "rule creates unreachable outputs"); + } + valid = false; + } + + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + seen_trivially_true = true; + } + + if (auto* nested_rule = + std::get_if>(&match.production); + nested_rule != nullptr) { + ABSL_DCHECK(*nested_rule != nullptr); + if (!CheckMatchStructure(**nested_rule)) { + valid = false; + } + } + } + + return valid; +} + +std::optional FirstMatchComposer::ResolvePolicyVariable( + absl::string_view reference) { + for (auto scope_iter = scopes_.rbegin(); scope_iter != scopes_.rend(); + ++scope_iter) { + if (auto it = scope_iter->find(reference); it != scope_iter->end()) { + return it->second; + } + } + return std::nullopt; +} + +class IndexRewrite : public AstRewriterBase { + public: + explicit IndexRewrite(absl::flat_hash_map expr_id_to_index, + OptimizerExprFactory& factory) + : expr_id_to_index_(std::move(expr_id_to_index)), factory_(factory) {} + + bool PreVisitRewrite(Expr& e) override { + if (auto it = expr_id_to_index_.find(e.id()); + it != expr_id_to_index_.end()) { + e.mutable_ident_expr().set_name(absl::StrCat("@index", it->second)); + factory_.RecordReplacement(e.id(), e); + return true; + } + return false; + } + + private: + absl::flat_hash_map expr_id_to_index_; + OptimizerExprFactory& factory_; +}; + +absl::StatusOr FirstMatchComposer::ComposeRule(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + scopes_.emplace_back(); + auto pop_scope = absl::MakeCleanup([this]() { scopes_.pop_back(); }); + ComposeRuleVariables(rule, init, insertion_expr); + Expr* insertion_point = &insertion_expr; + const bool has_default = IsTriviallyTrueCondition(rule.matches.back()); + const bool needs_wrap = !IsExhaustive(rule); + size_t end = rule.matches.size() - (has_default ? 1 : 0); + for (size_t i = 0; i < end; i++) { + const auto& match = rule.matches[i]; + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + return absl::InternalError("detected unreachable match after validation"); + } + + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + + if (!IsTriviallyTrueCondition(match)) { + Ast condition = *match.condition->ast; + MapVariables(condition); + factory_.StartCopyContext(); + auto copy = factory_.Copy(condition.root_expr()); + auto source_info = factory_.RemapSourceInfo(condition.source_info()); + factory_.MergeSourceInfo(source_info); + *insertion_point = factory_.NewCall("_?_:_", std::move(copy)); + insertion_point->mutable_call_expr().mutable_args().push_back( + std::move(production)); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + insertion_point = &insertion_point->mutable_call_expr().add_args(); + continue; + } + + if (!is_wrapped) { + return absl::InternalError( + "composition failed. expected optional wrapped rule but got a plain " + "value"); + } + auto fn = needs_wrap ? "or" : "orValue"; + *insertion_point = factory_.NewMemberCall(fn, std::move(production)); + insertion_point = &insertion_point->mutable_call_expr().add_args(); + } + + if (has_default) { + const auto& match = rule.matches.back(); + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + *insertion_point = std::move(production); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + + return needs_wrap; + } + + // Otherwise, we fell through a non-exhaustive rule. + *insertion_point = factory_.NewCall("optional.none"); + return true; +} + +absl::StatusOr FirstMatchComposer::ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr) { + if (auto* nested_rule = + std::get_if>(&production); + nested_rule != nullptr) { + return ComposeRule(**nested_rule, init, insertion_expr); + } + auto* output = std::get_if(&production); + if (output == nullptr) { + return absl::InternalError("unexpected rule production type"); + } + const EmbeddedAst& output_ast = output->output_ast; + Ast ast = *output_ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + Expr to_insert = factory_.Copy(ast.root_expr()); + auto source_info = factory_.RemapSourceInfo(ast.source_info()); + factory_.MergeSourceInfo(source_info); + insertion_expr = std::move(to_insert); + + return false; +} + +absl::flat_hash_map FirstMatchComposer::ResolveBlockIndexes( + const Ast& ast) { + absl::flat_hash_map out; + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + it++) { + const Reference& ref = it->second; + if (!it->second.overload_id().empty()) { + continue; + } + if (!absl::StartsWith(ref.name(), "variable")) { + continue; + } + if (auto index = ResolvePolicyVariable(ref.name()); index.has_value()) { + out[it->first] = *index; + } + } + return out; +} + +void FirstMatchComposer::MapVariables(Ast& ast) { + absl::flat_hash_map edit_map = ResolveBlockIndexes(ast); + IndexRewrite rewriter(std::move(edit_map), factory_); + AstRewrite(ast.mutable_root_expr(), rewriter); +} + +void FirstMatchComposer::ComposeRuleVariables(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + for (const auto& variable : rule.variables) { + Ast ast = *variable.ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + auto insertion = factory_.Copy(ast.root_expr()); + // TODO(b/506179116): apply the position offsets here. + auto info = factory_.RemapSourceInfo(ast.source_info()); + ABSL_DCHECK(init.has_list_expr()); + int index = init.mutable_list_expr().elements().size(); + init.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(insertion))); + scopes_.back()[variable.ident] = index; + } +} + +bool HasComprehensionParent(const NavigableAstNode& node) { + const NavigableAstNode* curr = &node; + while (curr != nullptr) { + if (curr->node_kind() == NodeKind::kComprehension) { + return true; + } + curr = curr->parent(); + } + return false; +} + +// Unnester implementation. +class Unnester { + public: + Unnester(Ast ast, int height, policy_internal::IssueReporter& issues) + : factory_(std::move(ast)), height_(height), issues_(issues) {} + + // Run the unnesting. + // The class cannot be reused after this is called. + absl::StatusOr Unnest() { + if (height_ > 0) { + CEL_RETURN_IF_ERROR(Slice()); + } + CEL_RETURN_IF_ERROR(Cleanup()); + return std::move(factory_.mutable_ast()); + } + + private: + // The core unnest routine. + absl::Status Slice(); + // Fixup the AST post-unnesting. + absl::Status Cleanup(); + + void ReportErrorAtId(int64_t id, absl::string_view message); + + OptimizerExprFactory factory_; + int height_; + policy_internal::IssueReporter& issues_; +}; + +class UnnestRewriter : public AstRewriterBase { + public: + explicit UnnestRewriter(OptimizerExprFactory& f, Expr& block_list_expr, + absl::Span cuts) + : factory_(f), cuts_(cuts), block_list_expr_(block_list_expr) {} + + bool PostVisitRewrite(Expr& expr) override { + using std::swap; + // Post order so we always see children before parents. + // No need to copy metadata since we're only moving exprs or minting + // new ones. + if (absl::c_contains(cuts_, expr.id())) { + size_t idx = block_list_expr_.list_expr().elements().size(); + Expr value = factory_.NewIdent(absl::StrCat("@index", idx)); + factory_.RecordReplacement(expr.id(), value, /*keep_metadata=*/true); + swap(value, expr); + block_list_expr_.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(value))); + return true; + } + return false; + } + + private: + OptimizerExprFactory& factory_; + absl::Span cuts_; + Expr& block_list_expr_; +}; + +absl::Status Unnester::Slice() { + Expr& root = factory_.mutable_ast().mutable_root_expr(); + if (root.call_expr().function() != kCelBlock || + root.call_expr().args().size() != 2 || + !root.call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + // Two passes, we identify the slice points (bottom up), then cut + // and paste the leaves into the block list. + NavigableAst nav_ast = NavigableAst::Build(factory_.ast().root_expr()); + + ABSL_DCHECK(nav_ast.IdsAreUnique()); + bool can_cut = true; + std::vector cuts; + for (const NavigableAstNode& node : nav_ast.Root().DescendantsPostorder()) { + // Subsequent cuts will be height_ + 1 in the block, indices. Within the + // error margin we specified. + if (node.height() % height_ == 0) { + if (HasComprehensionParent(node)) { + ReportErrorAtId( + node.expr()->id(), + absl::StrCat( + "cannot unnest AST due to comprehension. cannot accommodate " + "height limit of ", + height_)); + can_cut = false; + continue; + } + if (&node == &nav_ast.Root()) { + // If evenly divisible by height, don't cut since it will net a taller + // AST. + continue; + } + cuts.push_back(node.expr()->id()); + } + } + + if (!can_cut || cuts.empty()) { + return absl::OkStatus(); + } + + Expr& block_list_expr = root.mutable_call_expr().mutable_args()[0]; + Expr& insertion_expr = root.mutable_call_expr().mutable_args()[1]; + + UnnestRewriter rewriter(factory_, block_list_expr, cuts); + AstRewrite(insertion_expr, rewriter); + + return absl::OkStatus(); +} + +absl::Status Unnester::Cleanup() { + using std::swap; + + const auto& ast = factory_.ast(); + if (ast.root_expr().call_expr().function() != kCelBlock || + ast.root_expr().call_expr().args().size() != 2 || + !ast.root_expr().call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + if (ast.root_expr().call_expr().args()[0].list_expr().elements().empty()) { + Expr value = std::move(factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]); + factory_.mutable_ast().mutable_root_expr() = std::move(value); + } + + return absl::OkStatus(); +} + +void Unnester::ReportErrorAtId(int64_t id, absl::string_view message) { + int32_t position = 0; + auto it = factory_.ast().source_info().positions().find(id); + if (it != factory_.ast().source_info().positions().end()) { + position = it->second; + } + issues_.ReportError(-1, position, message); +} +} // namespace + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options) { + policy_internal::IssueReporter issues; + if (options.unnesting_height_limit != 0 && + options.unnesting_height_limit < 2) { + return absl::InvalidArgumentError( + "unnesting_height_limit must be at least 2"); + } + auto builder = compiler.ToBuilder(); + ExpressionContainer cont; + for (const auto& import : policy.imports()) { + auto status = cont.AddAbbreviation(import.name().value()); + if (!status.ok()) { + issues.ReportError( + import.name().id(), + absl::StrCat("'", import.name().value(), "': ", status.message())); + } + } + + builder->GetCheckerBuilder().SetExpressionContainer(cont); + CEL_ASSIGN_OR_RETURN(auto base_compiler, builder->Build()); + + PolicyCompiler policy_compiler(&issues, std::move(base_compiler)); + + IntermediateCompiledPolicy icp; + CEL_RETURN_IF_ERROR(policy_compiler.CompilePolicy(policy, &icp)); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + CEL_ASSIGN_OR_RETURN(base_compiler, builder->Build()); + switch (icp.semantics()) { + case RuleSemantics::kFirstMatch: { + FirstMatchComposer composer(icp, *base_compiler, issues); + CEL_RETURN_IF_ERROR(composer.Compose()); + if (!composer.success()) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + auto ast = composer.ReleaseAst(); + Unnester unnester(std::move(*ast), options.unnesting_height_limit, + issues); + CEL_ASSIGN_OR_RETURN(Ast unnested_ast, unnester.Unnest()); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + return CelPolicyValidationResult( + std::make_unique(std::move(unnested_ast)), {}, + policy.source_ptr()); + } + default: + return absl::UnimplementedError( + absl::StrCat("Unsupported RuleSemantics: ", icp.semantics())); + } +} + +} // namespace cel diff --git a/policy/compiler.h b/policy/compiler.h new file mode 100644 index 000000000..0187bd1a2 --- /dev/null +++ b/policy/compiler.h @@ -0,0 +1,50 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_validation_result.h" + +namespace cel { + +struct CompilePolicyOptions { + // If greater than 0, the compiler will attempt to unnest rule branches + // at the specified height. The overall height of the final AST may exceed + // this by a small, fixed margin. + // + // To avoid slicing comprehensions, subexpressions within comprehensions + // are not eligible for unnesting. If the height limit cannot be accommodated, + // an error with code InvalidArgument is returned. + // + // If the AST is converted to proto, even relatively low levels of nesting + // can cause problems in serialization/deserialization. This does not apply + // if the AST is used directly by the runtime. + int unnesting_height_limit = 0; +}; + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +// +// TODO(b/506179116): Implementation in progress. Functionally complete, +// but errors are not consistent with other implementations. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ diff --git a/policy/compiler_test.cc b/policy/compiler_test.cc new file mode 100644 index 000000000..8db494b45 --- /dev/null +++ b/policy/compiler_test.cc @@ -0,0 +1,946 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/types/message_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/cel_policy.yaml"; + +absl::StatusOr> BuildTestCompiler() { + CompilerOptions opts; + opts.adapt_parser_errors = true; + opts.parser_options.enable_optional_syntax = true; + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool(), opts)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", IntType()))); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", IntType()))); + + const google::protobuf::Descriptor* descriptor = + cel::internal::GetSharedTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + if (descriptor == nullptr) { + return absl::InternalError("Failed to find TestAllTypes descriptor"); + } + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("spec", cel::MessageType(descriptor)))); + + return builder->Build(); +} + +absl::StatusOr> ParsePolicyFromYaml( + absl::string_view yaml_content) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(yaml_content, "test.yaml")); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(auto parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError("Invalid policy YAML structure"); + } + return parse_result.ReleasePolicy(); +} + +TEST(CompilerTest, SmokeTest) { + std::string contents; + std::string test_file = + cel::internal::ResolveRunfilesPath(kTestPolicyFilePath); + auto read_status = cel::internal::GetFileContents(test_file, &contents); + ASSERT_THAT(read_status, IsOk()); + + auto source_or = cel::NewSource(contents, "cel_policy.yaml"); + ASSERT_THAT(source_or.status(), IsOk()); + auto source = *std::move(source_or); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + auto parse_result_or = cel::ParseYamlCelPolicy(policy_source); + ASSERT_THAT(parse_result_or.status(), IsOk()); + auto parse_result = *std::move(parse_result_or); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, VariableOutOfScopeReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: variables.non_existent == 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, ConditionNotBoolReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("condition must evaluate to bool")); +} + +TEST(CompilerTest, InvalidOutputExpressionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: undeclared_var +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, UnreachableMatchAfterTriviallyTrueCondition) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"first"' + - condition: true + output: '"second"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, UnreachableMatchAfterUnconditionalExhaustiveSubRule) { + absl::string_view yaml = R"yaml( +name: dead_branch +rule: + match: + - rule: + match: + - output: 1 + - output: 2 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, RuleWithoutMatchesReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("rule does not specify match conditions")); +} + +TEST(CompilerTest, ExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 15 + output: '"greater than 15"' + - condition: variables.test_var > 5 + output: '"greater than 5"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +TEST(CompilerTest, NonExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 5 + output: '"greater than 5"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, PolicyReferencesEnvInput) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: spec.single_int32 > 10 + output: '"greater than 10"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +struct EvaluationTestCase { + std::string name; + std::string yaml_policy; + struct Input { + int64_t x; + int64_t y; + } input; + ValueMatcher expected_result_matcher; +}; + +class PolicyEvaluationTest : public testing::TestWithParam { +}; + +TEST_P(PolicyEvaluationTest, Evaluate) { + const auto& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(test_case.yaml_policy)); + ASSERT_OK_AND_ASSIGN(auto validation_result, + CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(validation_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, validation_result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + // Set up activation + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(test_case.input.x)); + activation.InsertOrAssignValue("y", cel::IntValue(test_case.input.y)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.expected_result_matcher); +} + +constexpr absl::string_view kEvalPolicyYaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: x > 10 && y > 10 + output: '"both greater than 10"' + - condition: x > 10 + output: '"x greater than 10"' + - condition: y > 10 + output: '"y greater than 10"' + - output: '"default"' +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + PolicyEvaluationTest, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "BothGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 15}, + .expected_result_matcher = StringValueIs("both greater than 10"), + }, + EvaluationTestCase{ + .name = "XGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 5}, + .expected_result_matcher = StringValueIs("x greater than 10"), + }, + EvaluationTestCase{ + .name = "YGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 15}, + .expected_result_matcher = StringValueIs("y greater than 10"), + }, + EvaluationTestCase{ + .name = "Default", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 5}, + .expected_result_matcher = StringValueIs("default"), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNonExhaustivePolicyYaml = R"yaml( +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x < 3 + output: 1 + - condition: x < 5 + output: 2 + - condition: x < 0 + rule: + match: + - condition: x > -2 + output: 3 + - condition: x > -4 + output: 4 + - output: 5 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NonExhaustivePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals0_FallthroughTopLevel", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEquals2_MatchesFirstNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals6_FallthroughNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 6, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1_MatchesMinus2", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3_MatchesMinus4", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus5_MatchesDefault", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -5, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(5)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNestedVariablePolicyYaml = R"yaml( +name: nested_rule4 +rule: + variables: + - name: i + expression: "1" + - name: j + expression: "2" + match: + - condition: x > 0 + rule: + variables: + - name: k + expression: "3" + match: + - output: "variables.i + variables.j + variables.k" + - condition: x < 0 + rule: + variables: + - name: j + expression: "5" + - name: k + expression: "4" + match: + - output: "variables.i + variables.j + variables.k" + - output: "variables.i + variables.j" +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NestedVariablePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XGreaterThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(6), + }, + EvaluationTestCase{ + .name = "XLessThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(10), + }, + EvaluationTestCase{ + .name = "XEquals0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view + kOptionalChainingUnconditionalSubRuleOptionalParentYaml = R"yaml( +name: optional_chaining +rule: + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 + condition: x < 0 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRuleOptionalParent, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalSubRuleYaml = R"yaml( +name: optional_chaining +rule: + id: r1 + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRule, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(2), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalComplexYaml = R"yaml( +name: optional_chaining +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: x == 1 + output: 1 + - output: 2 + - rule: + match: + - condition: x == -1 + output: 3 + - condition: x == -2 + output: 4 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalComplex, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus2", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 + - output: 3 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = IntValueIs(2), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: non_exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalNonExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CompilerTest, ImportsAndAbbreviations) { + absl::string_view yaml = R"yaml( +name: imports_test +imports: + - name: cel.expr.conformance.proto3.TestAllTypes +rule: + match: + - condition: 'spec == TestAllTypes{single_int32: 10}' + output: '"matched"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + auto ast_or = CompilePolicy(*compiler, *policy); + ASSERT_THAT(ast_or, IsOk()); +} + +TEST(CompilerTest, MatchWithoutProductionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match must specify an output or rule")); +} + +int GetAstHeight(const cel::Ast& ast) { + auto nav_ast = cel::NavigableAst::Build(ast.root_expr()); + return nav_ast.Root().height(); +} + +TEST(CompilerTest, UnnestHeightValidation) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 1; + auto status_or = CompilePolicy(*compiler, *policy, options); + EXPECT_THAT(status_or.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "unnesting_height_limit must be at least 2"))); + + options.unnesting_height_limit = 2; + EXPECT_THAT(CompilePolicy(*compiler, *policy, options), IsOk()); +} + +constexpr absl::string_view kDeepPolicyYaml = R"yaml( +name: deep_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x > 1 + rule: + match: + - condition: x > 2 + rule: + match: + - condition: x > 3 + rule: + match: + - condition: x > 4 + rule: + match: + - condition: x > 5 + output: 6 + - output: 5 + - output: 4 + - output: 3 + - output: 2 + - output: 1 + - output: 0 +)yaml"; + +TEST(CompilerTest, UnnestHeightReduction) { + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + // Compile without unnesting + CompilePolicyOptions options_no_unnest; + options_no_unnest.unnesting_height_limit = 0; + ASSERT_OK_AND_ASSIGN(auto result_no_unnest, + CompilePolicy(*compiler, *policy, options_no_unnest)); + ASSERT_TRUE(result_no_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_no_unnest, result_no_unnest.ReleaseAst()); + int height_no_unnest = GetAstHeight(*ast_no_unnest); + + CompilePolicyOptions options_unnest; + options_unnest.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result_unnest, + CompilePolicy(*compiler, *policy, options_unnest)); + ASSERT_TRUE(result_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_unnest, result_unnest.ReleaseAst()); + int height_unnest = GetAstHeight(*ast_unnest); + + EXPECT_EQ(height_no_unnest, 8); + EXPECT_EQ(height_unnest, 5); + EXPECT_LT(height_unnest, height_no_unnest); +} + +TEST(CompilerTest, UnnestComprehensionFailure) { + absl::string_view yaml = R"yaml( +name: comprehension_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: "[1, 2].all(i, i > x)" + output: 1 + - output: 2 + - output: 0 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("cannot unnest AST due to comprehension")); +} + +struct UnnestEvaluationTestCase { + std::string name; + int64_t x; + ValueMatcher expected; +}; + +class UnnestedDeepPolicyEvaluationTest + : public testing::TestWithParam {}; + +TEST_P(UnnestedDeepPolicyEvaluationTest, Evaluate) { + const auto& tc = GetParam(); + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(tc.x)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(cel::Value res, program->Evaluate(&arena, activation)); + + EXPECT_THAT(res, tc.expected); +} + +INSTANTIATE_TEST_SUITE_P( + UnnestedDeepPolicyEvaluation, UnnestedDeepPolicyEvaluationTest, + testing::Values(UnnestEvaluationTestCase{"XEquals6", 6, IntValueIs(6)}, + UnnestEvaluationTestCase{"XEquals5", 5, IntValueIs(5)}, + UnnestEvaluationTestCase{"XEquals4", 4, IntValueIs(4)}, + UnnestEvaluationTestCase{"XEquals3", 3, IntValueIs(3)}, + UnnestEvaluationTestCase{"XEquals2", 2, IntValueIs(2)}, + UnnestEvaluationTestCase{"XEquals1", 1, IntValueIs(1)}, + UnnestEvaluationTestCase{"XEquals0", 0, IntValueIs(0)}, + UnnestEvaluationTestCase{"XEqualsMinus1", -1, + IntValueIs(0)}), + [](const testing::TestParamInfo< + UnnestedDeepPolicyEvaluationTest::ParamType>& info) { + return info.param.name; + }); + +TEST(CompilerTest, UnnestCleanupRunsWhenDisabled) { + // A policy without variables and without nesting. + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 0; // Disabled + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // If cleanup ran, it should have optimized away the trivial `cel.@block`. + // So the root expression should NOT be a call to `cel.@block`. + // It should be just the constant `"ok"`. + auto nav_ast = cel::NavigableAst::Build(ast->root_expr()); + EXPECT_FALSE(nav_ast.Root().expr()->has_call_expr() && + nav_ast.Root().expr()->call_expr().function() == "cel.@block"); + EXPECT_TRUE(nav_ast.Root().expr()->has_const_expr()); +} +} // namespace +} // namespace cel diff --git a/policy/internal/BUILD b/policy/internal/BUILD new file mode 100644 index 000000000..30f43d431 --- /dev/null +++ b/policy/internal/BUILD @@ -0,0 +1,68 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "issue_reporter", + srcs = ["issue_reporter.cc"], + hdrs = ["issue_reporter.h"], + deps = [ + "//common:source", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "optimizer_expr_factory", + srcs = ["optimizer_expr_factory.cc"], + hdrs = ["optimizer_expr_factory.h"], + deps = [ + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor_base", + "//common:constant", + "//common:expr", + "//common:expr_factory", + "//common:source", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "optimizer_expr_factory_test", + srcs = ["optimizer_expr_factory_test.cc"], + deps = [ + ":optimizer_expr_factory", + "//common:ast", + "//common:ast_proto", + "//common:ast_rewrite", + "//common:decl", + "//common:expr", + "//common:expr_factory", + "//common:source", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:expr_printer", + "//tools:cel_unparser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/policy/internal/issue_reporter.cc b/policy/internal/issue_reporter.cc new file mode 100644 index 000000000..944e687d6 --- /dev/null +++ b/policy/internal/issue_reporter.cc @@ -0,0 +1,45 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/issue_reporter.h" + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel::policy_internal { + +void IssueReporter::ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message) { + issues_.push_back({element, severity, message}); +} + +void IssueReporter::ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, + absl::string_view message) { + issues_.push_back({element, relative_position, severity, message}); +} + +void IssueReporter::ReportError(CelPolicyElementId element, + absl::string_view message) { + ReportIssue(element, Severity::kError, message); +} + +void IssueReporter::ReportError(CelPolicyElementId element, SourcePosition pos, + absl::string_view message) { + ReportOffsetIssue(element, pos, Severity::kError, message); +} + +} // namespace cel::policy_internal diff --git a/policy/internal/issue_reporter.h b/policy/internal/issue_reporter.h new file mode 100644 index 000000000..3f88806ef --- /dev/null +++ b/policy/internal/issue_reporter.h @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel::policy_internal { + +class IssueReporter { + private: + using Severity = CelPolicyIssue::Severity; + + public: + void ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message); + + void ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, absl::string_view message); + + void ReportError(CelPolicyElementId element, absl::string_view message); + void ReportError(CelPolicyElementId element, SourcePosition relative_pos, + absl::string_view message); + + std::vector ReleaseIssues() { + using std::swap; + std::vector out; + swap(out, issues_); + return out; + } + const std::vector& issues() const { return issues_; } + + private: + std::vector issues_; +}; + +} // namespace cel::policy_internal + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ diff --git a/policy/internal/optimizer_expr_factory.cc b/policy/internal/optimizer_expr_factory.cc new file mode 100644 index 000000000..6c89ae958 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.cc @@ -0,0 +1,373 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +namespace { + +class MaxIdVisitor final : public AstVisitorBase { + public: + ExprId max_id() const { return max_id_; } + + void PreVisitExpr(const Expr& expr) override { + max_id_ = std::max(max_id_, expr.id()); + } + + void PostVisitExpr(const Expr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr& struct_expr) override { + for (const auto& field : struct_expr.fields()) { + max_id_ = std::max(max_id_, field.id()); + } + } + + void PostVisitMap(const Expr&, const MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + max_id_ = std::max(max_id_, entry.id()); + } + } + + private: + ExprId max_id_ = 0; +}; + +ExprId GetMaxId(const Expr& expr) { + MaxIdVisitor visitor; + AstTraverse(expr, visitor); + return visitor.max_id(); +} + +ExprId GetMaxId(const Ast& ast) { + ExprId max_id = GetMaxId(ast.root_expr()); + for (const auto& [id, _] : ast.source_info().positions()) { + max_id = std::max(max_id, id); + } + for (const auto& [id, expr] : ast.source_info().macro_calls()) { + max_id = std::max(max_id, id); + max_id = std::max(max_id, GetMaxId(expr)); + } + return max_id; +} + +// Replaces nested macros in a macro_calls expr with reference nodes. +// +// The macro_calls map is used for retaining the original structure of the +// parsed expression before macro expansion. When a macro appears inside another +// macro, the parser will replace the inner macro expr node with an unspecified +// expr with the inner macro's ID in the macro_calls map to save space. +class MakeMacroCallRewrite final : public AstRewriterBase { + public: + explicit MakeMacroCallRewrite(const SourceInfo& source_info) + : source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (source_info_.macro_calls().find(expr.id()) != + source_info_.macro_calls().end()) { + ExprId id = expr.id(); + expr.mutable_kind() = UnspecifiedExpr(); + expr.set_id(id); + return true; + } + return false; + } + + private: + const SourceInfo& source_info_; +}; + +// Updates macro_calls map entries to reflect a replaced expression in the +// main AST. +class ReplaceMacroCallRewrite final : public AstRewriterBase { + public: + ReplaceMacroCallRewrite(ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) + : old_id_(old_id), replacement_(replacement), source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = macro_replacement(); + return true; + } + return false; + } + + Expr macro_replacement() { + if (!macro_replacement_) { + macro_replacement_.emplace(replacement_); + MakeMacroCallRewrite hole_creator(source_info_); + AstRewrite(*macro_replacement_, hole_creator); + } + return *macro_replacement_; + } + + private: + ExprId old_id_; + const Expr& replacement_; + absl::optional macro_replacement_; + const SourceInfo& source_info_; +}; + +void ReplaceSubExpr(Expr& expr, ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) { + ReplaceMacroCallRewrite rewriter(old_id, replacement, source_info); + AstRewrite(expr, rewriter); +} + +class IdRewriter : public AstRewriterBase { + using CopyIdFn = absl::AnyInvocable; + + public: + explicit IdRewriter(CopyIdFn copy_id) : copy_id_(std::move(copy_id)) {} + + // No structure changes just ids. + bool PreVisitRewrite(Expr& expr) override { + expr.set_id(copy_id_(expr.id())); + if (expr.has_struct_expr()) { + for (auto& field : expr.mutable_struct_expr().mutable_fields()) { + field.set_id(copy_id_(field.id())); + } + } else if (expr.has_map_expr()) { + for (auto& entry : expr.mutable_map_expr().mutable_entries()) { + entry.set_id(copy_id_(entry.id())); + } + } + return false; + } + + private: + CopyIdFn copy_id_; +}; + +} // namespace + +OptimizerExprFactory::OptimizerExprFactory(Ast basis) + : ast_(std::move(basis)), next_id_(GetMaxId(ast_) + 1) {} + +OptimizerExprFactory::OptimizerExprFactory() : next_id_(1) {} + +Expr OptimizerExprFactory::Copy(const Expr& expr) { + Expr copied = expr; + IdRewriter rewriter([this](ExprId id) { return CopyId(id); }); + AstRewrite(copied, rewriter); + return copied; +} + +ListExprElement OptimizerExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField OptimizerExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry OptimizerExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +ExprId OptimizerExprFactory::NextId() { return next_id_++; } + +ExprId OptimizerExprFactory::CopyId(ExprId id) { + if (id == 0) { + return 0; + } + auto it = renumbers_.find(id); + if (it != renumbers_.end()) { + return it->second; + } + ExprId new_id = NextId(); + renumbers_[id] = new_id; + return new_id; +} + +SourceInfo OptimizerExprFactory::RemapSourceInfo(const SourceInfo& info, + SourcePosition offset) { + SourceInfo out; + + for (const auto& [old_id, macro_expr] : info.macro_calls()) { + if (auto it = renumbers_.find(old_id); it != renumbers_.end()) { + ExprId new_id = it->second; + out.mutable_macro_calls()[new_id] = Copy(macro_expr); + } + } + + for (const auto& [old_id, new_id] : renumbers_) { + if (auto it = info.positions().find(old_id); it != info.positions().end()) { + out.mutable_positions()[new_id] = it->second + offset; + } + } + + return out; +} + +void OptimizerExprFactory::MergeSourceInfo(const SourceInfo& info) { + auto& target_info = ast_.mutable_source_info(); + + for (const auto& [id, pos] : info.positions()) { + auto [it, inserted] = target_info.mutable_positions().insert({id, pos}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in positions merge"}); + } + } + + for (const auto& [id, expr] : info.macro_calls()) { + auto [it, inserted] = target_info.mutable_macro_calls().insert({id, expr}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in macro calls merge"}); + } + } + + // TODO(b/506179116): need to add some check that we aren't + // introducing incompatible tags. Not possible in the policy compiler right + // now. + for (const auto& ext : info.extensions()) { + auto& target_exts = target_info.mutable_extensions(); + if (!absl::c_linear_search(target_exts, ext)) { + target_exts.push_back(ext); + } + } +} + +void OptimizerExprFactory::RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata) { + auto& source_info = ast_.mutable_source_info(); + if (!keep_metadata) { + source_info.mutable_positions().erase(id); + source_info.mutable_macro_calls().erase(id); + } + + for (auto& [macro_id, macro_expr] : source_info.mutable_macro_calls()) { + ReplaceSubExpr(macro_expr, id, replacement, source_info); + } +} + +Expr OptimizerExprFactory::ReportError(absl::string_view message) { + ExprId id = NextId(); + issues_.push_back(Issue{id, std::string(message)}); + return NewUnspecified(id); +} + +Expr OptimizerExprFactory::ReportErrorAt(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{expr.id(), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::ReportErrorAtCopy(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{CopyId(expr.id()), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::NewUnspecified() { return NewUnspecified(NextId()); } + +Expr OptimizerExprFactory::NewNullConst() { return NewNullConst(NextId()); } + +Expr OptimizerExprFactory::NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(const char* value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(const char* value) { + return NewStringConst(NextId(), value); +} + +absl::flat_hash_map OptimizerExprFactory::ConsumeRenumbers() { + using std::swap; + absl::flat_hash_map out; + swap(out, renumbers_); + return out; +} + +void OptimizerExprFactory::StartCopyContext() { renumbers_.clear(); } + +const std::vector& OptimizerExprFactory::issues() + const { + return issues_; +} + +const Ast& OptimizerExprFactory::ast() const { return ast_; } + +Ast& OptimizerExprFactory::mutable_ast() { return ast_; } + +absl::string_view OptimizerExprFactory::AccuVarName() { + return ExprFactory::AccuVarName(); +} + +Expr OptimizerExprFactory::NewAccuIdent() { return NewAccuIdent(NextId()); } + +ExprId OptimizerExprFactory::CopyId(const Expr& expr) { + return CopyId(expr.id()); +} + +} // namespace cel diff --git a/policy/internal/optimizer_expr_factory.h b/policy/internal/optimizer_expr_factory.h new file mode 100644 index 000000000..6f63f1485 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.h @@ -0,0 +1,419 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestOptimizerExprFactory; + +// `OptimizerExprFactory` is a specialization of `ExprFactory` used for AST +// optimization. It provides utilities for correcting metadata for modified +// ASTs. +class OptimizerExprFactory : protected ExprFactory { + public: + struct Issue { + ExprId location = 0; + std::string message; + }; + + explicit OptimizerExprFactory(Ast basis); + OptimizerExprFactory(); + + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + // Consume the current set of renumberings. + absl::flat_hash_map ConsumeRenumbers(); + + // Starts a new copy context. The current set of renumberings are cleared. + void StartCopyContext(); + + const std::vector& issues() const; + + // Record that a node in the working AST was replaced. This is used to correct + // metadata referencing the old ID. + void RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata = false); + + // Makes a copy of source metadata that is remapped to new expr Ids using + // current renumberings. This is suitable for merging into the main source + // info. + SourceInfo RemapSourceInfo(const SourceInfo& info, SourcePosition offset = 0); + + // Merge a remapped SourceInfo into the current one. + void MergeSourceInfo(const SourceInfo& info); + + const Ast& ast() const; + Ast& mutable_ast(); + + absl::string_view AccuVarName(); + + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified(); + + ABSL_MUST_USE_RESULT Expr NewNullConst(); + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value); + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value); + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value); + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name); + + ABSL_MUST_USE_RESULT Expr NewAccuIdent(); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field); + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args); + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args); + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false); + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields); + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false); + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result); + + ABSL_MUST_USE_RESULT Expr ReportError(absl::string_view message); + + // Reports an error at the id in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAt(const Expr& expr, + absl::string_view message); + // Reports an error at the mapped id of the copy of expr in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAtCopy(const Expr& expr, + absl::string_view message); + + protected: + ABSL_MUST_USE_RESULT ExprId NextId(); + + ABSL_MUST_USE_RESULT ExprId CopyId(ExprId id); + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr); + + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + private: + Ast ast_; + absl::flat_hash_map renumbers_; + std::vector issues_; + + ExprId next_id_ = 1; +}; + +// Implementation details. + +template +Expr OptimizerExprFactory::NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); +} + +template +Expr OptimizerExprFactory::NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); +} + +template +Expr OptimizerExprFactory::NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); +} + +template +StructExprField OptimizerExprFactory::NewStructField(Name name, Value value, + bool optional) { + return NewStructField(NextId(), std::move(name), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); +} + +template +MapExprEntry OptimizerExprFactory::NewMapEntry(Key key, Value value, + bool optional) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); +} + +template +Expr OptimizerExprFactory::NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); +} + +template +Expr OptimizerExprFactory::NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ diff --git a/policy/internal/optimizer_expr_factory_test.cc b/policy/internal/optimizer_expr_factory_test.cc new file mode 100644 index 000000000..1b14b5628 --- /dev/null +++ b/policy/internal/optimizer_expr_factory_test.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/ast_rewrite.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/expr_printer.h" +#include "tools/cel_unparser.h" + +namespace cel { + +using ::testing::SizeIs; + +// Expose protected members of OptimizerExprFactory for use in tests +// +// These allow setting explicit IDs which is not safe for the optimizing +// factory. +class TestOptimizerExprFactory final : public OptimizerExprFactory { + public: + using OptimizerExprFactory::OptimizerExprFactory; + + using OptimizerExprFactory::NewBoolConst; + using OptimizerExprFactory::NewCall; + using OptimizerExprFactory::NewComprehension; + using OptimizerExprFactory::NewIdent; + using OptimizerExprFactory::NewList; + using OptimizerExprFactory::NewListElement; + using OptimizerExprFactory::NewMap; + using OptimizerExprFactory::NewMapEntry; + using OptimizerExprFactory::NewMemberCall; + using OptimizerExprFactory::NewSelect; + using OptimizerExprFactory::NewStruct; + using OptimizerExprFactory::NewStructField; + using OptimizerExprFactory::NewUnspecified; + using OptimizerExprFactory::NextId; +}; + +namespace { + +class ReplaceExprRewriter final : public AstRewriterBase { + public: + ReplaceExprRewriter(ExprId old_id, const Expr& replacement) + : old_id_(old_id), replacement_(replacement) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = replacement_; + return true; + } + return false; + } + + private: + ExprId old_id_; + const Expr& replacement_; +}; + +void ReplaceExprInTree(Expr& expr, ExprId old_id, const Expr& replacement) { + ReplaceExprRewriter rewriter(old_id, replacement); + AstRewrite(expr, rewriter); +} + +absl::StatusOr> CreateTestCompiler() { + CompilerOptions opts; + opts.parser_options.add_macro_calls = true; + CEL_ASSIGN_OR_RETURN( + auto builder, cel::NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("to_replace", cel::DynType()))); + return builder->Build(); +} + +TEST(OptimizerExprFactory, CopyUnspecified) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(OptimizerExprFactory, CopyIdent) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyConst) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(OptimizerExprFactory, CopySelect) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(OptimizerExprFactory, CopyCall) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(OptimizerExprFactory, CopyList) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(OptimizerExprFactory, CopyStruct) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(OptimizerExprFactory, CopyMap) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(OptimizerExprFactory, CopyComprehension) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(OptimizerExprFactory, RemapSourceInfo) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + info.mutable_positions()[1] = 42; // old ID 1 has position 42 + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to position 42 + 10 = 52 + auto it = remapped.positions().find(2); + ASSERT_NE(it, remapped.positions().end()); + EXPECT_EQ(it->second, 52); +} + +TEST(OptimizerExprFactory, RemapSourceInfoWithMacroCalls) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + // old ID 1 has macro call with ID 3 + info.mutable_macro_calls()[1] = factory.NewIdent("bar"); + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to the copied macro call + // since "bar" has ID 3, Copy(bar) should map ID 3 to ID 4 + + auto it = remapped.macro_calls().find(2); + ASSERT_NE(it, remapped.macro_calls().end()); + + // The macro call should be an Ident with new ID 4 + EXPECT_EQ(it->second.id(), 4); + EXPECT_TRUE(it->second.has_ident_expr()); + EXPECT_EQ(it->second.ident_expr().name(), "bar"); +} + +TEST(OptimizerExprFactory, ReportError) { + TestOptimizerExprFactory factory{Ast()}; + Expr err_expr = factory.ReportError("something went wrong"); + + // err_expr should be unspecified with ID 1 + EXPECT_EQ(err_expr.id(), 1); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with ID 1 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "something went wrong"); +} + +TEST(OptimizerExprFactory, ReportErrorAt) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + Expr err_expr = factory.ReportErrorAtCopy(orig, "error on foo"); + + // err_expr should be unspecified with ID 3 (NextId) + EXPECT_EQ(err_expr.id(), 3); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with mapped ID 2 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 2); + EXPECT_EQ(factory.issues()[0].message, "error on foo"); +} + +TEST(OptimizerExprFactory, MergeSourceInfo) { + // Create a base AST with some source info + SourceInfo base_info; + base_info.set_syntax_version("cel1"); + base_info.set_location("test.cel"); + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + + TestOptimizerExprFactory factory{std::move(base_ast)}; + + // Create a new source info to merge + SourceInfo new_info; + new_info.mutable_positions()[2] = 20; + + factory.MergeSourceInfo(new_info); + + // The merged source info should have both positions + const auto& merged_info = factory.ast().source_info(); + EXPECT_EQ(merged_info.syntax_version(), "cel1"); + EXPECT_EQ(merged_info.location(), "test.cel"); + + auto it1 = merged_info.positions().find(1); + ASSERT_NE(it1, merged_info.positions().end()); + EXPECT_EQ(it1->second, 10); + + auto it2 = merged_info.positions().find(2); + ASSERT_NE(it2, merged_info.positions().end()); + EXPECT_EQ(it2->second, 20); +} + +TEST(OptimizerExprFactory, MergeSourceInfoConflict) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_positions()[1] = 20; // conflicting ID 1 + + factory.MergeSourceInfo(new_info); + + // Should report an error for the conflict + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in positions merge"); +} + +TEST(OptimizerExprFactory, RecordReplacement) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + base_info.mutable_positions()[2] = 20; + + TestOptimizerExprFactory factory{Ast()}; + + // macro_calls[1] maps ID 1 to macro call "bar(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[1] = + factory.NewCall("bar", factory.NewIdent(1, "foo")); + + // macro_calls[2] maps ID 2 to macro call "baz(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[2] = + factory.NewCall("baz", factory.NewIdent(1, "foo")); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory optimizer{std::move(base_ast)}; + + // Record the replacement of ID 1 by a new Ident "replacement" with ID 3 + optimizer.RecordReplacement(1, factory.NewIdent(3, "replacement")); + + const auto& result_info = optimizer.ast().source_info(); + + // 1. ID 1 should be erased from positions + EXPECT_EQ(result_info.positions().find(1), result_info.positions().end()); + EXPECT_NE(result_info.positions().find(2), result_info.positions().end()); + + // 2. ID 1 should be erased from macro_calls keys + EXPECT_EQ(result_info.macro_calls().find(1), result_info.macro_calls().end()); + + // 3. macro_calls[2] should still exist, but its argument referencing ID 1 + // should be replaced with the Ident "replacement" with ID 3 inline + auto it = result_info.macro_calls().find(2); + ASSERT_NE(it, result_info.macro_calls().end()); + + const Expr& macro_expr = it->second; + ASSERT_TRUE(macro_expr.has_call_expr()); + ASSERT_EQ(macro_expr.call_expr().args().size(), 1); + + const Expr& arg = macro_expr.call_expr().args()[0]; + EXPECT_EQ(arg.id(), 3); + EXPECT_TRUE(arg.has_ident_expr()); + EXPECT_EQ(arg.ident_expr().name(), "replacement"); +} + +class IdAdorner : public cel::test::ExpressionAdorner { + public: + std::string Adorn(const cel::Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(OptimizerExprFactory, UnparseCopiedMacroCall) { + // Arrange: create an template expression and one to inline. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto basis_result, + compiler->Compile("[1].map(x, x + to_replace)")); + ASSERT_TRUE(basis_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto basis_ast, basis_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto copy_result, + compiler->Compile("[1].filter(x, x > 2).size()")); + ASSERT_TRUE(copy_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto copy_ast, copy_result.ReleaseAst()); + + // Locate the "to_replace" IdentExpr node in reference_map + ExprId to_replace_id = 0; + for (const auto& [id, ref] : basis_ast->reference_map()) { + if (ref.name() == "to_replace") { + to_replace_id = id; + break; + } + } + ASSERT_NE(to_replace_id, 0); + + // Act: implement the optimization. + TestOptimizerExprFactory factory{std::move(*basis_ast)}; + Expr copied_expr = factory.Copy(copy_ast->root_expr()); + SourceInfo remapped_info = factory.RemapSourceInfo(copy_ast->source_info()); + factory.MergeSourceInfo(remapped_info); + + ReplaceExprInTree(factory.mutable_ast().mutable_root_expr(), to_replace_id, + copied_expr); + factory.RecordReplacement(to_replace_id, copied_expr); + + // Test AST structure. + EXPECT_EQ( + cel::test::ExprPrinter(IdAdorner()).Print(factory.ast().root_expr()), + R"(__comprehension__( + // Variable + x, + // Target + [ + 1#2 + ]#1, + // Accumulator + @result, + // Init + []#8, + // LoopCondition + true#9, + // LoopStep + _+_( + @result#10, + [ + _+_( + x#5, + __comprehension__( + // Variable + x, + // Target + [ + 1#18 + ]#17, + // Accumulator + @result, + // Init + []#19, + // LoopCondition + true#20, + // LoopStep + _?_:_( + _>_( + x#23, + 2#24 + )#22, + _+_( + @result#26, + [ + x#28 + ]#27 + )#25, + @result#29 + )#21, + // Result + @result#30)#16.size()#15 + )#6 + ]#11 + )#12, + // Result + @result#13)#14)"); + + // Check that the structure is compatible with unparser. + cel::expr::ParsedExpr optimized_parsed; + auto status = AstToParsedExpr(factory.ast(), &optimized_parsed); + ASSERT_THAT(status, absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::string unparsed, + google::api::expr::Unparse(optimized_parsed)); + + EXPECT_EQ(unparsed, "[1].map(x, x + [1].filter(x, x > 2).size())"); + + const CallExpr& call_expr = factory.mutable_ast() + .mutable_source_info() + .mutable_macro_calls()[14] + .mutable_call_expr(); + ASSERT_THAT(call_expr.args(), SizeIs(2)); + ASSERT_THAT(call_expr.args()[1].call_expr().args(), SizeIs(2)); + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].id(), 15); + + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].call_expr().target().id(), + 16); + EXPECT_EQ(call_expr.args()[1] + .call_expr() + .args()[1] + .call_expr() + .target() + .kind_case(), + ExprKindCase::kUnspecifiedExpr); +} + +TEST(OptimizerExprFactory, CopyMultipleAstsWithConsumeRenumbers) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto ast1_result, compiler->Compile("[1]")); + ASSERT_TRUE(ast1_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast1, ast1_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto ast2_result, compiler->Compile("2")); + ASSERT_TRUE(ast2_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast2, ast2_result.ReleaseAst()); + + TestOptimizerExprFactory factory{Ast()}; + + Expr copied1 = factory.Copy(ast1->root_expr()); + auto renumbers1 = factory.ConsumeRenumbers(); + + Expr copied2 = factory.Copy(ast2->root_expr()); + auto renumbers2 = factory.ConsumeRenumbers(); + + EXPECT_EQ(renumbers1.size(), 2); + EXPECT_EQ(renumbers2.size(), 1); + + EXPECT_NE(copied1.id(), copied2.id()); + EXPECT_GT(copied2.id(), copied1.id()); +} + +TEST(OptimizerExprFactory, MaxIdVisitorExprKinds) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateTestCompiler()); + + // Expression that covers all the kinds. + ASSERT_OK_AND_ASSIGN(auto source, NewSource(R"cel( + Struct{field : 1} || + {'key' : 'value'} || [1].exists(x, x) || foo(bar))cel")); + ASSERT_OK_AND_ASSIGN(auto ast, compiler->GetParser().Parse(*source)); + + TestOptimizerExprFactory factory{std::move(*ast)}; + + EXPECT_EQ(factory.NextId(), 26); +} + +TEST(OptimizerExprFactory, CopyListElement) { + TestOptimizerExprFactory factory{Ast()}; + ListExprElement orig = factory.NewListElement(factory.NewIdent("foo")); + ListExprElement copied = factory.Copy(orig); + EXPECT_EQ(copied.expr(), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyStructField) { + TestOptimizerExprFactory factory{Ast()}; + StructExprField orig = factory.NewStructField("bar", factory.NewIdent("baz")); + StructExprField copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 3); + EXPECT_EQ(copied.name(), "bar"); + EXPECT_EQ(copied.value(), factory.NewIdent(4, "baz")); +} + +TEST(OptimizerExprFactory, CopyMapEntry) { + TestOptimizerExprFactory factory{Ast()}; + MapExprEntry orig = + factory.NewMapEntry(factory.NewIdent("bar"), factory.NewIdent("baz")); + MapExprEntry copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 4); + EXPECT_EQ(copied.key(), factory.NewIdent(5, "bar")); + EXPECT_EQ(copied.value(), factory.NewIdent(6, "baz")); +} + +TEST(OptimizerExprFactory, MergeSourceInfoMacroConflict) { + SourceInfo base_info; + base_info.mutable_macro_calls()[1] = Expr(); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_macro_calls()[1] = Expr(); + + factory.MergeSourceInfo(new_info); + + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in macro calls merge"); +} + +} // namespace +} // namespace cel diff --git a/policy/test_custom_yaml_policy_parser.cc b/policy/test_custom_yaml_policy_parser.cc new file mode 100644 index 000000000..faced6952 --- /dev/null +++ b/policy/test_custom_yaml_policy_parser.cc @@ -0,0 +1,188 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parser.h" +#include "policy/yaml_policy_parser.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel::internal { + +// TestCustomYamlPolicyParser is used to support unit tests for custom tags +// and custom policy structures. It demonstrates the versatility of the +// cel::YamlPolicyParser framework API by implementing custom tag and block +// parsing without needing to modify the core parser. +class TestCustomYamlPolicyParser : public cel::YamlPolicyParser { + absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const override { + if (tag_name.value() == "name" || tag_name.value() == "description" || + tag_name.value() == "imports") { + return cel::YamlPolicyParser::ParsePolicyTag(ctx, tag_name, node); + } + if (tag_name.value() == "purpose") { + std::optional purpose = + GetValueString(ctx, node, "Policy purpose is not a string"); + if (purpose.has_value()) { + ctx.policy().mutable_metadata()["purpose"] = *purpose; + } + return true; + } + if (tag_name.value() == "version") { + std::optional version = + GetValueString(ctx, node, "Policy version is not a string"); + if (!version.has_value()) { + return true; + } + int version_int; + if (!absl::SimpleAtoi(version->value(), &version_int)) { + ctx.ReportError(version->id(), + absl::StrCat("Policy version is not an integer: ", + version->value())); + return true; + } + ctx.policy().mutable_metadata()["version"] = version_int; + return true; + } + + if (tag_name.value() == "conditions") { + if (!node.IsSequence()) { + ctx.ReportError(tag_name.id(), "Policy 'conditions' is not a sequence"); + return true; + } + for (const YAML::Node& condition : node) { + // Track the number of existing matches before parsing. When ParseMatch + // evaluates an 'else' block, it recursively triggers parsing and adds + // internal inner matches directly to the rule's match vector. + // Inserting the outer match at begin() + size_before ensures that the + // primary outer 'if' condition is always evaluated before its nested + // 'else' fallbacks. + // + // Example: + // if: x > 0 + // then: "positive" + // else: "negative" + // + // The inner "negative" match is parsed and appended to rule.matches() + // by the inner recursive call, before the outer "x > 0" match finishes. + // Inserting at size_before places the "x > 0" match ahead of the inner + // one. + size_t size_before = ctx.policy().rule().matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, + cel::YamlPolicyParser::ParseMatch( + ctx, condition, ctx.policy().mutable_rule())); + ctx.policy().mutable_rule().mutable_matches().insert( + ctx.policy().mutable_rule().mutable_matches().begin() + size_before, + std::move(match)); + } + + return true; + } + return false; + } + + absl::Status ParseThenBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, + Match& match) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'then' is not a string"); + if (val.has_value()) { + OutputBlock output; + output.set_output(*val); + match.set_result(output); + } + } else if (value_node.IsMap()) { + auto nested_rule = std::make_unique(); + CEL_ASSIGN_OR_RETURN( + Match nested_match, + cel::YamlPolicyParser::ParseMatch(ctx, value_node, *nested_rule)); + nested_rule->mutable_matches().insert( + nested_rule->mutable_matches().begin(), std::move(nested_match)); + match.set_result(std::move(nested_rule)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::Status ParseElseBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, Rule& rule) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'else' is not a string"); + if (val.has_value()) { + Match else_match; + else_match.set_id(CollectMetadata(ctx, value_node)); + OutputBlock output; + output.set_output(*val); + else_match.set_result(output); + rule.mutable_matches().push_back(std::move(else_match)); + } + } else if (value_node.IsMap()) { + size_t size_before = rule.matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, cel::YamlPolicyParser::ParseMatch( + ctx, value_node, rule)); + rule.mutable_matches().insert( + rule.mutable_matches().begin() + size_before, std::move(match)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, Match& match, + Rule& rule) const override { + if (tag_name.value() == "if") { + std::optional condition = + GetValueString(ctx, node, "Policy 'if' condition is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "then") { + CEL_RETURN_IF_ERROR(ParseThenBlock(ctx, node, match)); + return true; + } + if (tag_name.value() == "else") { + CEL_RETURN_IF_ERROR(ParseElseBlock(ctx, node, rule)); + return true; + } + return false; + } +}; + +const CelPolicyParser& GetTestCustomYamlPolicyParser() { + static const auto* const parser = new TestCustomYamlPolicyParser(); + return *parser; +} + +} // namespace cel::internal diff --git a/policy/test_util.cc b/policy/test_util.cc new file mode 100644 index 000000000..9fe1e43d1 --- /dev/null +++ b/policy/test_util.cc @@ -0,0 +1,221 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +#include "policy/test_util.h" + +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "cel/expr/value.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "yaml-cpp/yaml.h" + +namespace cel::test { + +namespace { + +absl::Status YamlToExprValue(const YAML::Node& node, + cel::expr::Value* proto) { + if (node.IsNull()) { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + if (node.IsScalar()) { + // Try bool + try { + proto->set_bool_value(node.as()); + return absl::OkStatus(); + } catch (...) { + } + // Try int64 + try { + int64_t val; + if (YAML::convert::decode(node, val)) { + proto->set_int64_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Try double + try { + double val; + if (YAML::convert::decode(node, val)) { + proto->set_double_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Fallback to string + proto->set_string_value(node.as()); + return absl::OkStatus(); + } + if (node.IsSequence()) { + auto* list = proto->mutable_list_value(); + for (const auto& elem : node) { + CEL_RETURN_IF_ERROR(YamlToExprValue(elem, list->add_values())); + } + return absl::OkStatus(); + } + if (node.IsMap()) { + auto* map_val = proto->mutable_map_value(); + for (auto it = node.begin(); it != node.end(); ++it) { + auto* entry = map_val->add_entries(); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->first, entry->mutable_key())); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->second, entry->mutable_value())); + } + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unknown YAML node type"); +} + +absl::Status ParseInputValue( + const YAML::Node& node, + cel::expr::conformance::test::InputValue* input_val) { + if (node.IsMap() && node["expr"].IsDefined()) { + input_val->set_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node.IsMap() && node["value"].IsDefined()) { + return YamlToExprValue(node["value"], input_val->mutable_value()); + } + return YamlToExprValue(node, input_val->mutable_value()); +} + +absl::Status ParseTestOutput(const YAML::Node& node, + cel::expr::conformance::test::TestOutput* output) { + if (!node.IsDefined()) { + return absl::InvalidArgumentError("Missing output node"); + } + if (node.IsMap()) { + if (node["expr"].IsDefined()) { + output->set_result_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node["value"].IsDefined()) { + return YamlToExprValue(node["value"], output->mutable_result_value()); + } + if (node["error"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + eval_error->add_errors()->set_message(node["error"].as()); + return absl::OkStatus(); + } + if (node["error_set"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + for (const auto& err : node["error_set"]) { + eval_error->add_errors()->set_message(err.as()); + } + return absl::OkStatus(); + } + if (node["unknown"].IsDefined()) { + auto* unknown = output->mutable_unknown(); + for (const auto& expr_id_node : node["unknown"]) { + unknown->add_exprs(expr_id_node.as()); + } + return absl::OkStatus(); + } + } + return YamlToExprValue(node, output->mutable_result_value()); +} + +absl::StatusOr +ParsePolicyTestSuiteYamlImpl(absl::string_view yaml_content) { + YAML::Node tests_node; + try { + tests_node = YAML::Load(std::string(yaml_content)); + } catch (const std::exception& e) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse YAML: ", e.what())); + } + + cel::expr::conformance::test::TestSuite test_suite; + if (tests_node["description"].IsDefined()) { + test_suite.set_description(tests_node["description"].as()); + } + + YAML::Node sections = tests_node["sections"]; + if (!sections.IsDefined()) { + sections = tests_node["section"]; // support singular format + } + if (!sections.IsDefined()) { + return absl::InvalidArgumentError( + "Missing 'sections' or 'section' in tests YAML"); + } + + for (const auto& section_node : sections) { + auto* section = test_suite.add_sections(); + if (section_node["name"].IsDefined()) { + section->set_name(section_node["name"].as()); + } + if (section_node["description"].IsDefined()) { + section->set_description(section_node["description"].as()); + } + + YAML::Node tests = section_node["tests"]; + if (!tests.IsDefined()) { + tests = section_node["test"]; // support singular format + } + if (!tests.IsDefined()) { + continue; + } + + for (const auto& test_node : tests) { + auto* test_case = section->add_tests(); + if (test_node["name"].IsDefined()) { + test_case->set_name(test_node["name"].as()); + } + if (test_node["description"].IsDefined()) { + test_case->set_description(test_node["description"].as()); + } + if (test_node["context_expr"].IsDefined()) { + test_case->mutable_input_context()->set_context_expr( + test_node["context_expr"].as()); + } + + YAML::Node input_node = test_node["input"]; + if (input_node.IsDefined() && input_node.IsMap()) { + auto* input_map = test_case->mutable_input(); + for (auto it = input_node.begin(); it != input_node.end(); ++it) { + std::string var_name = it->first.as(); + cel::expr::conformance::test::InputValue input_val; + CEL_RETURN_IF_ERROR(ParseInputValue(it->second, &input_val)); + (*input_map)[var_name] = std::move(input_val); + } + } + + YAML::Node output_node = test_node["output"]; + if (output_node.IsDefined()) { + CEL_RETURN_IF_ERROR( + ParseTestOutput(output_node, test_case->mutable_output())); + } + } + } + + return test_suite; +} + +} // namespace + +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content) { + try { + return ParsePolicyTestSuiteYamlImpl(yaml_content); + } catch (...) { + return absl::InvalidArgumentError("Failed to parse YAML"); + } +} + +} // namespace cel::test diff --git a/policy/test_util.h b/policy/test_util.h new file mode 100644 index 000000000..5fe306050 --- /dev/null +++ b/policy/test_util.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "cel/expr/conformance/test/suite.pb.h" + +namespace cel::test { + +// Parses a YAML content representing a policy test suite (tests.yaml) +// and adapts it to the cel.expr.conformance.test.TestSuite protobuf message. +// +// TODO(uncreated-issue/92): Move to the testrunner library. +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ diff --git a/policy/testdata/BUILD b/policy/testdata/BUILD new file mode 100644 index 000000000..10a26fa0b --- /dev/null +++ b/policy/testdata/BUILD @@ -0,0 +1,19 @@ +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "policy_testdata", + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) + +exports_files( + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) diff --git a/policy/testdata/cel_policy.yaml b/policy/testdata/cel_policy.yaml new file mode 100644 index 000000000..010ad8855 --- /dev/null +++ b/policy/testdata/cel_policy.yaml @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Environment: +# spec: TestAllTypes +name: cel_policy +description: A test policy for CEL +display_name: Cel Policy +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +- name: cel.expr.conformance.proto3.TestAllTypes.NestedEnum +rule: + id: test_rule + description: test rule description + variables: + - name: test_var + expression: > + TestAllTypes{single_int64: 10}.single_int64 + match: + - condition: > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + output: | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + explanation: | + "invalid spec, spec is greater than 10" + - condition: > + spec.standalone_enum == NestedEnum.BAR + output: | + "invalid spec, reference to BAR is not allowed" + - condition: spec.single_int64 == variables.test_var + output: '"invalid spec: exactly matches test_var"' + explanation: '"the spec cannot have single_int64 set to a known bad value"' \ No newline at end of file diff --git a/policy/testdata/cel_policy_parser.baseline b/policy/testdata/cel_policy_parser.baseline new file mode 100644 index 000000000..7a6678bfe --- /dev/null +++ b/policy/testdata/cel_policy_parser.baseline @@ -0,0 +1,89 @@ +POLICY SOURCE: cel_policy.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # Environment: + # spec: TestAllTypes + #0> name: #1> cel_policy + #2> description: #3> A test policy for CEL + #4> display_name: #5> Cel Policy + #6> imports: + - #7> name: #8> cel.expr.conformance.proto3.TestAllTypes + - #9> name: #10> cel.expr.conformance.proto3.TestAllTypes.NestedEnum + #11> rule: + #13> #12> id: #14> test_rule + #15> description: #16> test rule description + #17> variables: + - #18> name: #19> test_var + #20> expression: #21> > + TestAllTypes{single_int64: 10}.single_int64 + #22> match: + - #24> #23> condition: #25> > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + #26> output: #27> | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + #28> explanation: #29> | + "invalid spec, spec is greater than 10" + - #31> #30> condition: #32> > + spec.standalone_enum == NestedEnum.BAR + #33> output: #34> | + "invalid spec, reference to BAR is not allowed" + - #36> #35> condition: #37> spec.single_int64 == variables.test_var + #38> output: #39> '"invalid spec: exactly matches test_var"' + #40> explanation: #41> '"the spec cannot have single_int64 set to a known bad value"' + =========================================================== + name: #1> "cel_policy" + description: #3> "A test policy for CEL" + display_name: #5> "Cel Policy" + imports: + #7> name: #8> "cel.expr.conformance.proto3.TestAllTypes" + #9> name: #10> "cel.expr.conformance.proto3.TestAllTypes.NestedEnum" + #12> rule: { + rule_id: #14> "test_rule" + description: #16> "test rule description" + variable: { + name: #19> "test_var" + expression: #21> "TestAllTypes{single_int64: 10}.single_int64 + " + } + #23> match: { + condition: #25> "spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + " + result: { + output: #27> ""invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + " + explanation: #29> ""invalid spec, spec is greater than 10" + " + } + } + #30> match: { + condition: #32> "spec.standalone_enum == NestedEnum.BAR + " + result: { + output: #34> ""invalid spec, reference to BAR is not allowed" + " + } + } + #35> match: { + condition: #37> "spec.single_int64 == variables.test_var" + result: { + output: #39> ""invalid spec: exactly matches test_var"" + explanation: #41> ""the spec cannot have single_int64 set to a known bad value"" + } + } + } +} diff --git a/policy/testdata/custom_policy_format.yaml b/policy/testdata/custom_policy_format.yaml new file mode 100644 index 000000000..a67356906 --- /dev/null +++ b/policy/testdata/custom_policy_format.yaml @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: test +version: 42 +conditions: +- if: spec.single_string == "none" + then: "'zero'" + else: + if: spec.single_string == "integer" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: "'negative integer'" + else: "'not an integer'" diff --git a/policy/testdata/custom_policy_format_parser.baseline b/policy/testdata/custom_policy_format_parser.baseline new file mode 100644 index 000000000..d5b1a2235 --- /dev/null +++ b/policy/testdata/custom_policy_format_parser.baseline @@ -0,0 +1,75 @@ +POLICY SOURCE: custom_policy_format.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> cel_policy_custom_tags + #2> description: #3> A custom policy format + #4> imports: + - #5> name: #6> cel.expr.conformance.proto3.TestAllTypes + #7> purpose: #8> test + #9> version: #10> 42 + #11> conditions: + - #13> #12> if: #14> spec.single_string == "none" + #15> then: #16> "'zero'" + #17> else: + #19> #18> if: #20> spec.single_string == "integer" + #21> then: + #23> #22> if: #24> spec.single_int32 > 0 + #25> then: #26> "'positive integer'" + #27> else: #29> #28> "'negative integer'" + #30> else: #32> #31> "'not an integer'" + + =========================================================== + name: #1> "cel_policy_custom_tags" + description: #3> "A custom policy format" + metadata: { + purpose: #8> "test" + version: 42 + } + imports: + #5> name: #6> "cel.expr.conformance.proto3.TestAllTypes" + rule: { + #12> match: { + condition: #14> "spec.single_string == "none"" + result: { + output: #16> "'zero'" + } + } + #18> match: { + condition: #20> "spec.single_string == "integer"" + result: + rule: { + #22> match: { + condition: #24> "spec.single_int32 > 0" + result: { + output: #26> "'positive integer'" + } + } + #29> match: { + result: { + output: #28> "'negative integer'" + } + } + } + } + #32> match: { + result: { + output: #31> "'not an integer'" + } + } + } +} diff --git a/policy/testdata/custom_policy_format_with_errors.yaml b/policy/testdata/custom_policy_format_with_errors.yaml new file mode 100644 index 000000000..594747c60 --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors.yaml @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: + - testing +version: new +conditions: +- if: + spec.single_string: "none" + then: "'zero'" + else: "'not zero'" +- if: spec.single_string == "number" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: + - ignore +- else: "'negative integer'" + diff --git a/policy/testdata/custom_policy_format_with_errors_parser.baseline b/policy/testdata/custom_policy_format_with_errors_parser.baseline new file mode 100644 index 000000000..978d27bda --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors_parser.baseline @@ -0,0 +1,16 @@ +POLICY SOURCE: custom_policy_format_with_errors.yaml +-------------------------------------------------------------------- +-------------------------------------------------------------------- +PARSER ISSUES: +ERROR: custom_policy_format_with_errors.yaml:19:3: Policy purpose is not a string + | - testing + | ..^ +ERROR: custom_policy_format_with_errors.yaml:20:10: Policy version is not an integer: new + | version: new + | .........^ +ERROR: custom_policy_format_with_errors.yaml:23:5: Policy 'if' condition is not a string + | spec.single_string: "none" + | ....^ +ERROR: custom_policy_format_with_errors.yaml:31:7: Bad syntax in 'if/then' block + | - ignore + | ......^ diff --git a/policy/testdata/nested_rule.yaml b/policy/testdata/nested_rule.yaml new file mode 100644 index 000000000..2b07faa64 --- /dev/null +++ b/policy/testdata/nested_rule.yaml @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: nested_rule +rule: + variables: + - name: "permitted_regions" + expression: "['us', 'uk', 'es']" + match: + - rule: + id: "banned regions" + description: > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + variables: + - name: "banned_regions" + expression: "{'us': false, 'ru': false, 'ir': false}" + match: + - condition: | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + output: "{'banned': true}" + - condition: resource.origin in variables.permitted_regions + output: "{'banned': false}" + - output: "{'banned': true}" + explanation: "'resource is in the banned region ' + resource.origin" \ No newline at end of file diff --git a/policy/testdata/nested_rule_parser.baseline b/policy/testdata/nested_rule_parser.baseline new file mode 100644 index 000000000..128f81bda --- /dev/null +++ b/policy/testdata/nested_rule_parser.baseline @@ -0,0 +1,84 @@ +POLICY SOURCE: nested_rule.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> nested_rule + #2> rule: + #4> #3> variables: + - #5> name: #6> "permitted_regions" + #7> expression: #8> "['us', 'uk', 'es']" + #9> match: + - #11> #10> rule: + #13> #12> id: #14> "banned regions" + #15> description: #16> > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + #17> variables: + - #18> name: #19> "banned_regions" + #20> expression: #21> "{'us': false, 'ru': false, 'ir': false}" + #22> match: + - #24> #23> condition: #25> | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + #26> output: #27> "{'banned': true}" + - #29> #28> condition: #30> resource.origin in variables.permitted_regions + #31> output: #32> "{'banned': false}" + - #34> #33> output: #35> "{'banned': true}" + #36> explanation: #37> "'resource is in the banned region ' + resource.origin" + =========================================================== + name: #1> "nested_rule" + description: "nested_rule.yaml" + #3> rule: { + variable: { + name: #6> "permitted_regions" + expression: #8> "['us', 'uk', 'es']" + } + #10> match: { + result: + #12> rule: { + rule_id: #14> "banned regions" + description: #16> "determine whether the resource origin is in the banned list. If the region is also in the permitted list, the ban has no effect. + " + variable: { + name: #19> "banned_regions" + expression: #21> "{'us': false, 'ru': false, 'ir': false}" + } + #23> match: { + condition: #25> "resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + " + result: { + output: #27> "{'banned': true}" + } + } + } + } + #28> match: { + condition: #30> "resource.origin in variables.permitted_regions" + result: { + output: #32> "{'banned': false}" + } + } + #33> match: { + result: { + output: #35> "{'banned': true}" + explanation: #37> "'resource is in the banned region ' + resource.origin" + } + } + } +} diff --git a/policy/yaml_policy_parser.cc b/policy/yaml_policy_parser.cc new file mode 100644 index 000000000..c838cff33 --- /dev/null +++ b/policy/yaml_policy_parser.cc @@ -0,0 +1,411 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +CelPolicyElementId YamlPolicyParser::CollectMetadata( + CelPolicyParseContext& ctx, const YAML::Node& node) const { + CelPolicyElementId element_id = ctx.next_element_id(); + if (!node.Mark().is_null()) { + ctx.policy_source().NoteSourcePosition(element_id, node.Mark().pos); + } + return element_id; +} + +std::optional YamlPolicyParser::GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const { + if (!node.IsDefined()) { + // This should never happen since the YAML syntax has already been checked. + return std::nullopt; + } + + CelPolicyElementId id = CollectMetadata(ctx, node); + if (!node.IsScalar()) { + ctx.ReportError(id, error_message); + return std::nullopt; + } + + try { + return ValueString(id, node.as()); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return std::nullopt; + } +} + +absl::Status YamlPolicyParser::ParsePolicy(CelPolicyParseContext& ctx) const { + const Source* source = ctx.policy_source().content(); + if (source == nullptr) { + return absl::OkStatus(); + } + + ctx.policy().set_description(ValueString(-1, source->description())); + std::string text = source->content().ToString(); + YAML::Node node; + try { + node = YAML::Load(text); + } catch (YAML::Exception& e) { + if (!e.mark.is_null()) { + ctx.policy_source().NoteSourcePosition(0, e.mark.pos); + } + ctx.ReportError(0, "Invalid CEL policy YAML syntax"); + return absl::OkStatus(); + } + + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy is not a map"); + return absl::OkStatus(); + } + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, ParsePolicyTag(ctx, *key, value_node)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized top-level policy tag: ", key->value())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParsePolicyTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node) const { + if (tag_name.value() == "imports") { + CEL_RETURN_IF_ERROR(ParseImports(ctx, node)); + return true; + } + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy 'name' is not a string"); + if (name.has_value()) { + ctx.policy().set_name(*name); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy 'description' is not a string"); + if (description.has_value()) { + ctx.policy().set_description(*description); + } + return true; + } + if (tag_name.value() == "display_name") { + std::optional display_name = + GetValueString(ctx, node, "Policy 'display_name' is not a string"); + if (display_name.has_value()) { + ctx.policy().set_display_name(*display_name); + } + return true; + } + if (tag_name.value() == "rule") { + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, ctx.policy().mutable_rule())); + return true; + } + return false; +} + +absl::Status YamlPolicyParser::ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy 'imports' is not a sequence"); + return absl::OkStatus(); + } + + for (const YAML::Node& import : node) { + CelPolicyElementId import_id = CollectMetadata(ctx, import); + if (!import.IsMap()) { + ctx.ReportError(import_id, "Import is not a map"); + continue; + } + const YAML::Node& name_node = import["name"]; + if (!name_node.IsDefined()) { + ctx.ReportError(import_id, "No 'name' tag in import"); + continue; + } + std::optional import_name = + GetValueString(ctx, name_node, "Import name is not a string"); + if (import_name.has_value()) { + ctx.policy().mutable_imports().push_back(Import(import_id, *import_name)); + } + } + return absl::OkStatus(); +} + +absl::Status YamlPolicyParser::ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy 'rule' is not a map"); + return absl::OkStatus(); + } + rule.set_id(CollectMetadata(ctx, node)); + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy rule tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseRuleTag(ctx, *key, value_node, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy rule tag: ", + key->value())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const { + if (tag_name.value() == "id") { + std::optional rule_id = + GetValueString(ctx, node, "Policy rule 'id' is not a string"); + if (rule_id.has_value()) { + rule.set_rule_id(*rule_id); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy rule 'description' is not a string"); + if (description.has_value()) { + rule.set_description(*description); + } + return true; + } + if (tag_name.value() == "variables") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variables' is not a sequence"); + return true; + } + for (const YAML::Node& variable_node : node) { + CEL_ASSIGN_OR_RETURN(Variable variable, + ParseVariable(ctx, variable_node, rule)); + rule.mutable_variables().push_back(std::move(variable)); + } + return true; + } + if (tag_name.value() == "match") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'match' is not a sequence"); + return true; + } + for (const YAML::Node& match_node : node) { + CEL_ASSIGN_OR_RETURN(Match match, ParseMatch(ctx, match_node, rule)); + rule.mutable_matches().push_back(std::move(match)); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseVariable( + CelPolicyParseContext& ctx, const YAML::Node& node, Rule& rule) const { + Variable variable; + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variable' is not a map"); + return variable; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy variable tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseVariableTag(ctx, *key, value_node, variable)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized policy variable tag: ", key->value())); + } + } + return variable; +} + +absl::StatusOr YamlPolicyParser::ParseVariableTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Variable& variable) const { + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy variable 'name' is not a string"); + if (name.has_value()) { + variable.set_name(*name); + } + return true; + } + if (tag_name.value() == "expression") { + std::optional expression = GetValueString( + ctx, node, "Policy variable 'expression' is not a string"); + if (expression.has_value()) { + variable.set_expression(*expression); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + Match match; + match.set_id(CollectMetadata(ctx, node)); + if (!node.IsMap()) { + ctx.ReportError(match.id(), "Policy rule 'match' is not a map"); + return match; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy match tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseMatchTag(ctx, *key, value_node, match, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy match tag: ", + key->value())); + } + } + + if (match.has_output_block()) { + if (match.output_block().output().value().empty() && + match.output_block().explanation().has_value()) { + ctx.ReportError(match.id(), "Match specifies explanation but no output"); + } + } + + return match; +} + +absl::StatusOr YamlPolicyParser::ParseMatchTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Match& match, Rule& rule) const { + if (tag_name.value() == "condition") { + std::optional condition = + GetValueString(ctx, node, "Policy match 'condition' is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "explanation") { + std::optional explanation = + GetValueString(ctx, node, "Policy match 'explanation' is not a string"); + if (explanation.has_value()) { + if (match.has_rule()) { + ctx.ReportError( + tag_name.id(), + "Cannot specify explanation when a nested rule is present"); + } else { + match.mutable_output_block().set_explanation(*explanation); + } + } + return true; + } + if (tag_name.value() == "output") { + std::optional output = + GetValueString(ctx, node, "Policy match 'output' is not a string"); + if (output.has_value()) { + if (match.has_rule()) { + ctx.ReportError(tag_name.id(), + "Cannot specify output when a nested rule is present"); + } else { + match.mutable_output_block().set_output(*output); + } + } + return true; + } + if (tag_name.value() == "rule") { + if (match.has_output_block()) { + ctx.ReportError(tag_name.id(), + "Cannot specify nested rule when output/explanation is " + "present"); + } + auto nested_rule = std::make_unique(); + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, *nested_rule)); + match.set_result(std::move(nested_rule)); + return true; + } + return false; +} + +const CelPolicyParser& GetDefaultYamlPolicyParser() { + static const auto* const parser = new YamlPolicyParser(); + return *parser; +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source) { + return ParseYamlCelPolicy(std::move(policy_source), + GetDefaultYamlPolicyParser()); +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser) { + CelPolicyParseContext ctx(std::move(policy_source)); + CEL_RETURN_IF_ERROR(parser.ParsePolicy(ctx)); + return ctx.GetResult(); +} + +} // namespace cel diff --git a/policy/yaml_policy_parser.h b/policy/yaml_policy_parser.h new file mode 100644 index 000000000..469209333 --- /dev/null +++ b/policy/yaml_policy_parser.h @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +// A parser for YAML-based CEL policies. +// +// To support additional or alternative YAML elements, subclass +// `YamlPolicyParser` and override specific parsing methods, `Parse*` +class YamlPolicyParser : public CelPolicyParser { + public: + std::optional GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const; + + absl::Status ParsePolicy(CelPolicyParseContext& ctx) const override; + + protected: + // Collects metadata (e.g. source position) for the given YAML node, stores it + // in the context, and returns an ID that can be used to refer to it. + virtual CelPolicyElementId CollectMetadata(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a top-level tag in the policy YAML. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const; + + // Parses the imports section of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a rule element of the policy YAML, which may be the top-level rule + // or a sub-rule of a match. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, Rule& rule) const; + + // Parses a tag in a policy YAML rule. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const; + + // Parses a variable element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariable(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML variable. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariableTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Variable& variable) const; + + // Parses a match element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML match. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Match& match, Rule& rule) const; +}; + +// Returns a default implementation of YamlPolicyParser. +const CelPolicyParser& GetDefaultYamlPolicyParser(); + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser); + +// YAML CelPolicy parser that uses the default format as implemented by +// `YamlPolicyParser`. +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ diff --git a/policy/yaml_policy_parser_test.cc b/policy/yaml_policy_parser_test.cc new file mode 100644 index 000000000..4e7dfc49c --- /dev/null +++ b/policy/yaml_policy_parser_test.cc @@ -0,0 +1,305 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +namespace internal { +const CelPolicyParser& GetTestCustomYamlPolicyParser(); +} // namespace internal + +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsNull; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/"; + +constexpr absl::string_view kBaselineSeparator = + "--------------------------------------------------------------------\n"; + +struct YamlPolicyParserTestCase { + std::string policy_source_file; + std::string baseline_file; + const cel::CelPolicyParser& (*parser_factory)(); +}; + +using YamlPolicyParserTest = testing::TestWithParam; + +TEST_P(YamlPolicyParserTest, Parse) { + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().policy_source_file)); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + std::string baseline; + std::string baseline_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().baseline_file)); + ASSERT_THAT(cel::internal::GetFileContents(baseline_file, &baseline), IsOk()); + baseline = absl::StripAsciiWhitespace(baseline); + + std::ostringstream out; + out << "POLICY SOURCE: " << GetParam().policy_source_file << "\n"; + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, GetParam().policy_source_file)); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + + ASSERT_OK_AND_ASSIGN( + CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source, GetParam().parser_factory())); + + out << kBaselineSeparator; + if (parse_result.IsValid()) { + out << "PARSED POLICY:\n"; + out << parse_result.GetPolicy()->DebugString(); + } else { + ASSERT_THAT(parse_result.GetPolicy(), IsNull()); + out << kBaselineSeparator; + out << "PARSER ISSUES:\n"; + for (const auto& issue : parse_result.GetIssues()) { + out << issue.ToDisplayString(*policy_source) << "\n"; + } + } + + std::string actual(absl::StripAsciiWhitespace(out.str())); + if (actual != baseline) { + // Log the actual result to make it easier to copy/paste into the baseline + // file when updating the tests. + ABSL_LOG(INFO) << "Actual:\n" << actual; + EXPECT_EQ(actual, baseline); + } +} + +INSTANTIATE_TEST_SUITE_P( + Formats, YamlPolicyParserTest, + testing::ValuesIn({ + YamlPolicyParserTestCase{ + .policy_source_file = "cel_policy.yaml", + .baseline_file = "cel_policy_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "nested_rule.yaml", + .baseline_file = "nested_rule_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format.yaml", + .baseline_file = "custom_policy_format_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format_with_errors.yaml", + .baseline_file = "custom_policy_format_with_errors_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + })); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +using YamlPolicyParseErrorTest = testing::TestWithParam; + +TEST_P(YamlPolicyParseErrorTest, YamlSyntaxError) { + const ParseTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(param.yaml, "test")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + EXPECT_THAT(parse_result.FormattedIssues(), HasSubstr(param.expected_error)); +} + +std::vector GetParseTestCases() { + return { + ParseTestCase{ + .yaml = R"yaml( ? [ John, Doe ]: age: 30 )yaml", + .expected_error = "1:22: Invalid CEL policy YAML syntax\n" + " | ? [ John, Doe ]: age: 30 \n" + " | .....................^", + }, + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Policy is not a map\n" + " | invalid yaml \n" + " | .^", + }, + ParseTestCase{ + .yaml = R"yaml( + ? [1, 2, 3] + : "Prime numbers sequence" + )yaml", + .expected_error = "2:23: Policy tag is not a string\n" + " | ? [1, 2, 3]\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: N/A + )yaml", + .expected_error = "2:28: Policy 'imports' is not a sequence\n" + " | imports: N/A\n" + " | ...........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - cel.expr.conformance + )yaml", + .expected_error = "3:21: Import is not a map\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - name: + - cel.expr.conformance + )yaml", + .expected_error = "4:21: Import name is not a string\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: do something + )yaml", + .expected_error = "2:25: Policy 'rule' is not a map\n" + " | rule: do something\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + id: + - 22 + )yaml", + .expected_error = "4:21: Policy rule 'id' is not a string\n" + " | - 22\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + no vars + )yaml", + .expected_error = "4:23: Policy rule 'variables' is not a sequence\n" + " | no vars\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: + foo: bar + )yaml", + .expected_error = "5:25: Policy variable 'name' is not a string\n" + " | foo: bar\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: test_var + expression: + - 22 + )yaml", + .expected_error = + "6:23: Policy variable 'expression' is not a string\n" + " | - 22\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: '\u0041\u00a9\u20ac\U0001f680' + - '\u0041\u00a9\u20ac\U0001f680': name + )yaml", + .expected_error = + "5:23: Unrecognized policy variable tag: " + "\\u0041\\u00a9\\u20ac\\U0001f680\n" + " | - '\\u0041\\u00a9\\u20ac\\U0001f680': " + "name\n" + " | ......................^", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(YamlPolicyParseErrorTest, YamlPolicyParseErrorTest, + ::testing::ValuesIn(GetParseTestCases())); + +TEST(YamlPolicyParserTest, OffsetIssueFormatting) { + // TODO(b/506179116): will need to copy the go implementation in extracting + // the source string from the YAML document instead of the interpreted string + // value to fix up error locations in folded and block literals. + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, "cel_policy.yaml")); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, "cel_policy.yaml")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + CelPolicyElementId name_id = policy->name().id(); + + CelPolicyIssue issue(name_id, 4, CelPolicyIssue::Severity::kError, + "Test error"); + + std::string formatted = issue.ToDisplayString(*policy_source); + + EXPECT_THAT(formatted, HasSubstr("ERROR: cel_policy.yaml:16:11: Test error")); + EXPECT_THAT(formatted, HasSubstr(" | name: cel_policy")); + EXPECT_THAT(formatted, HasSubstr(" | ..........^")); +} + +} // namespace +} // namespace cel From c8c187fea145b3113b3638d76c727c9052f13511 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 16 Jun 2026 10:10:54 -0700 Subject: [PATCH 117/143] Support overload lookup and filtering by signature as an alternative to overload_id PiperOrigin-RevId: 933154803 --- checker/BUILD | 2 + checker/internal/type_checker_builder_impl.cc | 4 +- checker/type_checker_builder.h | 2 +- checker/type_checker_builder_factory_test.cc | 17 +++-- checker/type_checker_subset_factory.cc | 20 ++++-- checker/type_checker_subset_factory_test.cc | 33 +++++++-- common/BUILD | 3 +- common/decl.cc | 68 +++++++++++-------- common/decl.h | 48 +++---------- common/decl_test.cc | 47 +++++++++++++ env/BUILD | 1 + env/config_test.cc | 55 +++++++++++++++ env/env.cc | 56 +++++++++++---- env/env_test.cc | 16 +++++ env/env_yaml_test.cc | 12 ++-- 15 files changed, 276 insertions(+), 108 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index 27a1eb84e..10ed6e363 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -229,6 +229,8 @@ cc_library( hdrs = ["type_checker_subset_factory.h"], deps = [ ":type_checker_builder", + "//common:decl", + "//common/internal:signature", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 9b91fc926..f0332b999 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -187,8 +187,8 @@ std::optional FilterDecl(FunctionDecl decl, FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); - for (const auto& ovl : overloads) { - if (subset.should_include_overload(name, ovl.id())) { + for (auto& ovl : overloads) { + if (subset.should_include_overload(name, ovl)) { absl::Status s = filtered.AddOverload(std::move(ovl)); if (!s.ok()) { // Should not be possible to construct the original decl in a way that diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f145b8a98..c2d0cbf7b 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -52,7 +52,7 @@ struct CheckerLibrary { // Represents a declaration to only use a subset of a library. struct TypeCheckerSubset { using FunctionPredicate = absl::AnyInvocable; + absl::string_view function, const OverloadDecl& overload) const>; // The id of the library to subset. Only one subset can be applied per // library id. diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 9c4775e7f..40406948d 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -235,8 +235,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id == "add_int" || overload_id == "sub_int"); + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() == "add_int" || overload.id() == "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -274,9 +274,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id != "add_int" && overload_id != "sub_int"); - ; + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() != "add_int" && overload.id() != "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -313,7 +312,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function != "add"; }}), IsOk()); @@ -352,12 +351,12 @@ TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), IsOk()); EXPECT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), StatusIs(absl::StatusCode::kAlreadyExists)); } @@ -369,7 +368,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); EXPECT_THAT(builder->AddLibrarySubset({"", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function == "add"; }}), StatusIs(absl::StatusCode::kInvalidArgument)); diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index 6a05ce220..e5335e220 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -21,14 +21,21 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/internal/signature.h" namespace cel { TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return true; + } + auto signature = common_internal::MakeOverloadSignature( + function, overload.args(), overload.member()); + return signature.ok() && overload_ids.contains(*signature); }; } @@ -41,8 +48,13 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return !overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return false; + } + auto signature = common_internal::MakeOverloadSignature( + function, overload.args(), overload.member()); + return !signature.ok() || !overload_ids.contains(*signature); }; } diff --git a/checker/type_checker_subset_factory_test.cc b/checker/type_checker_subset_factory_test.cc index fa38e1c0d..5b644ec7c 100644 --- a/checker/type_checker_subset_factory_test.cc +++ b/checker/type_checker_subset_factory_test.cc @@ -43,6 +43,8 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals, StandardOverloadIds::kNotStrictlyFalse, + "matches(string,string)", + "string.matches(string)", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -65,15 +67,19 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); + // Allowed by signature. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_TRUE(r.IsValid()); + // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_FALSE(r.IsValid()); - - ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); - EXPECT_FALSE(r.IsValid()); } TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { @@ -83,6 +89,8 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { absl::string_view exclude_list[] = { StandardOverloadIds::kMatches, StandardOverloadIds::kMatchesMember, + "size(string)", + "string.size()", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -105,18 +113,35 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); - // Not in allowlist. + // Allowed. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_TRUE(r.IsValid()); + // Excluded by ID. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); EXPECT_FALSE(r.IsValid()); + + // Excluded by signature (top-level function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by signature (member function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size member). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()")); + EXPECT_TRUE(r.IsValid()); } } // namespace diff --git a/common/BUILD b/common/BUILD index 93410306f..0bd3632dd 100644 --- a/common/BUILD +++ b/common/BUILD @@ -151,9 +151,8 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/common/decl.cc b/common/decl.cc index b338bfd4f..d2d50964a 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -20,8 +20,8 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -109,43 +109,46 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, - OverloadDeclHashSet& overloads, Overload&& overload, - absl::Status& status) { + absl::flat_hash_map& by_id, + absl::flat_hash_map& by_signature, + Overload&& overload, absl::Status& status) { if (!status.ok()) { return; } - if (overload.id().empty()) { - OverloadDecl overload_decl = overload; - absl::StatusOr overload_id = - common_internal::MakeOverloadSignature( - function_name, overload_decl.args(), overload_decl.member()); - if (!overload_id.ok()) { - status = overload_id.status(); - return; - } - overload_decl.set_id(*overload_id); - AddOverloadInternal(function_name, insertion_order, overloads, - std::move(overload_decl), status); + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function_name, overload.args(), + overload.member()); + if (!signature.ok()) { + status = signature.status(); return; } - if (auto it = overloads.find(overload.id()); it != overloads.end()) { + OverloadDecl mutable_overload = std::forward(overload); + + if (mutable_overload.id().empty()) { + mutable_overload.set_id(*signature); + } + + if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) { status = absl::AlreadyExistsError( - absl::StrCat("overload already exists: ", overload.id())); + absl::StrCat("overload exists: ", mutable_overload.id())); return; } - for (const auto& existing : overloads) { - if (SignaturesOverlap(overload, existing)) { + + for (const auto& existing : insertion_order) { + if (SignaturesOverlap(mutable_overload, existing)) { status = absl::InvalidArgumentError( absl::StrCat("overload signature collision: ", existing.id(), - " collides with ", overload.id())); + " collides with ", mutable_overload.id())); return; } } - const auto inserted = overloads.insert(std::forward(overload)); - ABSL_DCHECK(inserted.second); - insertion_order.push_back(*inserted.first); + + size_t index = insertion_order.size(); + by_id[mutable_overload.id()] = index; + by_signature[*signature] = index; + insertion_order.push_back(std::move(mutable_overload)); } void CollectTypeParams(absl::flat_hash_set& type_params, @@ -195,14 +198,25 @@ absl::flat_hash_set OverloadDecl::GetTypeParams() const { void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - overload, status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, overload, status); } void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - std::move(overload), status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, std::move(overload), status); +} + +const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const { + if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) { + return &overloads_.insertion_order[it->second]; + } + if (auto it = overloads_.by_signature.find(id); + it != overloads_.by_signature.end()) { + return &overloads_.insertion_order[it->second]; + } + return nullptr; } } // namespace cel diff --git a/common/decl.h b/common/decl.h index 22ee8cbf0..7aea8c502 100644 --- a/common/decl.h +++ b/common/decl.h @@ -22,11 +22,10 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -264,39 +263,6 @@ OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, return overload_decl; } -struct OverloadDeclHash { - using is_transparent = void; - - size_t operator()(const OverloadDecl& overload_decl) const { - return (*this)(overload_decl.id()); - } - - size_t operator()(absl::string_view id) const { return absl::HashOf(id); } -}; - -struct OverloadDeclEqualTo { - using is_transparent = void; - - bool operator()(const OverloadDecl& lhs, const OverloadDecl& rhs) const { - return (*this)(lhs.id(), rhs.id()); - } - - bool operator()(const OverloadDecl& lhs, absl::string_view rhs) const { - return (*this)(lhs.id(), rhs); - } - - bool operator()(absl::string_view lhs, const OverloadDecl& rhs) const { - return (*this)(lhs, rhs.id()); - } - - bool operator()(absl::string_view lhs, absl::string_view rhs) const { - return lhs == rhs; - } -}; - -using OverloadDeclHashSet = - absl::flat_hash_set; - template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads); @@ -346,21 +312,27 @@ class FunctionDecl final { return overloads_.insertion_order; } + ABSL_MUST_USE_RESULT const OverloadDecl* FindOverloadById( + absl::string_view id) const; + std::vector release_overloads() { std::vector released = std::move(overloads_.insertion_order); overloads_.insertion_order.clear(); - overloads_.set.clear(); + overloads_.by_id.clear(); + overloads_.by_signature.clear(); return released; } private: struct Overloads { std::vector insertion_order; - OverloadDeclHashSet set; + absl::flat_hash_map by_id; + absl::flat_hash_map by_signature; void Reserve(size_t size) { insertion_order.reserve(size); - set.reserve(size); + by_id.reserve(size); + by_signature.reserve(size); } }; diff --git a/common/decl_test.cc b/common/decl_test.cc index 510cd5017..72e7f1b93 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -165,6 +165,53 @@ TEST(FunctionDecl, Overloads) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(FunctionDecl, AddOverloadInvalidSignature) { + FunctionDecl function_decl; + function_decl.set_name("foo"); + // Member overload must have at least one argument (the receiver). + // This should fail to add because signature generation fails. + EXPECT_THAT(function_decl.AddOverload(MakeMemberOverloadDecl(StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadDuplicateId) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl("hello", + MakeOverloadDecl("foo", StringType{}, StringType{}))); + // Adding another overload with the same ID "foo" should fail. + EXPECT_THAT( + function_decl.AddOverload(MakeOverloadDecl("foo", IntType{}, IntType{})), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(FunctionDecl, FindOverload) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}))); + + // Find by explicit ID + const OverloadDecl* overload = function_decl.FindOverloadById("foo"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find by ID fallback to signature + overload = function_decl.FindOverloadById("hello(string)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find implicit overload (where ID == signature) + overload = function_decl.FindOverloadById("hello(int)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "hello(int)"); + + // Non-existent + EXPECT_EQ(function_decl.FindOverloadById("non_existent"), nullptr); +} + TEST(FunctionDecl, OverloadId) { google::protobuf::Arena arena; const auto* descriptor = diff --git a/env/BUILD b/env/BUILD index bd82e8ec6..1816238a5 100644 --- a/env/BUILD +++ b/env/BUILD @@ -56,6 +56,7 @@ cc_library( "//common:container", "//common:decl", "//common:type", + "//common/internal:signature", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", diff --git a/env/config_test.cc b/env/config_test.cc index df0d6f875..8cfc3cf7f 100644 --- a/env/config_test.cc +++ b/env/config_test.cc @@ -88,6 +88,34 @@ INSTANTIATE_TEST_SUITE_P( StandardLibraryConfigTestCase{ .standard_library_config = {}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -106,6 +134,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot set both included and excluded functions.", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -114,6 +151,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot include function '_+_' and also its " "specific overload 'add_list'", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add(int,int)'", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -121,6 +167,15 @@ INSTANTIATE_TEST_SUITE_P( }, .expected_error = "Cannot exclude function '_+_' and also its " "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add(int,int)'", })); TEST(VariableConfigTest, VariableConfig) { diff --git a/env/env.cc b/env/env.cc index 22d24295e..4fa4e7398 100644 --- a/env/env.cc +++ b/env/env.cc @@ -26,6 +26,7 @@ #include "common/constant.h" #include "common/container.h" #include "common/decl.h" +#include "common/internal/signature.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" @@ -57,21 +58,47 @@ bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, absl::string_view function, - absl::string_view overload_id) { - if (config.excluded_functions.contains( - std::make_pair(std::string(function), std::string(overload_id))) || - config.excluded_functions.contains( - std::make_pair(std::string(function), ""))) { - return false; + const OverloadDecl& overload) { + if (config.excluded_functions.empty() && config.included_functions.empty()) { + return true; + } + + if (!config.excluded_functions.empty()) { + if (config.excluded_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function, overload.args(), + overload.member()); + if (signature.ok() && config.excluded_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return false; + } } - if (!config.included_functions.empty() && - !config.included_functions.contains( - std::make_pair(std::string(function), "")) && - !config.included_functions.contains( - std::make_pair(std::string(function), std::string(overload_id)))) { + + if (!config.included_functions.empty()) { + if (config.included_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.included_functions.contains( + std::make_pair(std::string(function), ""))) { + return true; + } + // Ok to call MakeOverloadSignature() again, because in practice either + // included or excluded functions may be specified, but not both. + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function, overload.args(), + overload.member()); + if (signature.ok() && config.included_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return true; + } return false; } - return true; + + return true; // Never reached } absl::StatusOr MakeStdlibSubset( @@ -87,9 +114,8 @@ absl::StatusOr MakeStdlibSubset( }; subset.should_include_overload = [&standard_library_config]( absl::string_view function, - absl::string_view overload_id) { - return ShouldIncludeFunction(standard_library_config, function, - overload_id); + const OverloadDecl& overload) { + return ShouldIncludeFunction(standard_library_config, function, overload); }; return subset; } diff --git a/env/env_test.cc b/env/env_test.cc index fda87dfab..00143a857 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -280,6 +280,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "_+_(bytes,bytes)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}, + {"_+_", "_+_(string,string)"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_functions = {{"_+_", "add_bytes"}, @@ -294,6 +303,13 @@ INSTANTIATE_TEST_SUITE_P( .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "_+_(int,int)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_functions = {{"_+_", "add_int64"}, diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 9c5b3f04f..38f08e371 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -153,7 +153,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { - name: "_+_" overloads: - id: add_bytes - - id: add_list + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: @@ -166,7 +166,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { EXPECT_THAT( stdlib_config.included_functions, UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), std::make_pair("timestamp", "string_to_timestamp"))) << " Actual stdlib config: " << stdlib_config; @@ -1405,9 +1405,9 @@ std::vector GetExportTestCases() { .included_functions = { std::make_pair("timestamp", "string_to_timestamp"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), - std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "_+_(bytes,bytes)"), }, })); return config; @@ -1417,8 +1417,8 @@ std::vector GetExportTestCases() { include_functions: - name: "_+_" overloads: - - id: "add_bytes" - - id: "add_list" + - id: "_+_(bytes,bytes)" + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: From 903210f28babe29c18710d89217c0ce468fef9ba Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 16 Jun 2026 13:19:17 -0700 Subject: [PATCH 118/143] Internal change PiperOrigin-RevId: 933263369 --- checker/BUILD | 1 + checker/validation_result.h | 1 + common/decl.h | 64 +++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/checker/BUILD b/checker/BUILD index 10ed6e363..efca3ff73 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":type_check_issue", "//common:ast", + "//common:decl", "//common:source", "//common:type", "@com_google_absl//absl/base:nullability", diff --git a/checker/validation_result.h b/checker/validation_result.h index f424e7f6f..7417e9969 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -28,6 +28,7 @@ #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" +#include "common/decl.h" #include "common/source.h" #include "common/type.h" diff --git a/common/decl.h b/common/decl.h index 7aea8c502..b15645236 100644 --- a/common/decl.h +++ b/common/decl.h @@ -377,6 +377,70 @@ bool TypeIsAssignable(const Type& to, const Type& from); } // namespace common_internal +struct VariableDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::VariableDecl& lhs, + const cel::VariableDecl& rhs) const { + return lhs.name() == rhs.name(); + } + + bool operator()(const cel::VariableDecl& lhs, std::string_view rhs) const { + return lhs.name() == rhs; + } + + bool operator()(std::string_view lhs, const cel::VariableDecl& rhs) const { + return lhs == rhs.name(); + } +}; + +struct VariableDeclHash { + using is_transparent = void; + + size_t operator()(const cel::VariableDecl& decl) const { + return (*this)(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using VariableDeclSet = absl::flat_hash_set; + +struct FunctionDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::FunctionDecl& lhs, + const cel::FunctionDecl& rhs) const { + return (*this)(lhs.name(), rhs.name()); + } + + bool operator()(const cel::FunctionDecl& lhs, std::string_view rhs) const { + return (*this)(lhs.name(), rhs); + } + + bool operator()(std::string_view lhs, const cel::FunctionDecl& rhs) const { + return (*this)(lhs, rhs.name()); + } + + bool operator()(std::string_view lhs, std::string_view rhs) const { + return lhs == rhs; + } +}; + +struct FunctionDeclHash { + using is_transparent = void; + + size_t operator()(const cel::FunctionDecl& decl) const { + return absl::HashOf(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using FunctionDeclSet = absl::flat_hash_set; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ From 19b54d719884e57b5e41be5c79001d0a13c22c3c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 16 Jun 2026 15:59:45 -0700 Subject: [PATCH 119/143] Add policy conformance tests to postsubmit windows bazel test workflow. PiperOrigin-RevId: 933348031 --- .github/workflows/windows_bazel_test.yml | 2 +- .github/workflows/windows_bazel_test_post_merge.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml index 4ac7f2eec..6d12e6861 100644 --- a/.github/workflows/windows_bazel_test.yml +++ b/.github/workflows/windows_bazel_test.yml @@ -25,4 +25,4 @@ jobs: # //... won't work. shell: bash run: | - bazelisk test --config=msvc conformance:all \ No newline at end of file + bazelisk test --config=msvc conformance:all conformance/policy:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml index 11801011e..569177fcc 100644 --- a/.github/workflows/windows_bazel_test_post_merge.yml +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -9,5 +9,5 @@ jobs: trigger-test: # This prevents the workflow from running automatically when someone # pushes to their fork. - if: github.repository == 'google/cel-cpp' + if: github.repository == 'cel-expr/cel-cpp' uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file From 9485d85c8d73d31ed0d21b8b21397c8d74b90584 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 16 Jun 2026 16:36:08 -0700 Subject: [PATCH 120/143] Update from google/cel-spec to cel-expr/cel-spec (Part 2) PiperOrigin-RevId: 933365198 --- README.md | 2 +- eval/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 41b44388d..7c3c26be0 100644 --- a/README.md +++ b/README.md @@ -15,4 +15,4 @@ parser, and type checker. Released under the [Apache License](LICENSE). -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/eval/README.md b/eval/README.md index ee6fd0798..32fa4bda4 100644 --- a/eval/README.md +++ b/eval/README.md @@ -3,4 +3,4 @@ A C++ implementation of a [Common Expression Language][1] evaluator. -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec From 7dc7eacfbccf56c4d74bcb6e0f4dc2e63768c6bf Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 01:08:33 -0700 Subject: [PATCH 121/143] No public description PiperOrigin-RevId: 933560360 --- .../descriptor_pool_type_introspector.cc | 18 +++++++++--------- .../descriptor_pool_type_introspector_test.cc | 8 ++++---- checker/internal/type_check_env.cc | 10 +++++----- checker/internal/type_checker_builder_impl.cc | 2 +- .../internal/type_checker_builder_impl_test.cc | 4 ++-- checker/internal/type_checker_impl.cc | 6 +++--- checker/internal/type_inference_context.cc | 4 ++-- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc index da4f4430b..733e4a3cb 100644 --- a/checker/internal/descriptor_pool_type_introspector.cc +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -42,7 +42,7 @@ FindStructTypeFieldByNameDirectly( const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } const google::protobuf::FieldDescriptor* absl_nullable field = descriptor->FindFieldByName(name); @@ -54,7 +54,7 @@ FindStructTypeFieldByNameDirectly( if (field != nullptr) { return StructTypeField(MessageTypeField(field)); } - return absl::nullopt; + return std::nullopt; } // Standard implementation for listing fields. @@ -67,7 +67,7 @@ ListStructTypeFieldsDirectly( const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } std::vector extensions; @@ -100,7 +100,7 @@ DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { if (enum_descriptor != nullptr) { return Type::Enum(enum_descriptor); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> @@ -112,7 +112,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = enum_descriptor->FindValueByName(value); if (enum_value_descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } return EnumConstant{ .type = Type::Enum(enum_descriptor), @@ -121,7 +121,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( .number = enum_value_descriptor->number(), }; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> @@ -134,7 +134,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { - return absl::nullopt; + return std::nullopt; } if (auto it = field_table->json_name_map.find(name); @@ -147,7 +147,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( return field_table->fields[it->second].field; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr< @@ -160,7 +160,7 @@ DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { - return absl::nullopt; + return std::nullopt; } std::vector fields; fields.reserve(field_table->non_extensions.size()); diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc index 456798744..db766b347 100644 --- a/checker/internal/descriptor_pool_type_introspector_test.cc +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -47,7 +47,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, FindType) { "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); EXPECT_THAT(introspector.FindType("non.existent.Type"), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { @@ -84,7 +84,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); - EXPECT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { @@ -108,7 +108,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); - ASSERT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); + ASSERT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, @@ -168,7 +168,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { internal::GetTestingDescriptorPool()); auto fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.SomeOtherType"); - EXPECT_THAT(fields, IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(fields, IsOkAndHolds(Eq(std::nullopt))); } } // namespace diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 47487220c..8dc83518d 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -58,7 +58,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return type; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupEnumConstant( @@ -75,7 +75,7 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return decl; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupTypeConstant( @@ -92,14 +92,14 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return LookupEnumConstant(enum_name_candidate, value_name_candidate); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { if (proto_type_mask_registry_ != nullptr && !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { - return absl::nullopt; + return std::nullopt; } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the @@ -113,7 +113,7 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( return field; } } - return absl::nullopt; + return std::nullopt; } const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index f0332b999..4289fb528 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -198,7 +198,7 @@ std::optional FilterDecl(FunctionDecl decl, } } if (filtered.overloads().empty()) { - return absl::nullopt; + return std::nullopt; } filtered.set_name(std::move(name)); return filtered; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 913e704ee..fa7f80960 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -349,7 +349,7 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } - return absl::nullopt; + return std::nullopt; } }; @@ -370,7 +370,7 @@ TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } - return absl::nullopt; + return std::nullopt; } }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index f3a06a28d..2bc05dbf7 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1098,11 +1098,11 @@ std::optional ResolveVisitor::CheckFieldType(int64_t id, auto field_info = env_->LookupStructField(struct_type.name(), field); if (!field_info.ok()) { status_.Update(field_info.status()); - return absl::nullopt; + return std::nullopt; } if (!field_info->has_value()) { ReportUndefinedField(id, field, struct_type.name()); - return absl::nullopt; + return std::nullopt; } auto type = field_info->value().GetType(); if (type.kind() == TypeKind::kEnum) { @@ -1134,7 +1134,7 @@ std::optional ResolveVisitor::CheckFieldType(int64_t id, "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), "' cannot be the operand of a select operation"))); - return absl::nullopt; + return std::nullopt; } void ResolveVisitor::ResolveSelectOperation(const Expr& expr, diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 4681784af..4f738b804 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -149,7 +149,7 @@ std::optional WrapperToPrimitive(const Type& t) { case TypeKind::kUintWrapper: return UintType(); default: - return absl::nullopt; + return std::nullopt; } } @@ -579,7 +579,7 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, } if (!result_type.has_value() || matching_overloads.empty()) { - return absl::nullopt; + return std::nullopt; } return OverloadResolution{ .result_type = FullySubstitute(*result_type, /*free_to_dyn=*/false), From 8b7068abb4062074a135491ad8357139287084f9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 05:56:44 -0700 Subject: [PATCH 122/143] No public description PiperOrigin-RevId: 933675942 --- internal/message_equality_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index 318138d9b..092edd71b 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -399,7 +399,7 @@ absl::optional, PackTestAllTypesProto3Field(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field) { if (field->is_map()) { - return absl::nullopt; + return std::nullopt; } if (field->is_repeated() && field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -425,7 +425,7 @@ PackTestAllTypesProto3Field(const google::protobuf::Message& message, cel::to_address(packed), any_field)); return std::pair{packed, any_field}; } - return absl::nullopt; + return std::nullopt; } TEST_P(UnaryMessageFieldEqualsTest, Equals) { From 071fca088461260c8667b075d2c3c2a78d7b1fd8 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Wed, 17 Jun 2026 11:03:40 -0700 Subject: [PATCH 123/143] Add -fexceptions to the antlr gencode to avoid breaking users building with -fno-exceptions PiperOrigin-RevId: 933819478 --- bazel/antlr.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index 2abbb6dbd..a4d28cdf8 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -55,6 +55,7 @@ def antlr_cc_library(name, src, package): generated, "@antlr4-cpp-runtime//:antlr4-cpp-runtime", ], + copts = ["-fexceptions"], linkstatic = 1, ) From e2ed5270e50f90a15e1008a18662cefbac303650 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 17 Jun 2026 14:08:48 -0700 Subject: [PATCH 124/143] Internal change PiperOrigin-RevId: 933914348 --- checker/internal/type_checker_impl.cc | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 2bc05dbf7..9c2806e89 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -199,6 +199,7 @@ class ResolveVisitor : public AstVisitorBase { struct AttributeResolution { const VariableDecl* decl; bool requires_disambiguation; + bool local; }; ResolveVisitor(NamespaceGenerator namespace_generator, @@ -1001,7 +1002,7 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, const VariableDecl* local_decl = LookupLocalIdentifier(name); if (local_decl != nullptr && !absl::StartsWith(name, ".")) { - attributes_[&expr] = {local_decl, false}; + attributes_[&expr] = {local_decl, false, /*local=*/true}; types_[&expr] = inference_context_->InstantiateTypeParams(local_decl->type()); return; @@ -1016,8 +1017,13 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, }); if (decl != nullptr) { - attributes_[&expr] = {decl, - /* requires_disambiguation= */ local_decl != nullptr}; + attributes_[&expr] = { + decl, + /* requires_disambiguation= */ local_decl != nullptr, + // There is some oddity here, `.` prefixed idents should never be local. + // So LookupLocalIdentifier above should never return a valid decl. + // Perhaps this is a refactor holdover? + /*local=*/false}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } @@ -1072,9 +1078,14 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = {decl, - /* requires_disambiguation= */ decl != local_decl && - local_decl != nullptr}; + attributes_[root] = { + decl, + /* requires_disambiguation= */ decl != local_decl && + local_decl != nullptr, + // There is some oddity here, `.` prefixed idents should never be local. + // So LookupLocalIdentifier above should never return a valid decl. + // Perhaps this is a refactor holdover? + /*local=*/local_decl == decl}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. From 95e4b10bcc0ee5d3bd99650dc2a95e8fa2fee99a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 21:27:58 -0700 Subject: [PATCH 125/143] No public description PiperOrigin-RevId: 934095048 --- codelab/network_functions.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index 64f199cb3..6cc1505a9 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -368,7 +368,7 @@ std::optional NetworkAddressRep::Unwrap( auto opaque = value.AsOpaque(); if (!opaque.has_value() || opaque->GetTypeId() != cel::TypeId()) { - return absl::nullopt; + return std::nullopt; } // Note: safety depends on: @@ -385,10 +385,10 @@ std::optional NetworkAddressRep::Parse( char ipv6[16]; auto version = ParseAddressImpl(str, &ipv4, ipv6); if (!version.ok()) { - return absl::nullopt; + return std::nullopt; } if (*version != IpVersion::kIPv4) { - return absl::nullopt; + return std::nullopt; } NetworkAddressRep rep; rep.version_ = *version; @@ -422,7 +422,7 @@ std::optional NetworkAddressMatcher::Parse( int dash_pos = str.find('-'); if (dash_pos == absl::string_view::npos) { // TODO(uncreated-issue/86): CIDR style addr/prefix-length - return absl::nullopt; + return std::nullopt; } absl::string_view min_str = str.substr(0, dash_pos); absl::string_view max_str = str.substr(dash_pos + 1); @@ -431,23 +431,23 @@ std::optional NetworkAddressMatcher::Parse( NetworkRangev6 v6; auto min_parse = ParseAddressImpl(min_str, &v4.min_incl, v6.min_incl); if (!min_parse.ok()) { - return absl::nullopt; + return std::nullopt; } auto max_parse = ParseAddressImpl(max_str, &v4.max_incl, v6.max_incl); if (!max_parse.ok()) { - return absl::nullopt; + return std::nullopt; } if (*min_parse != *max_parse) { - return absl::nullopt; + return std::nullopt; } NetworkAddressMatcher rep; if (*min_parse == IpVersion::kIPv4) { if (v4.min_incl > v4.max_incl) { - return absl::nullopt; + return std::nullopt; } rep.ranges_v4_.push_back(v4); } else if (*min_parse == IpVersion::kIPv6) { - return absl::nullopt; + return std::nullopt; } return rep; From 51faf9f1f2889bc1bbc582755ba37ed1d105f393 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 22:43:18 -0700 Subject: [PATCH 126/143] No public description PiperOrigin-RevId: 934123553 --- extensions/select_optimization.cc | 12 ++++++------ extensions/select_optimization_test.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 42cad0f92..0cc64311a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -158,11 +158,11 @@ std::optional GetSelectInstruction( absl::string_view field_name) { auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) - .value_or(absl::nullopt); + .value_or(std::nullopt); if (field_or.has_value()) { return SelectInstruction{field_or->number(), std::string(field_or->name())}; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr SelectQualifierFromList(const ListExpr& list) { @@ -410,7 +410,7 @@ class RewriterImpl : public AstRewriterBase { std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) - : absl::nullopt; + : std::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); std::optional field_or = @@ -540,7 +540,7 @@ class RewriterImpl : public AstRewriterBase { std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( - absl::nullopt); + std::nullopt); } void SetProgressStatus(const absl::Status& status) { @@ -600,7 +600,7 @@ class OptimizedSelectImpl { absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { - return absl::nullopt; + return std::nullopt; } if (frame.unknown_processing_enabled() && @@ -624,7 +624,7 @@ absl::StatusOr> CheckForMarkedAttributes( attribute_trail.attribute()); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr OptimizedSelectImpl::ApplySelect( diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index 9d4024098..c14c4d461 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -257,7 +257,7 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { std::optional< google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> FindFieldByName(absl::string_view field_name) const override { - return absl::nullopt; + return std::nullopt; } MOCK_METHOD(absl::StatusOr, GetField, From 80fc78aca1945cdf2653e6d888c5ff15058892c4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 00:27:29 -0700 Subject: [PATCH 127/143] No public description PiperOrigin-RevId: 934169077 --- base/operators.cc | 32 ++++++++++++++++---------------- base/operators_test.cc | 32 ++++++++++++++++---------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/base/operators.cc b/base/operators.cc index 805acc5a1..b7df40b27 100644 --- a/base/operators.cc +++ b/base/operators.cc @@ -179,13 +179,13 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) absl::optional Operator::FindByName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -193,13 +193,13 @@ absl::optional Operator::FindByName(absl::string_view input) { absl::optional Operator::FindByDisplayName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_display_name.cbegin(), operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == operators_by_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -208,13 +208,13 @@ absl::optional UnaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_name.cbegin(), unary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == unary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -223,14 +223,14 @@ absl::optional UnaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), unary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == unary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -239,13 +239,13 @@ absl::optional BinaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_name.cbegin(), binary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == binary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -254,14 +254,14 @@ absl::optional BinaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), binary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == binary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -270,13 +270,13 @@ absl::optional TernaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_name.cbegin(), ternary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == ternary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } @@ -285,14 +285,14 @@ absl::optional TernaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), ternary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == ternary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } diff --git a/base/operators_test.cc b/base/operators_test.cc index fdf95e7ae..6049f76c8 100644 --- a/base/operators_test.cc +++ b/base/operators_test.cc @@ -130,55 +130,55 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) TEST(Operator, FindByName) { EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); - EXPECT_THAT(Operator::FindByName("in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByName("in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByName(""), Eq(std::nullopt)); } TEST(Operator, FindByDisplayName) { EXPECT_THAT(Operator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByName) { EXPECT_THAT(UnaryOperator::FindByName("-_"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByDisplayName) { EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByName) { EXPECT_THAT(BinaryOperator::FindByName("_-_"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByDisplayName) { EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByName) { EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), Optional(Eq(TernaryOperator::Conditional()))); - EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByDisplayName) { - EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); } TEST(Operator, SupportsAbslHash) { From 8515127e730d338066a0f39d13317b0b4cebc32a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 03:35:24 -0700 Subject: [PATCH 128/143] No public description PiperOrigin-RevId: 934248225 --- tools/flatbuffers_backed_impl.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 10c0b1cb8..2ee226859 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -127,7 +127,7 @@ class ObjectStringIndexedMapImpl : public CelMap { arena_, **it, schema_, object_, arena_)); } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr ListKeys() const override { return &keys_; } @@ -188,7 +188,7 @@ absl::optional FlatBuffersMapImpl::operator[]( } auto field = keys_.fields->LookupByKey(cel_key.StringOrDie().value().data()); if (field == nullptr) { - return absl::nullopt; + return std::nullopt; } switch (field->type()->base_type()) { case reflection::Byte: @@ -323,15 +323,15 @@ absl::optional FlatBuffersMapImpl::operator[]( } default: // Unsupported vector base types - return absl::nullopt; + return std::nullopt; } break; } default: // Unsupported types: enums, unions, arrays - return absl::nullopt; + return std::nullopt; } - return absl::nullopt; + return std::nullopt; } const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, From 3244f4a55e713555fc0f40ecfea5e38b5e3ad492 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 03:35:53 -0700 Subject: [PATCH 129/143] No public description PiperOrigin-RevId: 934248421 --- testutil/test_macros.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc index 672439dc5..19e9a4844 100644 --- a/testutil/test_macros.cc +++ b/testutil/test_macros.cc @@ -40,7 +40,7 @@ bool IsCelNamespace(const Expr& target) { std::optional CelBlockMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& bindings_arg = args[0]; if (!bindings_arg.has_list_expr()) { @@ -53,7 +53,7 @@ std::optional CelBlockMacroExpander(MacroExprFactory& factory, std::optional CelIndexMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& index_arg = args[0]; if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { @@ -72,7 +72,7 @@ std::optional CelIterVarMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || @@ -96,7 +96,7 @@ std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || From 3456ec81c08ec5ce3980c75b9522e917683e58ed Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 05:34:35 -0700 Subject: [PATCH 130/143] No public description PiperOrigin-RevId: 934294071 --- extensions/protobuf/bind_proto_to_activation_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc index fd79508ac..680b4b353 100644 --- a/extensions/protobuf/bind_proto_to_activation_test.cc +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -76,10 +76,10 @@ TEST_F(BindProtoToActivationTest, BindProtoToActivationSkip) { EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationDefault) { From 774340d632ae13dc74c64f2eade30d8209686089 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 06:50:51 -0700 Subject: [PATCH 131/143] No public description PiperOrigin-RevId: 934323942 --- parser/macro_registry.cc | 4 ++-- parser/macro_registry_test.cc | 4 ++-- parser/parser.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc index 3a816b10e..d36761e87 100644 --- a/parser/macro_registry.cc +++ b/parser/macro_registry.cc @@ -55,7 +55,7 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, bool receiver_style) const { // :: if (name.empty() || absl::StrContains(name, ':')) { - return absl::nullopt; + return std::nullopt; } // Try argument count specific key first. auto key = absl::StrCat(name, ":", arg_count, ":", @@ -68,7 +68,7 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, if (auto it = macros_.find(key); it != macros_.end()) { return it->second; } - return absl::nullopt; + return std::nullopt; } std::vector MacroRegistry::ListMacros() const { diff --git a/parser/macro_registry_test.cc b/parser/macro_registry_test.cc index 9e6da87a4..db8a99ab2 100644 --- a/parser/macro_registry_test.cc +++ b/parser/macro_registry_test.cc @@ -30,14 +30,14 @@ using ::testing::Ne; TEST(MacroRegistry, RegisterAndFind) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacro(HasMacro()), IsOk()); - EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(std::nullopt)); } TEST(MacroRegistry, RegisterRollsback) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacros({HasMacro(), AllMacro(), AllMacro()}), StatusIs(absl::StatusCode::kAlreadyExists)); - EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(std::nullopt)); } } // namespace diff --git a/parser/parser.cc b/parser/parser.cc index a858337a4..24b4ca079 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1468,11 +1468,11 @@ Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, } } factory_.BeginMacro(factory_.GetSourceRange(expr_id)); - auto expr = macro->Expand(factory_, absl::nullopt, absl::MakeSpan(args)); + auto expr = macro->Expand(factory_, std::nullopt, absl::MakeSpan(args)); factory_.EndMacro(); if (expr) { if (add_macro_calls_) { - factory_.AddMacroCall(expr->id(), function, absl::nullopt, + factory_.AddMacroCall(expr->id(), function, std::nullopt, std::move(macro_args)); } // We did not end up using `expr_id`. Delete metadata. From 83ccadfe7e0d284956332f3b90804c23b6ef3caf Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 18 Jun 2026 15:33:26 -0700 Subject: [PATCH 132/143] Refactor local variable resolution in checker. No functional changes, but makes intent a little clearer w.r.t. checking if we need to preserve a leading '.' or not. PiperOrigin-RevId: 934579650 --- checker/internal/type_checker_impl.cc | 97 ++++++++++++++------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 9c2806e89..bca187417 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -960,11 +960,10 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( absl::string_view name) { - // Note: if we see a leading dot, this shouldn't resolve to a local variable, - // but we need to check whether we need to disambiguate against a global in - // the reference map. if (absl::StartsWith(name, ".")) { - name = name.substr(1); + // Should not happen for normally parsed CEL, but prevent lookup in case + // of hand-crafted ASTs. + return nullptr; } return current_scope_->LookupLocalVariable(name); } @@ -999,13 +998,15 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(name); - - if (local_decl != nullptr && !absl::StartsWith(name, ".")) { - attributes_[&expr] = {local_decl, false, /*local=*/true}; - types_[&expr] = - inference_context_->InstantiateTypeParams(local_decl->type()); - return; + if (!absl::StartsWith(name, ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(name); + if (local_decl != nullptr) { + attributes_[&expr] = {local_decl, /*requires_disambiguation=*/false, + /*local=*/true}; + types_[&expr] = + inference_context_->InstantiateTypeParams(local_decl->type()); + return; + } } const VariableDecl* decl = nullptr; @@ -1016,14 +1017,13 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, return decl == nullptr; }); + bool requires_disambiguation = false; + if (absl::StartsWith(name, ".")) { + requires_disambiguation = LookupLocalIdentifier(name.substr(1)) != nullptr; + } + if (decl != nullptr) { - attributes_[&expr] = { - decl, - /* requires_disambiguation= */ local_decl != nullptr, - // There is some oddity here, `.` prefixed idents should never be local. - // So LookupLocalIdentifier above should never return a valid decl. - // Perhaps this is a refactor holdover? - /*local=*/false}; + attributes_[&expr] = {decl, requires_disambiguation, /*local=*/false}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } @@ -1039,35 +1039,49 @@ void ResolveVisitor::ResolveQualifiedIdentifier( return; } + int matched_segment_index = -1; + const VariableDecl* decl = nullptr; + bool requires_disambiguation = false; + bool is_local = false; // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); - const VariableDecl* decl = nullptr; - - int matched_segment_index = -1; - - if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) { - decl = local_decl; - matched_segment_index = 0; - } else { - namespace_generator_.GenerateCandidates( - qualifiers, [&decl, &matched_segment_index, this]( - absl::string_view candidate, int segment_index) { - decl = LookupGlobalIdentifier(candidate); - if (decl != nullptr) { - matched_segment_index = segment_index; - return false; - } - return true; - }); + if (!absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); + if (local_decl != nullptr) { + decl = local_decl; + matched_segment_index = 0; + is_local = true; + goto resolve_select_trail; + } } + namespace_generator_.GenerateCandidates( + qualifiers, [&decl, &matched_segment_index, this]( + absl::string_view candidate, int segment_index) { + decl = LookupGlobalIdentifier(candidate); + if (decl != nullptr) { + matched_segment_index = segment_index; + return false; + } + return true; + }); + if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); types_[&expr] = ErrorType(); return; } + if (absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = + LookupLocalIdentifier(qualifiers[0].substr(1)); + if (local_decl != nullptr) { + requires_disambiguation = true; + } + } + +resolve_select_trail: + const int num_select_opts = qualifiers.size() - matched_segment_index - 1; const Expr* root = &expr; @@ -1078,14 +1092,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = { - decl, - /* requires_disambiguation= */ decl != local_decl && - local_decl != nullptr, - // There is some oddity here, `.` prefixed idents should never be local. - // So LookupLocalIdentifier above should never return a valid decl. - // Perhaps this is a refactor holdover? - /*local=*/local_decl == decl}; + attributes_[root] = {decl, requires_disambiguation, is_local}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. From ef57455367567055465d8d86032c54fb9db27aff Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 19:41:49 -0700 Subject: [PATCH 133/143] No public description PiperOrigin-RevId: 934677950 --- extensions/bindings_ext.cc | 2 +- extensions/lists_functions.cc | 2 +- extensions/math_ext_macros.cc | 6 +++--- extensions/math_ext_test.cc | 4 ++-- extensions/proto_ext.cc | 12 ++++++------ 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index c59f724bd..4823c077c 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -73,7 +73,7 @@ std::vector bindings_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span args) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } if (!args[0].has_ident_expr()) { return factory.ReportErrorAt( diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index 10bc717ed..bfe05d887 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -454,7 +454,7 @@ Macro ListSortByMacro() { MakeMapComprehension(factory, factory.Copy(sortby_input_ident), std::move(key_ident), std::move(key_expr)); if (!map_compr.has_value()) { - return absl::nullopt; + return std::nullopt; } // Build the call expression: diff --git a/extensions/math_ext_macros.cc b/extensions/math_ext_macros.cc index a66720a60..08b163132 100644 --- a/extensions/math_ext_macros.cc +++ b/extensions/math_ext_macros.cc @@ -72,7 +72,7 @@ absl::optional CheckInvalidArgs(MacroExprFactory &factory, } } - return absl::nullopt; + return std::nullopt; } bool IsListLiteralWithValidArgs(const Expr &arg) { @@ -99,7 +99,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { @@ -143,7 +143,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index ea9331970..ce05ae6ed 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -93,7 +93,7 @@ TestCase MinCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MinCase(CelValue list, CelValue result) { - return TestCase{kMathMin, list, absl::nullopt, result}; + return TestCase{kMathMin, list, std::nullopt, result}; } TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { @@ -101,7 +101,7 @@ TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MaxCase(CelValue list, CelValue result) { - return TestCase{kMathMax, list, absl::nullopt, result}; + return TestCase{kMathMax, list, std::nullopt, result}; } struct MacroTestCase { diff --git a/extensions/proto_ext.cc b/extensions/proto_ext.cc index f38039002..48618f7ae 100644 --- a/extensions/proto_ext.cc +++ b/extensions/proto_ext.cc @@ -45,11 +45,11 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { absl::Overload( [](const SelectExpr& select_expr) -> absl::optional { if (select_expr.test_only()) { - return absl::nullopt; + return std::nullopt; } auto op_name = ValidateExtensionIdentifier(select_expr.operand()); if (!op_name.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*op_name, ".", select_expr.field()); }, @@ -57,7 +57,7 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { return ident_expr.name(); }, [](const auto&) -> absl::optional { - return absl::nullopt; + return std::nullopt; }), expr.kind()); } @@ -68,7 +68,7 @@ absl::optional GetExtensionFieldName(const Expr& expr) { select_expr) { return ValidateExtensionIdentifier(expr); } - return absl::nullopt; + return std::nullopt; } bool IsExtensionCall(const Expr& target) { @@ -95,7 +95,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { @@ -109,7 +109,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { From 7641fac87020aae83f71cca82de1d7e759672b32 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 23:58:02 -0700 Subject: [PATCH 134/143] No public description PiperOrigin-RevId: 934760042 --- eval/eval/attribute_utility.cc | 2 +- eval/eval/container_access_step.cc | 2 +- eval/eval/function_step.cc | 4 ++-- eval/eval/select_step.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 117516caf..1e044627e 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -123,7 +123,7 @@ absl::optional AttributeUtility::MergeUnknowns( } if (!result_set.has_value()) { - return absl::nullopt; + return std::nullopt; } return UnknownValue(cel::Unknown(result_set->unknown_attributes(), diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index fda51e34f..4cf4ebf4d 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -55,7 +55,7 @@ absl::optional CelNumberFromValue(const Value& value) { case ValueKind::kDouble: return Number::FromDouble(value.GetDouble().NativeValue()); default: - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index fcf429378..12c5af8a7 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -291,7 +291,7 @@ absl::StatusOr ResolveStatic( return overload; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr ResolveLazy( @@ -299,7 +299,7 @@ absl::StatusOr ResolveLazy( bool receiver_style, absl::Span providers, const ExecutionFrameBase& frame) { - ResolveResult result = absl::nullopt; + ResolveResult result = std::nullopt; std::vector arg_types(input_args.size()); diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 420f3ac31..b815f5d87 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -72,7 +72,7 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, return cel::ErrorValue(std::move(result).status()); } - return absl::nullopt; + return std::nullopt; } void TestOnlySelect(const StructValue& msg, const std::string& field, From aea9b2adcddea6595548b65ad4e8a70d9e4de04c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 19 Jun 2026 03:50:29 -0700 Subject: [PATCH 135/143] No public description PiperOrigin-RevId: 934844207 --- eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/flat_expr_builder_extensions.cc | 6 +++--- eval/compiler/qualified_reference_resolver.cc | 10 +++++----- eval/compiler/regex_precompilation_optimization.cc | 4 ++-- eval/compiler/resolver.cc | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index fc6d87b16..aa9a8858c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1078,7 +1078,7 @@ class FlatExprVisitor : public cel::AstVisitor { // eligible for recursion, or nullopt if it is not. std::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { - return absl::nullopt; + return std::nullopt; } return program_builder_.current()->RecursiveDependencyDepth(); } diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index e51b64023..ee106ff4a 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -102,15 +102,15 @@ std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { - return absl::nullopt; + return std::nullopt; } for (const auto& element : *tree) { auto* subexpression = absl::get_if(&element); if (subexpression == nullptr) { - return absl::nullopt; + return std::nullopt; } if (!(*subexpression)->IsRecursive()) { - return absl::nullopt; + return std::nullopt; } depth = std::max(depth, (*subexpression)->recursive_program().depth); } diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 67c14d9b2..158e492be 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -99,7 +99,7 @@ std::optional BestOverloadMatch(const Resolver& resolver, return *name; } } - return absl::nullopt; + return std::nullopt; } // Rewriter visitor for resolving references. @@ -267,22 +267,22 @@ class ReferenceResolver : public cel::AstRewriterBase { if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. - return absl::nullopt; + return std::nullopt; } if (expr.has_ident_expr()) { return expr.ident_expr().name(); } else if (expr.has_select_expr()) { if (expr.select_expr().test_only()) { - return absl::nullopt; + return std::nullopt; } maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); if (!maybe_parent_namespace.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*maybe_parent_namespace, ".", expr.select_expr().field()); } else { - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index 455796131..38ef842b9 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -178,7 +178,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { if (subexpression == nullptr || subexpression->IsFlattened()) { // Already modified, can't recover the input pattern. - return absl::nullopt; + return std::nullopt; } std::optional constant; if (subexpression->IsRecursive()) { @@ -206,7 +206,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { return Cast(*constant).ToString(); } - return absl::nullopt; + return std::nullopt; } absl::Status RewritePlan( diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 17f60eaad..cca72964a 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -128,7 +128,7 @@ std::optional Resolver::FindConstant(absl::string_view name, return TypeValue(**type_value); } } - return absl::nullopt; + return std::nullopt; } std::vector Resolver::FindOverloads( @@ -216,7 +216,7 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } } - return absl::nullopt; + return std::nullopt; } } // namespace google::api::expr::runtime From e54b8bd8322454b485852227512a6b458a26f4a8 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 19 Jun 2026 11:29:23 -0700 Subject: [PATCH 136/143] Actually reject non-simple variable names in comprehensions PiperOrigin-RevId: 934996645 --- extensions/comprehensions_v2_macros.cc | 43 +++++++++++++++----------- parser/macro.cc | 22 ++++++++----- parser/parser_test.cc | 21 +++++++++++++ 3 files changed, 60 insertions(+), 26 deletions(-) diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index 134fb80ff..a054626f9 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -14,12 +14,14 @@ #include "extensions/comprehensions_v2_macros.h" +#include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -38,16 +40,21 @@ namespace { using ::google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("all() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "all() second variable name must be a simple identifier"); } @@ -89,11 +96,11 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("exists() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "exists() second variable name must be a simple identifier"); } @@ -138,11 +145,11 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("existsOne() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "existsOne() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "existsOne() second variable name must be a simple identifier"); @@ -190,12 +197,12 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformList() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -239,12 +246,12 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformList() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -290,12 +297,12 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMap() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -338,12 +345,12 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMap() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -388,12 +395,12 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMapEntry() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); @@ -438,12 +445,12 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMapEntry() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); diff --git a/parser/macro.cc b/parser/macro.cc index 8f8c9e596..815b07401 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -25,6 +25,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -40,6 +41,11 @@ namespace { using google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + inline MacroExpander ToMacroExpander(GlobalMacroExpander expander) { ABSL_DCHECK(expander); return [expander = std::move(expander)]( @@ -87,7 +93,7 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("all() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } @@ -119,7 +125,7 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("exists() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } @@ -153,7 +159,7 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("exists_one() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } @@ -192,7 +198,7 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("map() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } @@ -225,7 +231,7 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("map() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } @@ -260,7 +266,7 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("filter() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } @@ -298,7 +304,7 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("optMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } @@ -337,7 +343,7 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("optFlatMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 1add80f84..35f11b413 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -631,6 +631,27 @@ std::vector test_cases = { "ERROR: :1:7: all() variable name must be a simple identifier\n" " | 1.all(2, 3)\n" " | ......^"}, + {"[].all(.x, x)", "", + "ERROR: :1:9: all() variable name must be a simple identifier\n" + " | [].all(.x, x)\n" + " | ........^"}, + {"[].exists(.x, x)", "", + "ERROR: :1:12: exists() variable name must be a simple identifier\n" + " | [].exists(.x, x)\n" + " | ...........^"}, + {"[].exists_one(.x, x)", "", + "ERROR: :1:16: exists_one() variable name must be a simple " + "identifier\n" + " | [].exists_one(.x, x)\n" + " | ...............^"}, + {"[].map(.x, x, x)", "", + "ERROR: :1:9: map() variable name must be a simple identifier\n" + " | [].map(.x, x, x)\n" + " | ........^"}, + {"[].filter(.x, x)", "", + "ERROR: :1:12: filter() variable name must be a simple identifier\n" + " | [].filter(.x, x)\n" + " | ...........^"}, {"x[\"a\"].single_int32 == 23", "_==_(\n" " _[_](\n" From 16074dcc55fddb53294aacd34afddd3db4a3138a Mon Sep 17 00:00:00 2001 From: Clayton Knittel Date: Sun, 21 Jun 2026 00:27:48 -0700 Subject: [PATCH 137/143] No public description PiperOrigin-RevId: 935529310 --- eval/public/message_wrapper_test.cc | 3 +-- .../structs/proto_message_type_adapter_test.cc | 4 +++- eval/public/testing/matchers.cc | 4 +--- internal/proto_matchers.h | 13 ++++++------- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc index ff0e691ab..15e5e88da 100644 --- a/eval/public/message_wrapper_test.cc +++ b/eval/public/message_wrapper_test.cc @@ -18,7 +18,6 @@ #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -60,7 +59,7 @@ TEST(MessageWrapperBuilder, Builder) { static_cast(&test_message)); auto mutable_message = - cel::internal::down_cast(builder.message_ptr()); + google::protobuf::DownCastMessage(builder.message_ptr()); mutable_message->set_int64_value(20); mutable_message->set_double_value(12.3); diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 32608bc3f..270fd3ce1 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -36,6 +36,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" namespace google::api::expr::runtime { namespace { @@ -725,7 +726,8 @@ TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, api->NewInstance(manager)); - EXPECT_NE(dynamic_cast(builder.message_ptr()), nullptr); + EXPECT_NE(google::protobuf::DynamicCastMessage(builder.message_ptr()), + nullptr); } TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index f79071fce..4f728c730 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -7,7 +7,6 @@ #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" #include "eval/public/set_util.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" @@ -76,8 +75,7 @@ class CelValueMatcherImpl CelValue::MessageWrapper arg; return v.GetValue(&arg) && arg.HasFullProto() && underlying_type_matcher_.Matches( - cel::internal::down_cast( - arg.message_ptr())); + google::protobuf::DownCastMessage(arg.message_ptr())); } void DescribeTo(std::ostream* os) const override { diff --git a/internal/proto_matchers.h b/internal/proto_matchers.h index 76d844036..02250634b 100644 --- a/internal/proto_matchers.h +++ b/internal/proto_matchers.h @@ -21,7 +21,6 @@ #include "absl/log/absl_check.h" #include "absl/memory/memory.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -43,13 +42,13 @@ class TextProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } @@ -58,7 +57,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p.New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(p)); + *message, google::protobuf::DownCastMessage(p)); } bool MatchAndExplain(const google::protobuf::Message* p, @@ -66,7 +65,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p->New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(*p)); + *message, google::protobuf::DownCastMessage(*p)); } inline void DescribeTo(::std::ostream* os) const { *os << expected_; } @@ -93,13 +92,13 @@ class ProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } From a32e1418a8385c4b9c1ebe3b5072aaa8fbdf23ce Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 22 Jun 2026 11:42:04 -0700 Subject: [PATCH 138/143] When generating YAML from config, skip overload_id if identical to signature PiperOrigin-RevId: 936168430 --- env/env_yaml.cc | 21 +++++++++++++------ env/env_yaml_test.cc | 50 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index e7b8a7885..d5e3ad059 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -725,6 +725,9 @@ absl::StatusOr ParseFunctionOverloadConfig( function_name, "\"")); } overload_config.is_member_function = parsed_signature.is_member; + if (overload_config.overload_id.empty()) { + overload_config.overload_id = signature; + } if (!parsed_signature.signature_type.has_function()) { return absl::InternalError(absl::StrCat( "Function overload signature has no function type: ", signature)); @@ -1101,11 +1104,8 @@ void EmitFunctionOverloadConfig( const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, const EnvConfigToYamlOptions& options) { out << YAML::BeginMap; - if (!overload_config.overload_id.empty()) { - out << YAML::Key << "id"; - out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; - } bool signature_generated = false; + std::string signature_str; if (options.use_type_signatures) { bool param_type_spec_generated = true; std::vector params; @@ -1123,12 +1123,21 @@ void EmitFunctionOverloadConfig( common_internal::MakeOverloadSignature( function_name, params, overload_config.is_member_function); if (signature.ok()) { - out << YAML::Key << "signature"; - out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_str = std::move(*signature); signature_generated = true; } } } + if (!overload_config.overload_id.empty()) { + if (!signature_generated || overload_config.overload_id != signature_str) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + } + if (signature_generated) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << signature_str; + } if (!signature_generated) { if (overload_config.is_member_function) { out << YAML::Key << "target" << YAML::Value; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 38f08e371..c5bd1b787 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -545,12 +545,15 @@ std::vector GetParseFunctionTestCases() { .overload_configs = { Config::FunctionOverloadConfig{ + .overload_id = + "google.protobuf.StringValue.isEmpty()", .examples = {"''.isEmpty() // true"}, .is_member_function = true, .parameters = {{.name = "string_wrapper"}}, .return_type = {.name = "bool"}, }, Config::FunctionOverloadConfig{ + .overload_id = "list<~T>.isEmpty()", .examples = {"[].isEmpty() // true", "[1].isEmpty() // false"}, .is_member_function = true, @@ -635,6 +638,7 @@ std::vector GetParseFunctionTestCases() { .overload_configs = { Config::FunctionOverloadConfig{ + .overload_id = "contains(list<~T>, ~T)", .examples = {"contains([1, 2, 3], 2) // true"}, .is_member_function = false, .parameters = @@ -1740,6 +1744,45 @@ std::vector GetExportTestCases() { - type_name: "int" )yaml", }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "timestamp.foo(A<~B>)", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "timestamp.foo(A<~B>)" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, }; }; @@ -1888,6 +1931,13 @@ std::vector GetSignatureRoundTripTestCases() { signature: "foo(timestamp,A<~B>)" return: "list" )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", }; } From 76ae0b3c1768d93a10270f904101de338867bdb1 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 22 Jun 2026 12:33:00 -0700 Subject: [PATCH 139/143] Move signature.h and signature.cc from common/internal to common This is done to make it visible to cel/python PiperOrigin-RevId: 936194518 --- checker/BUILD | 2 +- checker/type_checker_subset_factory.cc | 10 +++---- common/BUILD | 38 ++++++++++++++++++++++++- common/decl.cc | 5 ++-- common/internal/BUILD | 36 ----------------------- common/{internal => }/signature.cc | 6 ++-- common/{internal => }/signature.h | 10 +++---- common/{internal => }/signature_test.cc | 33 ++++++++++----------- env/BUILD | 4 +-- env/env.cc | 8 ++---- env/env_yaml.cc | 25 +++++++--------- 11 files changed, 83 insertions(+), 94 deletions(-) rename common/{internal => }/signature.cc (99%) rename common/{internal => }/signature.h (93%) rename common/{internal => }/signature_test.cc (96%) diff --git a/checker/BUILD b/checker/BUILD index efca3ff73..7f3ccfef7 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -231,7 +231,7 @@ cc_library( deps = [ ":type_checker_builder", "//common:decl", - "//common/internal:signature", + "//common:signature", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index e5335e220..1b146c5a5 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -22,7 +22,7 @@ #include "absl/types/span.h" #include "checker/type_checker_builder.h" #include "common/decl.h" -#include "common/internal/signature.h" +#include "common/signature.h" namespace cel { @@ -33,8 +33,8 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( if (overload_ids.contains(overload.id())) { return true; } - auto signature = common_internal::MakeOverloadSignature( - function, overload.args(), overload.member()); + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); return signature.ok() && overload_ids.contains(*signature); }; } @@ -52,8 +52,8 @@ TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( if (overload_ids.contains(overload.id())) { return false; } - auto signature = common_internal::MakeOverloadSignature( - function, overload.args(), overload.member()); + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); return !signature.ok() || !overload_ids.contains(*signature); }; } diff --git a/common/BUILD b/common/BUILD index 0bd3632dd..0426c0827 100644 --- a/common/BUILD +++ b/common/BUILD @@ -79,6 +79,42 @@ cc_test( ], ) +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + ":ast", + ":type", + ":type_spec_resolver", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":ast", + ":signature", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], @@ -145,9 +181,9 @@ cc_library( hdrs = ["decl.h"], deps = [ ":constant", + ":signature", ":type", ":type_kind", - "//common/internal:signature", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", diff --git a/common/decl.cc b/common/decl.cc index d2d50964a..858e6fb49 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -26,7 +26,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "common/type.h" #include "common/type_kind.h" @@ -117,8 +117,7 @@ void AddOverloadInternal(std::string_view function_name, } absl::StatusOr signature = - common_internal::MakeOverloadSignature(function_name, overload.args(), - overload.member()); + MakeOverloadSignature(function_name, overload.args(), overload.member()); if (!signature.ok()) { status = signature.status(); return; diff --git a/common/internal/BUILD b/common/internal/BUILD index b07faf229..3be350754 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -135,39 +135,3 @@ cc_library( "@com_google_protobuf//src/google/protobuf/io", ], ) - -cc_library( - name = "signature", - srcs = ["signature.cc"], - hdrs = ["signature.h"], - deps = [ - "//common:ast", - "//common:type", - "//common:type_spec_resolver", - "//internal:status_macros", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "signature_test", - srcs = ["signature_test.cc"], - deps = [ - ":signature", - "//common:ast", - "//common:type", - "//common:type_kind", - "//common:type_spec_resolver", - "//internal:testing", - "//internal:testing_descriptor_pool", - "@com_google_absl//absl/base:no_destructor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:status_matchers", - "@com_google_absl//absl/status:statusor", - "@com_google_protobuf//:protobuf", - ], -) diff --git a/common/internal/signature.cc b/common/signature.cc similarity index 99% rename from common/internal/signature.cc rename to common/signature.cc index fe315bb04..e497e780d 100644 --- a/common/internal/signature.cc +++ b/common/signature.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "common/internal/signature.h" +#include "common/signature.h" #include #include @@ -34,7 +34,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" -namespace cel::common_internal { +namespace cel { // Signature generator helper functions. namespace { @@ -637,4 +637,4 @@ absl::StatusOr ParseType(std::string_view signature, google::protobuf::Are return cel::ConvertTypeSpecToType(type_spec, arena, pool); } -} // namespace cel::common_internal +} // namespace cel diff --git a/common/internal/signature.h b/common/signature.h similarity index 93% rename from common/internal/signature.h rename to common/signature.h index 8a44fbd5c..777f03439 100644 --- a/common/internal/signature.h +++ b/common/signature.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ #include #include @@ -25,7 +25,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" -namespace cel::common_internal { +namespace cel { // Generates a signature for a `cel::Type`, which is a string representation of // the type. @@ -96,6 +96,6 @@ struct ParsedFunctionOverload { absl::StatusOr ParseFunctionSignature( std::string_view signature); -} // namespace cel::common_internal +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/signature_test.cc similarity index 96% rename from common/internal/signature_test.cc rename to common/signature_test.cc index 17b628d88..ea51eb566 100644 --- a/common/internal/signature_test.cc +++ b/common/signature_test.cc @@ -1,4 +1,4 @@ -#include "common/internal/signature.h" +#include "common/signature.h" // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,7 +30,7 @@ #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" -namespace cel::common_internal { +namespace cel { namespace { using ::absl_testing::IsOkAndHolds; @@ -77,8 +77,7 @@ using TypeSignatureTest = testing::TestWithParam; TEST_P(TypeSignatureTest, TypeSignature) { const auto& param = GetParam(); - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(param.type); + absl::StatusOr signature = MakeTypeSpecSignature(param.type); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); @@ -257,25 +256,24 @@ INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); TEST(TypeSignatureTest, UnsupportedTypes) { - EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + EXPECT_THAT(MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported Type kind: *unknown*"))); - EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + EXPECT_THAT(MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported type in signature: *error*"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature( - TypeSpec(static_cast(999))), + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec(static_cast(999))), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported primitive type"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature( - TypeSpec(static_cast(999))), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Unsupported well-known type"))); + EXPECT_THAT( + MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature(TypeSpec( + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec( PrimitiveTypeWrapper(static_cast(999)))), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported wrapper type"))); @@ -308,8 +306,7 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { const auto& param = GetParam(); absl::StatusOr signature = - common_internal::MakeOverloadSignature(param.function_name, param.args, - param.is_member); + MakeOverloadSignature(param.function_name, param.args, param.is_member); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); @@ -433,8 +430,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature( - "hello", std::vector{}, true); + auto signature = + MakeOverloadSignature("hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -784,4 +781,4 @@ TEST(OverloadSignatureTest, ArgumentTypeVector) { } } // namespace -} // namespace cel::common_internal +} // namespace cel diff --git a/env/BUILD b/env/BUILD index 1816238a5..0c17d6305 100644 --- a/env/BUILD +++ b/env/BUILD @@ -55,8 +55,8 @@ cc_library( "//common:constant", "//common:container", "//common:decl", + "//common:signature", "//common:type", - "//common/internal:signature", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", @@ -124,7 +124,7 @@ cc_library( ":config", "//common:ast", "//common:constant", - "//common/internal:signature", + "//common:signature", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", diff --git a/env/env.cc b/env/env.cc index 4fa4e7398..85c5139da 100644 --- a/env/env.cc +++ b/env/env.cc @@ -26,7 +26,7 @@ #include "common/constant.h" #include "common/container.h" #include "common/decl.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" @@ -71,8 +71,7 @@ bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, return false; } absl::StatusOr signature = - common_internal::MakeOverloadSignature(function, overload.args(), - overload.member()); + MakeOverloadSignature(function, overload.args(), overload.member()); if (signature.ok() && config.excluded_functions.contains(std::make_pair( std::string(function), *std::move(signature)))) { return false; @@ -89,8 +88,7 @@ bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, // Ok to call MakeOverloadSignature() again, because in practice either // included or excluded functions may be specified, but not both. absl::StatusOr signature = - common_internal::MakeOverloadSignature(function, overload.args(), - overload.member()); + MakeOverloadSignature(function, overload.args(), overload.member()); if (signature.ok() && config.included_functions.contains(std::make_pair( std::string(function), *std::move(signature)))) { return true; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index d5e3ad059..281cf3ff1 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -38,7 +38,7 @@ #include "absl/time/time.h" #include "common/ast.h" #include "common/constant.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "env/config.h" #include "env/type_info.h" #include "internal/status_macros.h" @@ -434,8 +434,7 @@ absl::StatusOr ParseTypeInfo(const YAML::Node& node, if (!type.IsScalar()) { return YamlError(yaml, type, "Node 'type' is not a string"); } - CEL_ASSIGN_OR_RETURN(auto type_spec, - common_internal::ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(GetString(yaml, type))); CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); return type_config; } @@ -714,9 +713,8 @@ absl::StatusOr ParseFunctionOverloadConfig( } std::string signature = GetString(yaml, signature_node); - CEL_ASSIGN_OR_RETURN( - common_internal::ParsedFunctionOverload parsed_signature, - common_internal::ParseFunctionSignature(signature)); + CEL_ASSIGN_OR_RETURN(ParsedFunctionOverload parsed_signature, + ParseFunctionSignature(signature)); if (parsed_signature.function_name != function_name) { return YamlError(yaml, signature_node, absl::StrCat("Function overload name \"", @@ -767,8 +765,8 @@ absl::StatusOr ParseFunctionOverloadConfig( const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { if (return_type.IsScalar()) { - CEL_ASSIGN_OR_RETURN(auto type_spec, common_internal::ParseTypeSpec( - GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(auto type_spec, + ParseTypeSpec(GetString(yaml, return_type))); CEL_ASSIGN_OR_RETURN(overload_config.return_type, TypeSpecToTypeInfo(type_spec)); } else if (return_type.IsMap()) { @@ -993,8 +991,7 @@ void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, if (options.use_type_signatures) { absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); if (type_spec.ok()) { - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(*type_spec); + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); if (signature.ok()) { out << YAML::Key << "type"; out << YAML::Value << YAML::DoubleQuoted << *signature; @@ -1119,9 +1116,8 @@ void EmitFunctionOverloadConfig( params.push_back(std::move(*type_spec)); } if (param_type_spec_generated) { - absl::StatusOr signature = - common_internal::MakeOverloadSignature( - function_name, params, overload_config.is_member_function); + absl::StatusOr signature = MakeOverloadSignature( + function_name, params, overload_config.is_member_function); if (signature.ok()) { signature_str = std::move(*signature); signature_generated = true; @@ -1177,8 +1173,7 @@ void EmitFunctionOverloadConfig( absl::StatusOr type_spec = TypeInfoToTypeSpec(overload_config.return_type); if (type_spec.ok()) { - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(*type_spec); + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); if (signature.ok()) { out << YAML::Key << "return"; out << YAML::Value << YAML::DoubleQuoted << *signature; From 452235ff20435bd4eaa753e62322cb09b0f365a9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 23 Jun 2026 20:52:20 -0700 Subject: [PATCH 140/143] No public description PiperOrigin-RevId: 937065106 --- eval/public/structs/cel_proto_wrap_util.cc | 4 ++-- eval/public/structs/cel_proto_wrapper.cc | 2 +- eval/public/structs/legacy_type_provider.cc | 14 +++++++------- eval/public/structs/legacy_type_provider_test.cc | 14 +++++++------- eval/public/structs/proto_message_type_adapter.cc | 4 ++-- .../structs/proto_message_type_adapter_test.cc | 2 +- .../structs/protobuf_descriptor_type_provider.cc | 4 ++-- .../structs/trivial_legacy_type_info_test.cc | 4 ++-- 8 files changed, 24 insertions(+), 24 deletions(-) diff --git a/eval/public/structs/cel_proto_wrap_util.cc b/eval/public/structs/cel_proto_wrap_util.cc index d0f80171f..7bfe81fe6 100644 --- a/eval/public/structs/cel_proto_wrap_util.cc +++ b/eval/public/structs/cel_proto_wrap_util.cc @@ -691,7 +691,7 @@ class ValueFromMessageMaker { return CreateWellknownTypeValue(message, factory, arena); // WELLKNOWNTYPE_FIELDMASK has no special CelValue type default: - return absl::nullopt; + return std::nullopt; } } @@ -716,7 +716,7 @@ absl::optional DynamicMap::operator[](CelValue key) const { auto it = values_->fields().find(std::string(str_key.value())); if (it == values_->fields().end()) { - return absl::nullopt; + return std::nullopt; } return ValueManager(factory_, arena_).ValueFromMessage(&it->second); diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index a1dc83ade..6fad6aee3 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -53,7 +53,7 @@ absl::optional CelProtoWrapper::MaybeWrapValue( if (msg != nullptr) { return InternalWrapMessage(msg); } else { - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index f87ab9645..f8db92298 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -63,7 +63,7 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr> SetFieldByNumber( @@ -76,7 +76,7 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr Build() && override { @@ -116,7 +116,7 @@ class LegacyValueBuilder final : public cel::ValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetField( name, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr> SetFieldByNumber( @@ -129,7 +129,7 @@ class LegacyValueBuilder final : public cel::ValueBuilder { CEL_RETURN_IF_ERROR(adapter_.mutation_apis()->SetFieldByNumber( number, legacy_value, memory_manager_, builder_)) .With(cel::ErrorValueReturn()); - return absl::nullopt; + return std::nullopt; } absl::StatusOr Build() && override { @@ -187,7 +187,7 @@ absl::StatusOr> LegacyTypeProvider::FindTypeImpl( return cel::common_internal::MakeBasicStructType( (*type_info)->GetTypename(MessageWrapper())); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> @@ -206,13 +206,13 @@ LegacyTypeProvider::FindStructTypeFieldByNameImpl( const auto* mutation_apis = (*type_info)->GetMutationApis(MessageWrapper()); if (mutation_apis == nullptr || !mutation_apis->DefinesField(name)) { - return absl::nullopt; + return std::nullopt; } return cel::common_internal::BasicStructTypeField(name, 0, cel::DynType{}); } } - return absl::nullopt; + return std::nullopt; } } // namespace google::api::expr::runtime diff --git a/eval/public/structs/legacy_type_provider_test.cc b/eval/public/structs/legacy_type_provider_test.cc index 160ac49f3..8de2aba01 100644 --- a/eval/public/structs/legacy_type_provider_test.cc +++ b/eval/public/structs/legacy_type_provider_test.cc @@ -28,7 +28,7 @@ class LegacyTypeProviderTestEmpty : public LegacyTypeProvider { public: absl::optional ProvideLegacyType( absl::string_view name) const override { - return absl::nullopt; + return std::nullopt; } }; @@ -60,14 +60,14 @@ class LegacyTypeProviderTestImpl : public LegacyTypeProvider { if (name == "test") { return LegacyTypeAdapter(nullptr, nullptr); } - return absl::nullopt; + return std::nullopt; } absl::optional ProvideLegacyTypeInfo( absl::string_view name) const override { if (name == "test") { return test_type_info_; } - return absl::nullopt; + return std::nullopt; } private: @@ -76,8 +76,8 @@ class LegacyTypeProviderTestImpl : public LegacyTypeProvider { TEST(LegacyTypeProviderTest, EmptyTypeProviderHasProvideTypeInfo) { LegacyTypeProviderTestEmpty provider; - EXPECT_EQ(provider.ProvideLegacyType("test"), absl::nullopt); - EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), absl::nullopt); + EXPECT_EQ(provider.ProvideLegacyType("test"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("test"), std::nullopt); } TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { @@ -85,8 +85,8 @@ TEST(LegacyTypeProviderTest, NonEmptyTypeProviderProvidesSomeTypes) { LegacyTypeProviderTestImpl provider(&test_type_info); EXPECT_TRUE(provider.ProvideLegacyType("test").has_value()); EXPECT_TRUE(provider.ProvideLegacyTypeInfo("test").has_value()); - EXPECT_EQ(provider.ProvideLegacyType("other"), absl::nullopt); - EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), absl::nullopt); + EXPECT_EQ(provider.ProvideLegacyType("other"), std::nullopt); + EXPECT_EQ(provider.ProvideLegacyTypeInfo("other"), std::nullopt); } } // namespace diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index 6a3417ba3..8c140c0c7 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -472,14 +472,14 @@ const LegacyTypeAccessApis* ProtoMessageTypeAdapter::GetAccessApis( absl::optional ProtoMessageTypeAdapter::FindFieldByName(absl::string_view field_name) const { if (descriptor_ == nullptr) { - return absl::nullopt; + return std::nullopt; } const google::protobuf::FieldDescriptor* field_descriptor = descriptor_->FindFieldByName(field_name); if (field_descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } return LegacyTypeInfoApis::FieldDescription{field_descriptor->number(), diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 270fd3ce1..e28d76102 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -710,7 +710,7 @@ TEST(ProtoMesssageTypeAdapter, FindFieldNotFound) { "google.api.expr.runtime.TestMessage"), google::protobuf::MessageFactory::generated_factory()); - EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), absl::nullopt); + EXPECT_EQ(adapter.FindFieldByName("foo_not_a_field"), std::nullopt); } TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { diff --git a/eval/public/structs/protobuf_descriptor_type_provider.cc b/eval/public/structs/protobuf_descriptor_type_provider.cc index 68b39c643..b5746523e 100644 --- a/eval/public/structs/protobuf_descriptor_type_provider.cc +++ b/eval/public/structs/protobuf_descriptor_type_provider.cc @@ -27,7 +27,7 @@ absl::optional ProtobufDescriptorProvider::ProvideLegacyType( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { - return absl::nullopt; + return std::nullopt; } // ProtoMessageTypeAdapter provides apis for both access and mutation. return LegacyTypeAdapter(result, result); @@ -38,7 +38,7 @@ ProtobufDescriptorProvider::ProvideLegacyTypeInfo( absl::string_view name) const { const ProtoMessageTypeAdapter* result = GetTypeAdapter(name); if (result == nullptr) { - return absl::nullopt; + return std::nullopt; } return result; } diff --git a/eval/public/structs/trivial_legacy_type_info_test.cc b/eval/public/structs/trivial_legacy_type_info_test.cc index 9b4840373..9cc6e4916 100644 --- a/eval/public/structs/trivial_legacy_type_info_test.cc +++ b/eval/public/structs/trivial_legacy_type_info_test.cc @@ -56,9 +56,9 @@ TEST(TrivialTypeInfo, FindFieldByName) { TrivialTypeInfo info; MessageWrapper wrapper; - EXPECT_EQ(info.FindFieldByName("foo"), absl::nullopt); + EXPECT_EQ(info.FindFieldByName("foo"), std::nullopt); EXPECT_EQ(TrivialTypeInfo::GetInstance()->FindFieldByName("foo"), - absl::nullopt); + std::nullopt); } } // namespace From b548c7de3cba492909928acec9fbc3542bb02788 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 23 Jun 2026 23:58:24 -0700 Subject: [PATCH 141/143] No public description PiperOrigin-RevId: 937135148 --- common/source.cc | 14 ++++---- common/source_test.cc | 28 ++++++++-------- common/type.cc | 8 ++--- common/type_introspector.cc | 20 ++++++------ common/type_proto.cc | 2 +- common/type_reflector_test.cc | 60 +++++++++++++++++------------------ 6 files changed, 66 insertions(+), 66 deletions(-) diff --git a/common/source.cc b/common/source.cc index 8c32ad6ba..5fa4cca0e 100644 --- a/common/source.cc +++ b/common/source.cc @@ -483,26 +483,26 @@ absl::optional Source::GetLocation( return SourceLocation{line_and_offset->first, position - line_and_offset->second}; } - return absl::nullopt; + return std::nullopt; } absl::optional Source::GetPosition( const SourceLocation& location) const { if (ABSL_PREDICT_FALSE(location.line < 1 || location.column < 0)) { - return absl::nullopt; + return std::nullopt; } if (auto position = FindLinePosition(location.line); ABSL_PREDICT_TRUE(position.has_value())) { return *position + location.column; } - return absl::nullopt; + return std::nullopt; } absl::optional Source::Snippet(int32_t line) const { auto content = this->content(); auto start = FindLinePosition(line); if (ABSL_PREDICT_FALSE(!start.has_value() || content.empty())) { - return absl::nullopt; + return std::nullopt; } auto end = FindLinePosition(line + 1); if (end.has_value()) { @@ -554,7 +554,7 @@ std::string Source::DisplayErrorLocation(SourceLocation location) const { absl::optional Source::FindLinePosition(int32_t line) const { if (ABSL_PREDICT_FALSE(line < 1)) { - return absl::nullopt; + return std::nullopt; } if (line == 1) { return SourcePosition{0}; @@ -563,13 +563,13 @@ absl::optional Source::FindLinePosition(int32_t line) const { if (ABSL_PREDICT_TRUE(line <= static_cast(line_offsets.size()))) { return line_offsets[static_cast(line - 2)]; } - return absl::nullopt; + return std::nullopt; } absl::optional> Source::FindLine( SourcePosition position) const { if (ABSL_PREDICT_FALSE(position < 0)) { - return absl::nullopt; + return std::nullopt; } int32_t line = 1; const auto line_offsets = this->line_offsets(); diff --git a/common/source_test.cc b/common/source_test.cc index 2a3b78893..30a2ce9b0 100644 --- a/common/source_test.cc +++ b/common/source_test.cc @@ -81,37 +81,37 @@ TEST(StringSource, PositionAndLocation) { Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); - EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); } TEST(StringSource, SnippetSingle) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello, world", "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); - EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); } TEST(StringSource, SnippetMulti) { ASSERT_OK_AND_ASSIGN(auto source, NewSource("hello\nworld\nmy\nbub\n", "four-line-test")); - EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); - EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); } TEST(CordSource, Description) { @@ -150,17 +150,17 @@ TEST(CordSource, PositionAndLocation) { Optional(Eq(SourceLocation{int32_t{1}, int32_t{2}}))); EXPECT_THAT(source->GetLocation(*end), Optional(Eq(SourceLocation{int32_t{3}, int32_t{2}}))); - EXPECT_THAT(source->GetLocation(-1), Eq(absl::nullopt)); + EXPECT_THAT(source->GetLocation(-1), Eq(std::nullopt)); EXPECT_THAT(source->content().ToString(*start, *end), Eq("d &&\n\t b.c.arg(10) &&\n\t ")); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{0}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{1}, int32_t{-1}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); EXPECT_THAT(source->GetPosition(SourceLocation{int32_t{4}, int32_t{0}}), - Eq(absl::nullopt)); + Eq(std::nullopt)); } TEST(CordSource, SnippetSingle) { @@ -168,7 +168,7 @@ TEST(CordSource, SnippetSingle) { NewSource(absl::Cord("hello, world"), "one-line-test")); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello, world"))); - EXPECT_THAT(source->Snippet(2), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(2), Eq(std::nullopt)); } TEST(CordSource, SnippetMulti) { @@ -176,13 +176,13 @@ TEST(CordSource, SnippetMulti) { auto source, NewSource(absl::Cord("hello\nworld\nmy\nbub\n"), "four-line-test")); - EXPECT_THAT(source->Snippet(0), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(0), Eq(std::nullopt)); EXPECT_THAT(source->Snippet(1), Optional(Eq("hello"))); EXPECT_THAT(source->Snippet(2), Optional(Eq("world"))); EXPECT_THAT(source->Snippet(3), Optional(Eq("my"))); EXPECT_THAT(source->Snippet(4), Optional(Eq("bub"))); EXPECT_THAT(source->Snippet(5), Optional(Eq(""))); - EXPECT_THAT(source->Snippet(6), Eq(absl::nullopt)); + EXPECT_THAT(source->Snippet(6), Eq(std::nullopt)); } TEST(Source, DisplayErrorLocationBasic) { diff --git a/common/type.cc b/common/type.cc index 684c5ba09..9ea85954c 100644 --- a/common/type.cc +++ b/common/type.cc @@ -158,7 +158,7 @@ absl::optional GetOrNullopt(const common_internal::TypeVariant& variant) { if (const auto* alt = absl::get_if(&variant); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } } // namespace @@ -243,7 +243,7 @@ absl::optional Type::AsOptional() const { if (auto maybe_opaque = AsOpaque(); maybe_opaque.has_value()) { return maybe_opaque->AsOptional(); } - return absl::nullopt; + return std::nullopt; } absl::optional Type::AsString() const { @@ -263,7 +263,7 @@ absl::optional Type::AsStruct() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } absl::optional Type::AsTimestamp() const { @@ -603,7 +603,7 @@ absl::optional StructTypeField::AsMessage() const { alternative != nullptr) { return *alternative; } - return absl::nullopt; + return std::nullopt; } StructTypeField::operator MessageTypeField() const { diff --git a/common/type_introspector.cc b/common/type_introspector.cc index 26f53685e..3846ab58b 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -104,7 +104,7 @@ struct WellKnownType { auto it = std::lower_bound(fields_by_name.begin(), fields_by_name.end(), name, FieldNameComparer{}); if (it == fields_by_name.end() || it->name() != name) { - return absl::nullopt; + return std::nullopt; } return *it; } @@ -114,7 +114,7 @@ struct WellKnownType { auto it = std::lower_bound(fields_by_number.begin(), fields_by_number.end(), number, FieldNumberComparer{}); if (it == fields_by_number.end() || it->number() != number) { - return absl::nullopt; + return std::nullopt; } return *it; } @@ -216,25 +216,25 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { absl::StatusOr> TypeIntrospector::FindTypeImpl( absl::string_view) const { - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeIntrospector::FindEnumConstantImpl(absl::string_view, absl::string_view) const { - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, absl::string_view) const { - return absl::nullopt; + return std::nullopt; } absl::StatusOr< absl::optional>> TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { - return absl::nullopt; + return std::nullopt; } absl::optional FindWellKnownType(absl::string_view name) { @@ -242,7 +242,7 @@ absl::optional FindWellKnownType(absl::string_view name) { if (auto it = well_known_types.find(name); it != well_known_types.end()) { return it->second.type; } - return absl::nullopt; + return std::nullopt; } absl::optional FindWellKnownTypeEnumConstant( @@ -251,7 +251,7 @@ absl::optional FindWellKnownTypeEnumConstant( return TypeIntrospector::EnumConstant{ IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } - return absl::nullopt; + return std::nullopt; } absl::optional FindWellKnownTypeFieldByName( @@ -260,7 +260,7 @@ absl::optional FindWellKnownTypeFieldByName( if (auto it = well_known_types.find(type); it != well_known_types.end()) { return it->second.FieldByName(name); } - return absl::nullopt; + return std::nullopt; } absl::optional> @@ -268,7 +268,7 @@ ListFieldsForWellKnownType(absl::string_view type) { const auto& well_known_types = GetWellKnownTypesMap(); auto it = well_known_types.find(type); if (it == well_known_types.end()) { - return absl::nullopt; + return std::nullopt; } // The fields are not normally gettable. return {}; diff --git a/common/type_proto.cc b/common/type_proto.cc index 66c16689d..b6b66f73a 100644 --- a/common/type_proto.cc +++ b/common/type_proto.cc @@ -71,7 +71,7 @@ absl::optional MaybeWellKnownType(absl::string_view type_name) { return it->second; } - return absl::nullopt; + return std::nullopt; } absl::Status TypeToProtoInternal(const cel::Type& type, diff --git a/common/type_reflector_test.cc b/common/type_reflector_test.cc index f2ff2c322..d9c855e4b 100644 --- a/common/type_reflector_test.cc +++ b/common/type_reflector_test.cc @@ -210,7 +210,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { internal::GetTestingMessageFactory(), "google.protobuf.BoolValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -218,7 +218,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BoolValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BoolValue(true)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -236,7 +236,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { internal::GetTestingMessageFactory(), "google.protobuf.Int32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -248,7 +248,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int32Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -270,7 +270,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { internal::GetTestingMessageFactory(), "google.protobuf.Int64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -278,7 +278,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Int64Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -296,7 +296,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { internal::GetTestingMessageFactory(), "google.protobuf.UInt32Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -308,7 +308,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt32Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kOutOfRange))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -330,7 +330,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { internal::GetTestingMessageFactory(), "google.protobuf.UInt64Value"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -338,7 +338,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_UInt64Value) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, UintValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, UintValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -356,7 +356,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { internal::GetTestingMessageFactory(), "google.protobuf.FloatValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -364,7 +364,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_FloatValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -382,7 +382,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { internal::GetTestingMessageFactory(), "google.protobuf.DoubleValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -390,7 +390,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_DoubleValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, DoubleValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, DoubleValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -408,7 +408,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { internal::GetTestingMessageFactory(), "google.protobuf.StringValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", StringValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -416,7 +416,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_StringValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, StringValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, StringValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -434,7 +434,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { internal::GetTestingMessageFactory(), "google.protobuf.BytesValue"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("value", BytesValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -442,7 +442,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_BytesValue) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, BytesValue("foo")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue("foo")), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -460,7 +460,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { internal::GetTestingMessageFactory(), "google.protobuf.Duration"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -468,7 +468,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -477,7 +477,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -485,7 +485,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Duration) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -505,7 +505,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { internal::GetTestingMessageFactory(), "google.protobuf.Timestamp"); ASSERT_THAT(builder, NotNull()); EXPECT_THAT(builder->SetFieldByName("seconds", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -513,7 +513,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("nanos", IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName( "nanos", IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -522,7 +522,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(1, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -530,7 +530,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Timestamp) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, IntValue(1)), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber( 2, IntValue(std::numeric_limits::max())), IsOkAndHolds(Optional( @@ -552,7 +552,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { EXPECT_THAT(builder->SetFieldByName( "type_url", StringValue("type.googleapis.com/google.protobuf.BoolValue")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("does_not_exist", IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -560,14 +560,14 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByName("value", BytesValue()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByName("value", BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT( builder->SetFieldByNumber( 1, StringValue("type.googleapis.com/google.protobuf.BoolValue")), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(3, IntValue(1)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kNotFound))))); @@ -575,7 +575,7 @@ TEST_F(TypeReflectorTest, NewValueBuilder_Any) { IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); EXPECT_THAT(builder->SetFieldByNumber(2, BytesValue()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(builder->SetFieldByNumber(2, BoolValue(true)), IsOkAndHolds(Optional( ErrorValueIs(StatusIs(absl::StatusCode::kInvalidArgument))))); From 7f3e2da8e20d11e5037e3edd95541c09410bff52 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 24 Jun 2026 00:57:41 -0700 Subject: [PATCH 142/143] No public description PiperOrigin-RevId: 937159464 --- extensions/protobuf/internal/qualify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/protobuf/internal/qualify.cc b/extensions/protobuf/internal/qualify.cc index dba4f44ae..37ad30011 100644 --- a/extensions/protobuf/internal/qualify.cc +++ b/extensions/protobuf/internal/qualify.cc @@ -145,7 +145,7 @@ absl::StatusOr> LookupMapValu bool found = cel::extensions::protobuf_internal::LookupMapValue( *reflection, *message, *field_desc, proto_key, &value_ref); if (!found) { - return absl::nullopt; + return std::nullopt; } return value_ref; } From 43960d06c38367520733fc2054688ec5fe447b28 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 24 Jun 2026 07:01:10 -0700 Subject: [PATCH 143/143] No public description PiperOrigin-RevId: 937312837 --- common/types/opaque_type.cc | 2 +- common/types/struct_type.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/types/opaque_type.cc b/common/types/opaque_type.cc index 002319d1d..9c58e8289 100644 --- a/common/types/opaque_type.cc +++ b/common/types/opaque_type.cc @@ -98,7 +98,7 @@ absl::optional OpaqueType::AsOptional() const { if (IsOptional()) { return OptionalType(absl::in_place, *this); } - return absl::nullopt; + return std::nullopt; } OptionalType OpaqueType::GetOptional() const { diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc index a1be1f786..69f531a2f 100644 --- a/common/types/struct_type.cc +++ b/common/types/struct_type.cc @@ -61,7 +61,7 @@ absl::optional StructType::AsMessage() const { if (const auto* alt = absl::get_if(&variant_); alt != nullptr) { return *alt; } - return absl::nullopt; + return std::nullopt; } MessageType StructType::GetMessage() const {