Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion env/runtime_std_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ void RegisterStandardExtensions(EnvRuntime& env_runtime) {
"cel.lib.ext.strings", "strings", version,
[version](RuntimeBuilder& runtime_builder,
const RuntimeOptions& runtime_options) -> absl::Status {
cel::extensions::StringsExtensionOptions strings_options;
strings_options.version = version;
return cel::extensions::RegisterStringsFunctions(
runtime_builder.function_registry(), runtime_options, version);
runtime_builder.function_registry(), runtime_options,
strings_options);
});
}

Expand Down
57 changes: 32 additions & 25 deletions extensions/formatting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@

#include "extensions/formatting.h"

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/container/btree_map.h"
#include "absl/memory/memory.h"
#include "absl/numeric/bits.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -64,7 +63,7 @@ absl::StatusOr<absl::string_view> FormatString(
std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND);

absl::StatusOr<std::pair<int64_t, std::optional<int>>> ParsePrecision(
absl::string_view format) {
absl::string_view format, int max_precision) {
if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt};

int64_t i = 1;
Expand All @@ -80,9 +79,9 @@ absl::StatusOr<std::pair<int64_t, std::optional<int>>> ParsePrecision(
return absl::InvalidArgumentError(
"unable to convert precision specifier to integer");
}
if (precision > kMaxPrecision) {
if (precision > max_precision) {
return absl::InvalidArgumentError(
absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision));
absl::StrCat("precision specifier exceeds maximum of ", max_precision));
}
return std::pair{i, precision};
}
Expand Down Expand Up @@ -444,12 +443,13 @@ absl::StatusOr<absl::string_view> FormatScientific(
}

absl::StatusOr<std::pair<int64_t, absl::string_view>> ParseAndFormatClause(
absl::string_view format, const Value& value,
absl::string_view format, const Value& value, int max_precision,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena,
std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) {
CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format));
CEL_ASSIGN_OR_RETURN(auto precision_pair,
ParsePrecision(format, max_precision));
auto [read, precision] = precision_pair;
switch (format[read]) {
case 's': {
Expand Down Expand Up @@ -494,7 +494,7 @@ absl::StatusOr<std::pair<int64_t, absl::string_view>> ParseAndFormatClause(
}

absl::StatusOr<Value> Format(
const StringValue& format_value, const ListValue& args,
const StringValue& format_value, const ListValue& args, int max_precision,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
Expand All @@ -512,43 +512,50 @@ absl::StatusOr<Value> Format(
}
++i;
if (i >= format.size()) {
return absl::InvalidArgumentError("unexpected end of format string");
return ErrorValue(
absl::InvalidArgumentError("unexpected end of format string"));
}
if (format[i] == '%') {
result.push_back('%');
continue;
}
if (arg_index >= args_size) {
return absl::InvalidArgumentError(
absl::StrFormat("index %d out of range", arg_index));
return ErrorValue(absl::InvalidArgumentError(
absl::StrFormat("index %d out of range", arg_index)));
}
CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool,
message_factory, arena));
CEL_ASSIGN_OR_RETURN(
auto clause,
ParseAndFormatClause(format.substr(i), value, descriptor_pool,
message_factory, arena, clause_scratch));
absl::StrAppend(&result, clause.second);
i += clause.first;

auto clause = ParseAndFormatClause(format.substr(i), value, max_precision,
descriptor_pool, message_factory, arena,
clause_scratch);
if (!clause.ok()) {
return ErrorValue(std::move(clause).status());
}
absl::StrAppend(&result, clause->second);
i += clause->first;
}
return StringValue(arena, std::move(result));
return StringValue::From(std::move(result), arena);
}

} // namespace

absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry,
const RuntimeOptions& options) {
const RuntimeOptions& options,
int max_precision) {
max_precision = std::clamp(max_precision, 0, kMaxPrecision);
CEL_RETURN_IF_ERROR(registry.Register(
BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, ListValue>::
CreateDescriptor("format", /*receiver_style=*/true),
BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, ListValue>::
WrapFunction(
[](const StringValue& format, const ListValue& args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
return Format(format, args, descriptor_pool, message_factory,
arena);
[max_precision](
const StringValue& format, const ListValue& args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
return Format(format, args, max_precision, descriptor_pool,
message_factory, arena);
})));
return absl::OkStatus();
}
Expand Down
3 changes: 2 additions & 1 deletion extensions/formatting.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace cel::extensions {

// Register extension functions for string formatting.
absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry,
const RuntimeOptions& options);
const RuntimeOptions& options,
int max_precision = 1000);

} // namespace cel::extensions

Expand Down
24 changes: 24 additions & 0 deletions extensions/formatting_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,30 @@ TEST_P(StringFormatLimitsTest, FormatLimits) {
}
}

TEST(StringFormatLimitsTest, MaxPrecisionOption) {
google::protobuf::Arena arena;
const RuntimeOptions options;
ASSERT_OK_AND_ASSIGN(auto builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(),
options, /*max_precision=*/99),
IsOk());

ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());

ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])",
"<input>", ParserOptions{}));
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Program> program,
ProtobufRuntimeAdapter::CreateProgram(*runtime, expr));
Activation activation;

ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation));
ASSERT_TRUE(value.Is<ErrorValue>());
EXPECT_THAT(value.GetError().ToStatus().message(),
HasSubstr("precision specifier exceeds maximum of 99"));
}

INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest,
ValuesIn<std::string>({
"double('%.326f'.format([x])) == x",
Expand Down
19 changes: 12 additions & 7 deletions extensions/strings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,10 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) {

} // namespace

absl::Status RegisterStringsFunctions(FunctionRegistry& registry,
const RuntimeOptions& options,
int version) {
absl::Status RegisterStringsFunctions(
FunctionRegistry& registry, const RuntimeOptions& options,
const StringsExtensionOptions& extension_options) {
const int version = extension_options.version;
CEL_RETURN_IF_ERROR(registry.Register(
BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
CreateDescriptor("split", /*receiver_style=*/true),
Expand Down Expand Up @@ -382,7 +383,8 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry,
return absl::OkStatus();
}

CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options));
CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(
registry, options, extension_options.max_precision));
CEL_RETURN_IF_ERROR(
(UnaryFunctionAdapter<StringValue, StringValue>::RegisterGlobalOverload(
"strings.quote", &Quote, registry)));
Expand Down Expand Up @@ -412,13 +414,16 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry,

absl::Status RegisterStringsFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options) {
const google::api::expr::runtime::InterpreterOptions& options,
const StringsExtensionOptions& extension_options) {
return RegisterStringsFunctions(
registry->InternalGetRegistry(),
google::api::expr::runtime::ConvertToRuntimeOptions(options));
google::api::expr::runtime::ConvertToRuntimeOptions(options),
extension_options);
}

CheckerLibrary StringsCheckerLibrary(int version) {
CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) {
const int version = options.version;
return {"strings", [version](TypeCheckerBuilder& builder) {
return RegisterStringsDecls(builder, version);
}};
Expand Down
34 changes: 29 additions & 5 deletions extensions/strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,45 @@ namespace cel::extensions {

constexpr int kStringsExtensionLatestVersion = 4;

struct StringsExtensionOptions {
int version = kStringsExtensionLatestVersion;

// Maximum precision allowed for floating point format specifiers in
// format() function. This is used for both fixed and scientific notations.
// Value must be in the range [0, 1000], otherwise clamped.
//
// Does not affect default precisions for %e and %f format specifiers.
int max_precision = 1000;
};

// Register extension functions for strings.
absl::Status RegisterStringsFunctions(
FunctionRegistry& registry, const RuntimeOptions& options,
int version = kStringsExtensionLatestVersion);
const StringsExtensionOptions& extension_options = {});

absl::Status RegisterStringsFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options);
const google::api::expr::runtime::InterpreterOptions& options,
const StringsExtensionOptions& extension_options = {});

CheckerLibrary StringsCheckerLibrary(
int version = kStringsExtensionLatestVersion);
const StringsExtensionOptions& extension_options = {});

inline CheckerLibrary StringsCheckerLibrary(int version) {
StringsExtensionOptions options;
options.version = version;
return StringsCheckerLibrary(options);
}

inline CompilerLibrary StringsCompilerLibrary(
int version = kStringsExtensionLatestVersion) {
return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(version));
const StringsExtensionOptions& options = {}) {
return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options));
}

inline CompilerLibrary StringsCompilerLibrary(int version) {
StringsExtensionOptions options;
options.version = version;
return StringsCompilerLibrary(options);
}

} // namespace cel::extensions
Expand Down
43 changes: 43 additions & 0 deletions extensions/strings_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace cel::extensions {
namespace {

using ::absl_testing::IsOk;
using ::absl_testing::StatusIs;
using ::cel::expr::ParsedExpr;
using ::google::api::expr::parser::Parse;
using ::google::api::expr::parser::ParserOptions;
Expand Down Expand Up @@ -85,6 +86,48 @@ TEST(StringsCheckerLibrary, SmokeTest) {
)~bool^equals)");
}

TEST(StringsExtTest, MaxPrecisionOption) {
StringsExtensionOptions extension_options;
extension_options.max_precision = 99;

ASSERT_OK_AND_ASSIGN(
auto compiler_builder,
NewCompilerBuilder(internal::GetTestingDescriptorPool()));

ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk());
ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk());

ASSERT_OK_AND_ASSIGN(auto compiler, compiler_builder->Build());

auto result = compiler->Compile("'abc %.100f'.format([2.0])", "<input>");

ASSERT_THAT(result, IsOk());
ASSERT_TRUE(result->IsValid());

RuntimeOptions opts;
ASSERT_OK_AND_ASSIGN(
auto runtime_builder,
CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts));

ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(),
opts, extension_options),
IsOk());

ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build());

ASSERT_OK_AND_ASSIGN(auto program,
runtime->CreateProgram(*result->ReleaseAst()));

google::protobuf::Arena arena;
cel::Activation activation;
ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation));

ASSERT_TRUE(value.Is<ErrorValue>());
EXPECT_THAT(value.GetError().ToStatus(),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("precision specifier exceeds maximum of 99")));
}

using StringsExtFunctionsTest = testing::TestWithParam<std::string>;

TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) {
Expand Down