diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc index 167e3b104..b866a5965 100644 --- a/env/runtime_std_extensions.cc +++ b/env/runtime_std_extensions.cc @@ -105,8 +105,11 @@ void RegisterStandardExtensions(EnvRuntime& env_runtime) { "cel.lib.ext.strings", "strings", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { + cel::extensions::StringsExtensionOptions strings_options; + strings_options.version = version; return cel::extensions::RegisterStringsFunctions( - runtime_builder.function_registry(), runtime_options, version); + runtime_builder.function_registry(), runtime_options, + strings_options); }); } diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 6e58a7b86..8d9c4f783 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -14,20 +14,19 @@ #include "extensions/formatting.h" +#include #include #include #include #include #include #include -#include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" -#include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -64,7 +63,7 @@ absl::StatusOr FormatString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); absl::StatusOr>> ParsePrecision( - absl::string_view format) { + absl::string_view format, int max_precision) { if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; int64_t i = 1; @@ -80,9 +79,9 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } - if (precision > kMaxPrecision) { + if (precision > max_precision) { return absl::InvalidArgumentError( - absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision)); + absl::StrCat("precision specifier exceeds maximum of ", max_precision)); } return std::pair{i, precision}; } @@ -444,12 +443,13 @@ absl::StatusOr FormatScientific( } absl::StatusOr> ParseAndFormatClause( - absl::string_view format, const Value& value, + absl::string_view format, const Value& value, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { - CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format)); + CEL_ASSIGN_OR_RETURN(auto precision_pair, + ParsePrecision(format, max_precision)); auto [read, precision] = precision_pair; switch (format[read]) { case 's': { @@ -494,7 +494,7 @@ absl::StatusOr> ParseAndFormatClause( } absl::StatusOr Format( - const StringValue& format_value, const ListValue& args, + const StringValue& format_value, const ListValue& args, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -512,43 +512,50 @@ absl::StatusOr Format( } ++i; if (i >= format.size()) { - return absl::InvalidArgumentError("unexpected end of format string"); + return ErrorValue( + absl::InvalidArgumentError("unexpected end of format string")); } if (format[i] == '%') { result.push_back('%'); continue; } if (arg_index >= args_size) { - return absl::InvalidArgumentError( - absl::StrFormat("index %d out of range", arg_index)); + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("index %d out of range", arg_index))); } CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, message_factory, arena)); - CEL_ASSIGN_OR_RETURN( - auto clause, - ParseAndFormatClause(format.substr(i), value, descriptor_pool, - message_factory, arena, clause_scratch)); - absl::StrAppend(&result, clause.second); - i += clause.first; + + auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, + descriptor_pool, message_factory, arena, + clause_scratch); + if (!clause.ok()) { + return ErrorValue(std::move(clause).status()); + } + absl::StrAppend(&result, clause->second); + i += clause->first; } - return StringValue(arena, std::move(result)); + return StringValue::From(std::move(result), arena); } } // namespace absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { + const RuntimeOptions& options, + int max_precision) { + max_precision = std::clamp(max_precision, 0, kMaxPrecision); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: CreateDescriptor("format", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( - [](const StringValue& format, const ListValue& args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) { - return Format(format, args, descriptor_pool, message_factory, - arena); + [max_precision]( + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); }))); return absl::OkStatus(); } diff --git a/extensions/formatting.h b/extensions/formatting.h index bc2002006..f4fa2f2d8 100644 --- a/extensions/formatting.h +++ b/extensions/formatting.h @@ -23,7 +23,8 @@ namespace cel::extensions { // Register extension functions for string formatting. absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); + const RuntimeOptions& options, + int max_precision = 1000); } // namespace cel::extensions diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 824f14e45..ca7ef2ad2 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -96,6 +96,30 @@ TEST_P(StringFormatLimitsTest, FormatLimits) { } } +TEST(StringFormatLimitsTest, MaxPrecisionOption) { + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), + options, /*max_precision=*/99), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", + "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus().message(), + HasSubstr("precision specifier exceeds maximum of 99")); +} + INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, ValuesIn({ "double('%.326f'.format([x])) == x", diff --git a/extensions/strings.cc b/extensions/strings.cc index ed6f27319..9762e5999 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -305,9 +305,10 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options, - int version) { +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options) { + const int version = extension_options.version; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -382,7 +383,8 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, return absl::OkStatus(); } - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( + registry, options, extension_options.max_precision)); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); @@ -412,13 +414,16 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options) { + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options) { return RegisterStringsFunctions( registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + google::api::expr::runtime::ConvertToRuntimeOptions(options), + extension_options); } -CheckerLibrary StringsCheckerLibrary(int version) { +CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { + const int version = options.version; return {"strings", [version](TypeCheckerBuilder& builder) { return RegisterStringsDecls(builder, version); }}; diff --git a/extensions/strings.h b/extensions/strings.h index 3cbc9f19f..3ec92d603 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -27,21 +27,45 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; +struct StringsExtensionOptions { + int version = kStringsExtensionLatestVersion; + + // Maximum precision allowed for floating point format specifiers in + // format() function. This is used for both fixed and scientific notations. + // Value must be in the range [0, 1000], otherwise clamped. + // + // Does not affect default precisions for %e and %f format specifiers. + int max_precision = 1000; +}; + // Register extension functions for strings. absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options); + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options = {}); CheckerLibrary StringsCheckerLibrary( - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); + +inline CheckerLibrary StringsCheckerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCheckerLibrary(options); +} inline CompilerLibrary StringsCompilerLibrary( - int version = kStringsExtensionLatestVersion) { - return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(version)); + const StringsExtensionOptions& options = {}) { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); +} + +inline CompilerLibrary StringsCompilerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCompilerLibrary(options); } } // namespace cel::extensions diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index a5d56eaed..03ea19b29 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -50,6 +50,7 @@ namespace cel::extensions { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; @@ -85,6 +86,48 @@ TEST(StringsCheckerLibrary, SmokeTest) { )~bool^equals)"); } +TEST(StringsExtTest, MaxPrecisionOption) { + StringsExtensionOptions extension_options; + extension_options.max_precision = 99; + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build()); + + auto result = compiler->Compile("'abc %.100f'.format([2.0])", ""); + + ASSERT_THAT(result, IsOk()); + ASSERT_TRUE(result->IsValid()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), + opts, extension_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, + runtime->CreateProgram(*result->ReleaseAst())); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("precision specifier exceeds maximum of 99"))); +} + using StringsExtFunctionsTest = testing::TestWithParam; TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) {