From 19757ad5540392348fc880c25a448b05c8f6ca64 Mon Sep 17 00:00:00 2001 From: DeadlySurgeon Date: Mon, 20 Nov 2023 06:53:58 -0600 Subject: [PATCH] Resolve ctx mixup --- evaluator/evaluator.go | 137 +++++++++++++++++------------------- evaluator/evaluator_test.go | 12 ++-- evaluator/stdlib_core.go | 3 +- 3 files changed, 73 insertions(+), 79 deletions(-) diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index eb78177..e65fc08 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -22,7 +22,6 @@ var ( TRUE = &object.Boolean{Value: true} FALSE = &object.Boolean{Value: false} PRAGMAS = make(map[string]int) - CTX = context.Background() ) // The built-in functions / standard-library methods are stored here. @@ -30,7 +29,7 @@ var builtins = map[string]*object.Builtin{} // Eval is our core function for evaluating nodes. func Eval(node ast.Node, env *object.Environment) object.Object { - return EvalContext(CTX, node, env) + return EvalContext(context.Background(), node, env) } // EvalContext is our core function for evaluating nodes. @@ -51,9 +50,9 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob //Statements case *ast.Program: - return evalProgram(node, env) + return evalProgram(ctx, node, env) case *ast.ExpressionStatement: - return Eval(node.Expression, env) + return EvalContext(ctx, node.Expression, env) //Expressions case *ast.IntegerLiteral: @@ -65,7 +64,7 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob case *ast.NullLiteral: return NULL case *ast.PrefixExpression: - right := Eval(node.Right, env) + right := EvalContext(ctx, node.Right, env) if isError(right) { return right } @@ -73,11 +72,11 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob case *ast.PostfixExpression: return evalPostfixExpression(env, node.Operator, node) case *ast.InfixExpression: - left := Eval(node.Left, env) + left := EvalContext(ctx, node.Left, env) if isError(left) { return left } - right := Eval(node.Right, env) + right := EvalContext(ctx, node.Right, env) if isError(right) { return right } @@ -91,30 +90,30 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob return (res) case *ast.BlockStatement: - return evalBlockStatement(node, env) + return evalBlockStatement(ctx, node, env) case *ast.IfExpression: - return evalIfExpression(node, env) + return evalIfExpression(ctx, node, env) case *ast.TernaryExpression: - return evalTernaryExpression(node, env) + return evalTernaryExpression(ctx, node, env) case *ast.ForLoopExpression: - return evalForLoopExpression(node, env) + return evalForLoopExpression(ctx, node, env) case *ast.ForeachStatement: - return evalForeachExpression(node, env) + return evalForeachExpression(ctx, node, env) case *ast.ReturnStatement: - val := Eval(node.ReturnValue, env) + val := EvalContext(ctx, node.ReturnValue, env) if isError(val) { return val } return &object.ReturnValue{Value: val} case *ast.LetStatement: - val := Eval(node.Value, env) + val := EvalContext(ctx, node.Value, env) if isError(val) { return val } env.Set(node.Name.Value, val) return val case *ast.ConstStatement: - val := Eval(node.Value, env) + val := EvalContext(ctx, node.Value, env) if isError(val) { return val } @@ -134,7 +133,7 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob env.Set(node.TokenLiteral(), &object.Function{Parameters: params, Env: env, Body: body, Defaults: defaults}) return NULL case *ast.ObjectCallExpression: - res := evalObjectCallExpression(node, env) + res := evalObjectCallExpression(ctx, node, env) if isError(res) { fmt.Fprintf(os.Stderr, "Error calling object-method %s\n", res.Inspect()) if PRAGMAS["strict"] == 1 { @@ -143,15 +142,15 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob } return res case *ast.CallExpression: - function := Eval(node.Function, env) + function := EvalContext(ctx, node.Function, env) if isError(function) { return function } - args := evalExpression(node.Arguments, env) + args := evalExpression(ctx, node.Arguments, env) if len(args) == 1 && isError(args[0]) { return args[0] } - res := applyFunction(env, function, args) + res := applyFunction(ctx, env, function, args) if isError(res) { fmt.Fprintf(os.Stderr, "Error calling `%s` : %s\n", node.Function, res.Inspect()) if PRAGMAS["strict"] == 1 { @@ -162,7 +161,7 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob return res case *ast.ArrayLiteral: - elements := evalExpression(node.Elements, env) + elements := evalExpression(ctx, node.Elements, env) if len(elements) == 1 && isError(elements[0]) { return elements[0] } @@ -174,30 +173,30 @@ func EvalContext(ctx context.Context, node ast.Node, env *object.Environment) ob case *ast.BacktickLiteral: return backTickOperation(node.Value) case *ast.IndexExpression: - left := Eval(node.Left, env) + left := EvalContext(ctx, node.Left, env) if isError(left) { return left } - index := Eval(node.Index, env) + index := EvalContext(ctx, node.Index, env) if isError(index) { return index } return evalIndexExpression(left, index) case *ast.AssignStatement: - return evalAssignStatement(node, env) + return evalAssignStatement(ctx, node, env) case *ast.HashLiteral: - return evalHashLiteral(node, env) + return evalHashLiteral(ctx, node, env) case *ast.SwitchExpression: - return evalSwitchStatement(node, env) + return evalSwitchStatement(ctx, node, env) } return nil } // eval block statement -func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object { +func evalBlockStatement(ctx context.Context, block *ast.BlockStatement, env *object.Environment) object.Object { var result object.Object for _, statement := range block.Statements { - result = Eval(statement, env) + result = EvalContext(ctx, statement, env) if result != nil { rt := result.Type() if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { @@ -613,7 +612,7 @@ func evalStringInfixExpression(operator string, left, right object.Object) objec // evalIfExpression handles an `if` expression, running the block // if the condition matches, and running any optional else block // otherwise. -func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object { +func evalIfExpression(ctx context.Context, ie *ast.IfExpression, env *object.Environment) object.Object { // // Create an environment for handling regexps // @@ -624,14 +623,14 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje i++ } nEnv := object.NewTemporaryScope(env, permit) - condition := Eval(ie.Condition, nEnv) + condition := EvalContext(ctx, ie.Condition, nEnv) if isError(condition) { return condition } if isTruthy(condition) { - return Eval(ie.Consequence, nEnv) + return EvalContext(ctx, ie.Consequence, nEnv) } else if ie.Alternative != nil { - return Eval(ie.Alternative, nEnv) + return EvalContext(ctx, ie.Alternative, nEnv) } else { return NULL } @@ -641,21 +640,21 @@ func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Obje // is true we return the contents of evaluating the true-branch, otherwise // the false-branch. (Unlike an `if` statement we know that we always have // an alternative/false branch.) -func evalTernaryExpression(te *ast.TernaryExpression, env *object.Environment) object.Object { +func evalTernaryExpression(ctx context.Context, te *ast.TernaryExpression, env *object.Environment) object.Object { - condition := Eval(te.Condition, env) + condition := EvalContext(ctx, te.Condition, env) if isError(condition) { return condition } if isTruthy(condition) { - return Eval(te.IfTrue, env) + return EvalContext(ctx, te.IfTrue, env) } - return Eval(te.IfFalse, env) + return EvalContext(ctx, te.IfFalse, env) } -func evalAssignStatement(a *ast.AssignStatement, env *object.Environment) (val object.Object) { - evaluated := Eval(a.Value, env) +func evalAssignStatement(ctx context.Context, a *ast.AssignStatement, env *object.Environment) (val object.Object) { + evaluated := EvalContext(ctx, a.Value, env) if isError(evaluated) { return evaluated } @@ -754,10 +753,10 @@ func evalAssignStatement(a *ast.AssignStatement, env *object.Environment) (val o return evaluated } -func evalSwitchStatement(se *ast.SwitchExpression, env *object.Environment) object.Object { +func evalSwitchStatement(ctx context.Context, se *ast.SwitchExpression, env *object.Environment) object.Object { // Get the value. - obj := Eval(se.Value, env) + obj := EvalContext(ctx, se.Value, env) // Try all the choices for _, opt := range se.Choices { @@ -772,14 +771,14 @@ func evalSwitchStatement(se *ast.SwitchExpression, env *object.Environment) obje for _, val := range opt.Expr { // Get the value of the case - out := Eval(val, env) + out := EvalContext(ctx, val, env) // Is it a literal match? if obj.Type() == out.Type() && (obj.Inspect() == out.Inspect()) { // Evaluate the block and return the value - blockOut := evalBlockStatement(opt.Block, env) + blockOut := evalBlockStatement(ctx, opt.Block, env) return blockOut } @@ -790,7 +789,7 @@ func evalSwitchStatement(se *ast.SwitchExpression, env *object.Environment) obje if m == TRUE { // Evaluate the block and return the value - out := evalBlockStatement(opt.Block, env) + out := evalBlockStatement(ctx, opt.Block, env) return out } @@ -804,7 +803,7 @@ func evalSwitchStatement(se *ast.SwitchExpression, env *object.Environment) obje // skip default if opt.Default { - out := evalBlockStatement(opt.Block, env) + out := evalBlockStatement(ctx, opt.Block, env) return out } } @@ -812,15 +811,15 @@ func evalSwitchStatement(se *ast.SwitchExpression, env *object.Environment) obje return nil } -func evalForLoopExpression(fle *ast.ForLoopExpression, env *object.Environment) object.Object { +func evalForLoopExpression(ctx context.Context, fle *ast.ForLoopExpression, env *object.Environment) object.Object { rt := &object.Boolean{Value: true} for { - condition := Eval(fle.Condition, env) + condition := EvalContext(ctx, fle.Condition, env) if isError(condition) { return condition } if isTruthy(condition) { - rt := Eval(fle.Consequence, env) + rt := EvalContext(ctx, fle.Consequence, env) if !isError(rt) && (rt.Type() == object.RETURN_VALUE_OBJ || rt.Type() == object.ERROR_OBJ) { return rt } @@ -832,10 +831,10 @@ func evalForLoopExpression(fle *ast.ForLoopExpression, env *object.Environment) } // handle "for x [,y] in .." -func evalForeachExpression(fle *ast.ForeachStatement, env *object.Environment) object.Object { +func evalForeachExpression(ctx context.Context, fle *ast.ForeachStatement, env *object.Environment) object.Object { // expression - val := Eval(fle.Value, env) + val := EvalContext(ctx, fle.Value, env) helper, ok := val.(object.Iterable) if !ok { @@ -872,7 +871,7 @@ func evalForeachExpression(fle *ast.ForeachStatement, env *object.Environment) o } // Eval the block - rt := Eval(fle.Body, child) + rt := EvalContext(ctx, fle.Body, child) // // If we got an error/return then we handle it. @@ -901,10 +900,10 @@ func isTruthy(obj object.Object) bool { } } -func evalProgram(program *ast.Program, env *object.Environment) object.Object { +func evalProgram(ctx context.Context, program *ast.Program, env *object.Environment) object.Object { var result object.Object for _, statement := range program.Statements { - result = Eval(statement, env) + result = EvalContext(ctx, statement, env) switch result := result.(type) { case *object.ReturnValue: return result.Value @@ -940,10 +939,10 @@ func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object return newError("identifier not found: " + node.Value) } -func evalExpression(exps []ast.Expression, env *object.Environment) []object.Object { +func evalExpression(ctx context.Context, exps []ast.Expression, env *object.Environment) []object.Object { var result []object.Object for _, e := range exps { - evaluated := Eval(e, env) + evaluated := EvalContext(ctx, e, env) if isError(evaluated) { return []object.Object{evaluated} } @@ -1090,10 +1089,10 @@ func evalStringIndexExpression(input, index object.Object) object.Object { return &object.String{Value: string(ret)} } -func evalHashLiteral(node *ast.HashLiteral, env *object.Environment) object.Object { +func evalHashLiteral(ctx context.Context, node *ast.HashLiteral, env *object.Environment) object.Object { pairs := make(map[object.HashKey]object.HashPair) for keyNode, valueNode := range node.Pairs { - key := Eval(keyNode, env) + key := EvalContext(ctx, keyNode, env) if isError(key) { return key } @@ -1101,7 +1100,7 @@ func evalHashLiteral(node *ast.HashLiteral, env *object.Environment) object.Obje if !ok { return newError("unusable as hash key: %s", key.Type()) } - value := Eval(valueNode, env) + value := EvalContext(ctx, valueNode, env) if isError(value) { return value } @@ -1113,11 +1112,11 @@ func evalHashLiteral(node *ast.HashLiteral, env *object.Environment) object.Obje } -func applyFunction(env *object.Environment, fn object.Object, args []object.Object) object.Object { +func applyFunction(ctx context.Context, env *object.Environment, fn object.Object, args []object.Object) object.Object { switch fn := fn.(type) { case *object.Function: - extendEnv := extendFunctionEnv(fn, args) - evaluated := Eval(fn.Body, extendEnv) + extendEnv := extendFunctionEnv(ctx, fn, args) + evaluated := EvalContext(ctx, fn.Body, extendEnv) return upwrapReturnValue(evaluated) case *object.Builtin: return fn.Fn(env, args...) @@ -1127,12 +1126,12 @@ func applyFunction(env *object.Environment, fn object.Object, args []object.Obje } -func extendFunctionEnv(fn *object.Function, args []object.Object) *object.Environment { +func extendFunctionEnv(ctx context.Context, fn *object.Function, args []object.Object) *object.Environment { env := object.NewEnclosedEnvironment(fn.Env) // Set the defaults for key, val := range fn.Defaults { - env.Set(key, Eval(val, env)) + env.Set(key, EvalContext(ctx, val, env)) } for paramIdx, param := range fn.Parameters { if paramIdx < len(args) { @@ -1155,16 +1154,10 @@ func RegisterBuiltin(name string, fun object.BuiltinFunction) { builtins[name] = &object.Builtin{Fn: fun} } -// SetContext lets you configure a context, which is helpful if you wish to -// cause execution to timeout after a given period, for example. -func SetContext(ctx context.Context) { - CTX = ctx -} - // evalObjectCallExpression invokes methods against objects. -func evalObjectCallExpression(call *ast.ObjectCallExpression, env *object.Environment) object.Object { +func evalObjectCallExpression(ctx context.Context, call *ast.ObjectCallExpression, env *object.Environment) object.Object { - obj := Eval(call.Object, env) + obj := EvalContext(ctx, call.Object, env) if method, ok := call.Call.(*ast.CallExpression); ok { // @@ -1174,7 +1167,7 @@ func evalObjectCallExpression(call *ast.ObjectCallExpression, env *object.Enviro // We do this by forwarding the call to the appropriate // `invokeMethod` interface on the object. // - args := evalExpression(call.Call.(*ast.CallExpression).Arguments, env) + args := evalExpression(ctx, call.Call.(*ast.CallExpression).Arguments, env) ret := obj.InvokeMethod(method.Function.String(), *env, args...) if ret != nil { return ret @@ -1235,7 +1228,7 @@ func evalObjectCallExpression(call *ast.ObjectCallExpression, env *object.Enviro // // Extend our environment with the functional-args. // - extendEnv := extendFunctionEnv(fn.(*object.Function), args) + extendEnv := extendFunctionEnv(ctx, fn.(*object.Function), args) // // Now set "self" to be the implicit object, against @@ -1246,7 +1239,7 @@ func evalObjectCallExpression(call *ast.ObjectCallExpression, env *object.Enviro // // Finally invoke & return. // - evaluated := Eval(fn.(*object.Function).Body, extendEnv) + evaluated := EvalContext(ctx, fn.(*object.Function).Body, extendEnv) obj = upwrapReturnValue(evaluated) return obj } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 757efbc..9febf82 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -58,9 +58,8 @@ func testEval(input string) object.Object { ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond) defer cancel() - SetContext(ctx) - return Eval(program, env) + return EvalContext(ctx, program, env) } func testDecimalObject(t *testing.T, obj object.Object, expected interface{}) bool { @@ -638,15 +637,16 @@ for ( true ) { i++; } ` - ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond) - defer cancel() l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() env := object.NewEnvironment() - SetContext(ctx) - evaluated := Eval(program, env) + + ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond) + defer cancel() + + evaluated := EvalContext(ctx, program, env) errObj, ok := evaluated.(*object.Error) if !ok { diff --git a/evaluator/stdlib_core.go b/evaluator/stdlib_core.go index 21593ac..f59293d 100644 --- a/evaluator/stdlib_core.go +++ b/evaluator/stdlib_core.go @@ -1,6 +1,7 @@ package evaluator import ( + "context" "fmt" "os" "regexp" @@ -65,7 +66,7 @@ func evalFun(env *object.Environment, args ...object.Object) object.Object { program := p.ParseProgram() if len(p.Errors()) == 0 { // evaluate it, and return the output. - return (Eval(program, env)) + return EvalContext(context.Background(), program, env) } // Otherwise abort. We should have try { } catch