Skip to content

Commit

Permalink
Support for cel.@block during policy composition (#1056)
Browse files Browse the repository at this point in the history
* Runtime support for cel.@block
* Additional checks to prevent bad index specification
* Support for constant lists and extended validations
* Support for cel.@block during policy composition
  • Loading branch information
TristonianJones authored Nov 3, 2024
1 parent f9db1d6 commit 3f12eca
Show file tree
Hide file tree
Showing 11 changed files with 757 additions and 96 deletions.
8 changes: 7 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,20 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) {

// Program generates an evaluable instance of the Ast within the environment (Env).
func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return e.PlanProgram(ast.NativeRep(), opts...)
}

// PlanProgram generates an evaluable instance of the AST in the go-native representation within
// the environment (Env).
func (e *Env) PlanProgram(a *celast.AST, opts ...ProgramOption) (Program, error) {
optSet := e.progOpts
if len(opts) != 0 {
mergedOpts := []ProgramOption{}
mergedOpts = append(mergedOpts, e.progOpts...)
mergedOpts = append(mergedOpts, opts...)
optSet = mergedOpts
}
return newProgram(e, ast, optSet)
return newProgram(e, a, optSet)
}

// CELTypeAdapter returns the `types.Adapter` configured for the environment.
Expand Down
10 changes: 10 additions & 0 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ type OptimizerContext struct {
*Issues
}

// ExtendEnv auguments the context's environment with the additional options.
func (opt *OptimizerContext) ExtendEnv(opts ...EnvOption) error {
e, err := opt.Env.Extend(opts...)
if err != nil {
return err
}
opt.Env = e
return nil
}

// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.
Expand Down
7 changes: 4 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"sync"

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
Expand Down Expand Up @@ -151,7 +152,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()

Expand Down Expand Up @@ -255,9 +256,9 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
return p.initInterpretable(a, decorators)
}

func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
interpretable, err := p.interpreter.NewInterpretable(a, decs...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions conformance/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ _ALL_TESTS = [
"@dev_cel_expr//tests/simple:testdata/timestamps.textproto",
"@dev_cel_expr//tests/simple:testdata/unknowns.textproto",
"@dev_cel_expr//tests/simple:testdata/wrappers.textproto",
"@dev_cel_expr//tests/simple:testdata/block_ext.textproto",
]

_TESTS_TO_SKIP = [
Expand Down Expand Up @@ -68,6 +69,7 @@ go_test(
deps = [
"//cel:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//ext:go_default_library",
Expand Down
88 changes: 88 additions & 0 deletions conformance/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/ext"
Expand Down Expand Up @@ -89,6 +90,7 @@ func init() {
ext.Math(),
ext.Protos(),
ext.Strings(),
cel.Lib(celBlockLib{}),
}

var err error
Expand Down Expand Up @@ -279,3 +281,89 @@ func TestConformance(t *testing.T) {
}
}
}

type celBlockLib struct{}

func (celBlockLib) LibraryName() string {
return "cel.lib.ext.cel.block.conformance"
}

func (celBlockLib) CompileOptions() []cel.EnvOption {
// Simulate indexed arguments which would normally have strong types associated
// with the values as part of a static optimization pass
maxIndices := 30
indexOpts := make([]cel.EnvOption, maxIndices)
for i := 0; i < maxIndices; i++ {
indexOpts[i] = cel.Variable(fmt.Sprintf("@index%d", i), cel.DynType)
}
return append([]cel.EnvOption{
cel.Macros(
// cel.block([args], expr)
cel.ReceiverMacro("block", 2, celBlock),
// cel.index(int)
cel.ReceiverMacro("index", 1, celIndex),
// cel.iterVar(int, int)
cel.ReceiverMacro("iterVar", 2, celCompreVar("cel.iterVar", "@it")),
// cel.accuVar(int, int)
cel.ReceiverMacro("accuVar", 2, celCompreVar("cel.accuVar", "@ac")),
),
}, indexOpts...)
}

func (celBlockLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

func celBlock(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
bindings := args[0]
if bindings.Kind() != ast.ListKind {
return bindings, mef.NewError(bindings.ID(), "cel.block requires the first arg to be a list literal")
}
return mef.NewCall("cel.@block", args...), nil
}

func celIndex(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
index := args[0]
if !isNonNegativeInt(index) {
return index, mef.NewError(index.ID(), "cel.index requires a single non-negative int constant arg")
}
indexVal := index.AsLiteral().(types.Int)
return mef.NewIdent(fmt.Sprintf("@index%d", indexVal)), nil
}

func celCompreVar(funcName, varPrefix string) cel.MacroFactory {
return func(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !isCELNamespace(target) {
return nil, nil
}
depth := args[0]
if !isNonNegativeInt(depth) {
return depth, mef.NewError(depth.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName))
}
unique := args[1]
if !isNonNegativeInt(unique) {
return unique, mef.NewError(unique.ID(), fmt.Sprintf("%s requires two non-negative int constant args", funcName))
}
depthVal := depth.AsLiteral().(types.Int)
uniqueVal := unique.AsLiteral().(types.Int)
return mef.NewIdent(fmt.Sprintf("%s:%d:%d", varPrefix, depthVal, uniqueVal)), nil
}
}

func isCELNamespace(target ast.Expr) bool {
return target.Kind() == ast.IdentKind && target.AsIdent() == "cel"
}

func isNonNegativeInt(expr ast.Expr) bool {
if expr.Kind() != ast.LiteralKind {
return false
}
val := expr.AsLiteral()
return val.Type() == cel.IntType && val.(types.Int) >= 0
}
Loading

0 comments on commit 3f12eca

Please sign in to comment.