diff --git a/README.md b/README.md index 84deabd..8cdc9f1 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ documentation is representative of the current pql api. - [`as`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/as-operator) - [`count`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/count-operator) - [`join`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/join-operator) +- [`let` statements](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/let-statement), + but only scalar expressions are supported. - [`project`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/project-operator) - [`extend`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/extend-operator) - [`sort`/`order`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/sort-operator) diff --git a/cmd/pql/main.go b/cmd/pql/main.go index bd86cd6..45e85c8 100644 --- a/cmd/pql/main.go +++ b/cmd/pql/main.go @@ -70,6 +70,7 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e } var finalError error + letStatements := new(strings.Builder) for scanner.Scan() { sb.Write(scanner.Bytes()) sb.WriteByte('\n') @@ -80,7 +81,20 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e } for _, stmt := range statements[:len(statements)-1] { - sql, err := pql.Compile(stmt) + // Valid let statements are prepended to an ongoing prelude. + tokens := parser.Scan(stmt) + if len(tokens) > 0 && tokens[0].Kind == parser.TokenIdentifier && tokens[0].Value == "let" { + if _, err := pql.Compile(letStatements.String() + stmt + ";X"); err != nil { + logError(err) + finalError = errors.New("one or more statements could not be compiled") + } else { + letStatements.WriteString(stmt) + letStatements.WriteString(";\n") + } + continue + } + + sql, err := pql.Compile(letStatements.String() + stmt) if err != nil { logError(err) finalError = errors.New("one or more statements could not be compiled") diff --git a/parser/ast.go b/parser/ast.go index 5b3452a..a2f8e4e 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -72,12 +72,20 @@ func (id *QualifiedIdent) Span() Span { func (id *QualifiedIdent) expression() {} +type Statement interface { + Node + statement() +} + // TabularExpr is a query expression that produces a table. +// It implements [Statement]. type TabularExpr struct { Source TabularDataSource Operators []TabularOperator } +func (x *TabularExpr) statement() {} + func (x *TabularExpr) Span() Span { if x == nil { return nullSpan() @@ -547,6 +555,29 @@ func (idx *IndexExpr) Span() Span { func (idx *IndexExpr) expression() {} +// A LetStatement node represents a let statement, +// assigning an expression to a name. +// It implements [Statement]. +type LetStatement struct { + Keyword Span + Name *Ident + Assign Span + X Expr +} + +func (stmt *LetStatement) statement() {} + +func (stmt *LetStatement) Span() Span { + if stmt == nil { + return nullSpan() + } + xSpan := nullSpan() + if stmt.X != nil { + xSpan = stmt.X.Span() + } + return unionSpans(stmt.Keyword, stmt.Name.Span(), stmt.Assign, xSpan) +} + // Walk traverses an AST in depth-first order. // If the visit function returns true for a node, // the visit function will be called for its children. @@ -684,6 +715,11 @@ func Walk(n Node, visit func(n Node) bool) { stack = append(stack, n.Index) stack = append(stack, n.X) } + case *LetStatement: + if visit(n) { + stack = append(stack, n.X) + stack = append(stack, n.Name) + } default: panic(fmt.Errorf("unknown Node type %T", n)) } diff --git a/parser/parser.go b/parser/parser.go index 96b4259..8f34b33 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -21,40 +21,119 @@ type parser struct { splitKind TokenKind } -// Parse converts a Pipeline Query Language tabular expression +// Parse converts a Pipeline Query Language query // into an Abstract Syntax Tree (AST). -func Parse(query string) (*TabularExpr, error) { +func Parse(query string) ([]Statement, error) { p := &parser{ source: query, tokens: Scan(query), } - expr, err := p.tabularExpr() - if p.pos < len(p.tokens) { - trailingToken := p.tokens[p.pos] - if trailingToken.Kind == TokenError { - err = joinErrors(err, &parseError{ - source: p.source, - span: trailingToken.Span, - err: errors.New(trailingToken.Value), - }) + var result []Statement + var resultError error + for { + stmtParser := p.splitSemi() + + stmt, err := firstParse( + func() (Statement, error) { + stmt, err := stmtParser.letStatement() + if stmt == nil { + // Prevent returning a non-nil interface. + return nil, err + } + return stmt, err + }, + func() (Statement, error) { + expr, err := stmtParser.tabularExpr() + if expr == nil { + // Prevent returning a non-nil interface. + return nil, err + } + return expr, err + }, + ) + + if isNotFound(err) { + // We're okay with empty statements, we just ignore them. + if stmtParser.pos < len(stmtParser.tokens) { + trailingToken := stmtParser.tokens[stmtParser.pos] + if trailingToken.Kind == TokenError { + resultError = joinErrors(err, &parseError{ + source: p.source, + span: trailingToken.Span, + err: errors.New(trailingToken.Value), + }) + } else { + resultError = joinErrors(err, &parseError{ + source: p.source, + span: trailingToken.Span, + err: errors.New("unrecognized token"), + }) + } + } } else { - err = joinErrors(err, &parseError{ - source: p.source, - span: trailingToken.Span, - err: errors.New("unrecognized token"), - }) + if stmt != nil { + result = append(result, stmt) + } + resultError = joinErrors(resultError, makeErrorOpaque(err)) + resultError = joinErrors(resultError, stmtParser.endSplit()) + } + + // Next token, if present, guaranteed to be a semicolon. + if _, ok := p.next(); !ok { + break } - } else if isNotFound(err) { - err = &parseError{ + } + + if resultError != nil { + return result, fmt.Errorf("parse pipeline query language: %w", resultError) + } + return result, nil +} + +func firstParse[T any](productions ...func() (T, error)) (T, error) { + for _, p := range productions[:len(productions)-1] { + x, err := p() + if !isNotFound(err) { + return x, err + } + } + return productions[len(productions)-1]() +} + +func (p *parser) letStatement() (*LetStatement, error) { + keyword, _ := p.next() + if keyword.Kind != TokenIdentifier || keyword.Value != "let" { + p.prev() + return nil, &parseError{ + source: p.source, + span: keyword.Span, + err: notFoundError{fmt.Errorf("expected 'let', got %s", formatToken(p.source, keyword))}, + } + } + + stmt := &LetStatement{ + Keyword: keyword.Span, + Assign: nullSpan(), + } + var err error + stmt.Name, err = p.ident() + if err != nil { + return stmt, makeErrorOpaque(err) + } + assign, _ := p.next() + if assign.Kind != TokenAssign { + return stmt, &parseError{ source: p.source, - span: indexSpan(len(query)), - err: errors.New("empty query"), + span: assign.Span, + err: fmt.Errorf("expected '=', got %s", formatToken(p.source, assign)), } } + stmt.Assign = assign.Span + stmt.X, err = p.expr() if err != nil { - return expr, fmt.Errorf("parse pipeline query language: %w", err) + return stmt, makeErrorOpaque(err) } - return expr, nil + return stmt, nil } func (p *parser) tabularExpr() (*TabularExpr, error) { @@ -294,27 +373,10 @@ func (p *parser) takeOperator(pipe, keyword Token) (*TakeOperator, error) { Pipe: pipe.Span, Keyword: keyword.Span, } - - tok, _ := p.next() - if tok.Kind != TokenNumber { - return op, &parseError{ - source: p.source, - span: tok.Span, - err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)), - } - } - rowCount := &BasicLit{ - Kind: tok.Kind, - Value: tok.Value, - ValueSpan: tok.Span, - } - op.RowCount = rowCount - if !rowCount.IsInteger() { - return op, &parseError{ - source: p.source, - span: tok.Span, - err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)), - } + var err error + op.RowCount, err = p.rowCount() + if err != nil { + return op, makeErrorOpaque(err) } return op, nil } @@ -326,30 +388,13 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) { By: nullSpan(), } - tok, _ := p.next() - if tok.Kind != TokenNumber { - p.prev() - return op, &parseError{ - source: p.source, - span: tok.Span, - err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)), - } - } - rowCount := &BasicLit{ - Kind: tok.Kind, - Value: tok.Value, - ValueSpan: tok.Span, - } - op.RowCount = rowCount - if !rowCount.IsInteger() { - return op, &parseError{ - source: p.source, - span: tok.Span, - err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)), - } + var err error + op.RowCount, err = p.rowCount() + if err != nil { + return op, makeErrorOpaque(err) } - tok, _ = p.next() + tok, _ := p.next() if tok.Kind != TokenBy { p.prev() return op, &parseError{ @@ -360,11 +405,28 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) { } op.By = tok.Span - var err error op.Col, err = p.sortTerm() return op, makeErrorOpaque(err) } +func (p *parser) rowCount() (Expr, error) { + x, err := p.expr() + if err != nil { + return x, err + } + if lit, ok := x.(*BasicLit); ok { + // Do basic check for common case of literals. + if !lit.IsInteger() { + return x, fmt.Errorf("expected integer, got %s", formatToken(p.source, Token{ + Kind: lit.Kind, + Span: lit.ValueSpan, + Value: lit.Value, + })) + } + } + return x, nil +} + func (p *parser) projectOperator(pipe, keyword Token) (*ProjectOperator, error) { op := &ProjectOperator{ Pipe: pipe.Span, @@ -1042,7 +1104,9 @@ func (p *parser) qualifiedIdent() (*QualifiedIdent, error) { // split advances the parser to right before the next token of the given kind, // and returns a new parser that reads the tokens that were skipped over. // It ignores tokens that are in parenthetical groups after the initial parse position. -// If no such token is found, skipTo advances to EOF. +// If no such token is found, split advances to EOF. +// +// For splitting by semicolon, see [*parser.splitSemi]. func (p *parser) split(search TokenKind) *parser { // stack is the list of expected closing parentheses/brackets. // When a closing parenthesis/bracket is encountered, @@ -1103,6 +1167,31 @@ loop: } } +// splitSemi advances the parser to right before the next semicolon, +// and returns a new parser that reads the tokens that were skipped over. +// If no semicolon is found, splitSemi advances to EOF. +func (p *parser) splitSemi() *parser { + start := p.pos + for { + tok, ok := p.next() + if !ok { + return &parser{ + source: p.source, + tokens: p.tokens[start:], + splitKind: TokenSemi, + } + } + if tok.Kind == TokenSemi { + p.prev() + return &parser{ + source: p.source, + tokens: p.tokens[start:p.pos], + splitKind: TokenSemi, + } + } + } +} + func (p *parser) endSplit() error { if p.splitKind == 0 { // This is a bug, but treating as an error instead of panicing. diff --git a/parser/parser_test.go b/parser/parser_test.go index cc1d453..794df5f 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -13,14 +13,13 @@ import ( var parserTests = []struct { name string query string - want *TabularExpr + want []Statement err bool }{ { name: "Empty", query: "", want: nil, - err: true, }, { name: "BadToken", @@ -31,19 +30,19 @@ var parserTests = []struct { { name: "OnlyTableName", query: "StormEvents", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", NameSpan: newSpan(0, 11), }, }, - }, + }}, }, { name: "OnlyQuotedTableName", query: "`StormEvents`", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -51,12 +50,12 @@ var parserTests = []struct { Quoted: true, }, }, - }, + }}, }, { name: "PipeCount", query: "StormEvents | count", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -69,12 +68,12 @@ var parserTests = []struct { Keyword: newSpan(14, 19), }, }, - }, + }}, }, { name: "DoublePipeCount", query: "StormEvents | count | count", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -91,12 +90,12 @@ var parserTests = []struct { Keyword: newSpan(22, 27), }, }, - }, + }}, }, { name: "WhereTrue", query: "StormEvents | where true", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -113,12 +112,12 @@ var parserTests = []struct { }).AsQualified(), }, }, - }, + }}, }, { name: "NegativeNumber", query: "StormEvents | where -42", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -140,12 +139,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ZeroArgFunction", query: `StormEvents | where rand()`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -166,13 +165,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ZeroArgFunctionWithTrailingComma", query: `StormEvents | where rand(,)`, err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -193,12 +192,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "OneArgFunction", query: "StormEvents | where not(false)", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -225,12 +224,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "TwoArgFunction", query: `StormEvents | where strcat("abc", "def")`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -263,12 +262,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "TwoArgFunctionWithTrailingComma", query: `StormEvents | where strcat("abc", "def",)`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -301,13 +300,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "TwoArgFunctionWithTwoTrailingCommas", query: `StormEvents | where strcat("abc", "def",,)`, err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -340,13 +339,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ExtraContentInCount", query: `StormEvents | count x | where true`, err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -367,12 +366,12 @@ var parserTests = []struct { }).AsQualified(), }, }, - }, + }}, }, { name: "BinaryOp", query: "StormEvents | where DamageProperty > 0", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -398,12 +397,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ComparisonWithSamePrecedenceLHS", query: "foo | where x / y * z == 1", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -445,12 +444,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ParenthesizedExpr", query: "foo | where x / (y * z) == 1", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -496,12 +495,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "OperatorPrecedence", query: "foo | where 2 + 3 * 4 + 5 == 19", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -555,12 +554,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "In", query: `StormEvents | where State in ("GEORGIA", "MISSISSIPPI")`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -594,12 +593,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "InAnd", query: `StormEvents | where State in ("GEORGIA", "MISSISSIPPI") and DamageProperty > 10000`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -650,12 +649,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "InAndFlipped", query: `StormEvents | where DamageProperty > 10000 and State in ("GEORGIA", "MISSISSIPPI")`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -706,12 +705,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "MapKey", query: `tab | where mapcol['strkey'] == 42`, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "tab", @@ -746,13 +745,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "MapKeyTrailingExpression", query: `tab | where mapcol['strkey' x] == 42`, err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "tab", @@ -787,13 +786,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "BadArgument", query: "foo | where strcat('a', .bork, 'x', 'y')", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -821,13 +820,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "BadParentheticalExpr", query: "foo | where (.bork) + 2", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -853,12 +852,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "SortBy", query: "foo | sort by bar", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -881,12 +880,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "OrderBy", query: "foo | order by bar", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -909,12 +908,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "SortByTake", query: "foo | sort by bar | take 1", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -946,12 +945,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "SortByMultiple", query: "StormEvents | sort by State asc, StartTime desc", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -986,12 +985,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "SortByNullsFirst", query: "foo | sort by bar nulls first", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "foo", @@ -1015,12 +1014,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "Take", query: "StormEvents | take 5", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1038,12 +1037,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "Limit", query: "StormEvents | limit 5", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1061,12 +1060,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "Project", query: "StormEvents | project EventId, State, EventType", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1102,13 +1101,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ProjectError", query: "StormEvents | project EventId=1 State", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1138,12 +1137,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ProjectExpr", query: "StormEvents | project TotalInjuries = InjuriesDirect + InjuriesIndirect", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1177,12 +1176,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ExtendExpr", query: "StormEvents | extend TotalInjuries = InjuriesDirect + InjuriesIndirect", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1216,12 +1215,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ExtendExprOnly", query: "StormEvents | extend InjuriesDirect + InjuriesIndirect", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1251,12 +1250,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ExtendExprMultiple", query: "StormEvents | extend TotalInjuries = InjuriesDirect + InjuriesIndirect, Duration = EndTime - StartTime", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1309,13 +1308,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ExtendError", query: "StormEvents | extend FooFooF=1 State", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1345,12 +1344,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "UniqueCombination", query: "StormEvents | summarize by State, EventType", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1380,12 +1379,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "MinAndMax", query: "StormEvents | summarize Min = min(Duration), Max = max(Duration)", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1439,12 +1438,12 @@ var parserTests = []struct { By: nullSpan(), }, }, - }, + }}, }, { name: "DistinctCount", query: "StormEvents | summarize TypesOfStorms=dcount(EventType) by State", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1488,13 +1487,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "ShortSummarize", query: "StormEvents | summarize", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1508,13 +1507,13 @@ var parserTests = []struct { By: nullSpan(), }, }, - }, + }}, }, { name: "SummarizeByTerminated", query: "StormEvents | summarize by", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1528,13 +1527,13 @@ var parserTests = []struct { By: newSpan(24, 26), }, }, - }, + }}, }, { name: "SummarizeRandomToken", query: "StormEvents | summarize and", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1548,12 +1547,12 @@ var parserTests = []struct { By: nullSpan(), }, }, - }, + }}, }, { name: "Top", query: "StormEvents | top 3 by InjuriesDirect", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "StormEvents", @@ -1580,12 +1579,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "Join", query: "X | join (Y) on Key", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1619,12 +1618,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "JoinLeft", query: "X | join kind=leftouter (Y) on Key", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1662,13 +1661,13 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "JoinBadFlavor", query: "X | join kind=salt (Y) on Key", err: true, - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1706,12 +1705,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "JoinComplexRight", query: "X | join (Y | where z == 5) on Key", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1764,12 +1763,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "JoinExplicitCondition", query: "X | join (Y) on $left.Key == $right.Key", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1827,12 +1826,12 @@ var parserTests = []struct { }, }, }, - }, + }}, }, { name: "JoinAndCount", query: "X | join (Y) on Key | count", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1870,12 +1869,12 @@ var parserTests = []struct { Keyword: newSpan(22, 27), }, }, - }, + }}, }, { name: "As", query: "X | as Y", - want: &TabularExpr{ + want: []Statement{&TabularExpr{ Source: &TableRef{ Table: &Ident{ Name: "X", @@ -1892,6 +1891,43 @@ var parserTests = []struct { }, }, }, + }}, + }, + { + name: "Let", + query: "let n = 10; Events | take n", + want: []Statement{ + &LetStatement{ + Keyword: newSpan(0, 3), + Name: &Ident{ + Name: "n", + NameSpan: newSpan(4, 5), + }, + Assign: newSpan(6, 7), + X: &BasicLit{ + Kind: TokenNumber, + ValueSpan: newSpan(8, 10), + Value: "10", + }, + }, + &TabularExpr{ + Source: &TableRef{ + Table: &Ident{ + Name: "Events", + NameSpan: newSpan(12, 18), + }, + }, + Operators: []TabularOperator{ + &TakeOperator{ + Pipe: newSpan(19, 20), + Keyword: newSpan(21, 25), + RowCount: (&Ident{ + Name: "n", + NameSpan: newSpan(26, 27), + }).AsQualified(), + }, + }, + }, }, }, } diff --git a/pql.go b/pql.go index d7add43..743eea3 100644 --- a/pql.go +++ b/pql.go @@ -32,10 +32,55 @@ type CompileOptions struct { // Compile converts the given Pipeline Query Language statement // into the equivalent SQL. func (opts *CompileOptions) Compile(source string) (string, error) { - expr, err := parser.Parse(source) + stmts, err := parser.Parse(source) if err != nil { return "", err } + var expr *parser.TabularExpr + scope := make(map[string]string) + if opts != nil { + for k, v := range opts.Parameters { + scope[k] = v + } + } + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *parser.TabularExpr: + if expr != nil { + return "", &compileError{ + source: source, + span: stmt.Span(), + err: fmt.Errorf("batch queries not supported"), + } + } + expr = stmt + case *parser.LetStatement: + if expr != nil { + // Skip let statements after the query: + // they should not be in scope. + continue + } + ctx := &exprContext{ + source: source, + scope: scope, + mode: letExprMode, + } + sb := new(strings.Builder) + if err := writeExpressionMaybeParen(ctx, sb, stmt.X); err != nil { + return "", err + } + scope[stmt.Name.Name] = sb.String() + default: + return "", &compileError{ + source: source, + span: stmt.Span(), + err: fmt.Errorf("unhandled %T statement", stmt), + } + } + } + if expr == nil { + return "", fmt.Errorf("missing tabular queries") + } subqueries, err := splitQueries(nil, source, expr) if err != nil { @@ -47,9 +92,7 @@ func (opts *CompileOptions) Compile(source string) (string, error) { query := subqueries[len(subqueries)-1] ctx := &exprContext{ source: source, - } - if opts != nil { - ctx.scope = opts.Parameters + scope: scope, } if len(ctes) > 0 { sb.WriteString("WITH ") @@ -507,6 +550,7 @@ type exprMode int const ( defaultExprMode exprMode = iota joinExprMode + letExprMode ) type exprContext struct { @@ -539,6 +583,25 @@ func writeExpression(ctx *exprContext, sb *strings.Builder, x parser.Expr) error sb.WriteString(sql) return nil } + if ctx.mode == letExprMode { + return &compileError{ + source: ctx.source, + span: part.NameSpan, + err: fmt.Errorf("unknown identifier %s in let expression", part.Name), + } + } + } else if ctx.mode == letExprMode { + return &compileError{ + source: ctx.source, + span: part.NameSpan, + err: fmt.Errorf("quoted identifier not permitted in let expression"), + } + } + } else if ctx.mode == letExprMode { + return &compileError{ + source: ctx.source, + span: x.Span(), + err: fmt.Errorf("qualified identifier not permitted in let expression"), } } @@ -549,7 +612,7 @@ func writeExpression(ctx *exprContext, sb *strings.Builder, x parser.Expr) error if !part.Quoted && (part.Name == leftJoinTableAlias || part.Name == rightJoinTableAlias) && ctx.mode != joinExprMode { return &compileError{ source: ctx.source, - span: x.Parts[0].NameSpan, + span: part.NameSpan, err: fmt.Errorf("%s used in non-join context", part.Name), } } @@ -699,6 +762,8 @@ func writeExpression(ctx *exprContext, sb *strings.Builder, x parser.Expr) error return nil } +// writeExpressionMaybeParen writes an expression to sb, +// surrounding it with parentheses if sufficiently complex. func writeExpressionMaybeParen(ctx *exprContext, sb *strings.Builder, x parser.Expr) error { for { p, ok := x.(*parser.ParenExpr) diff --git a/testdata/Goldens/Let/input.pql b/testdata/Goldens/Let/input.pql new file mode 100644 index 0000000..c679031 --- /dev/null +++ b/testdata/Goldens/Let/input.pql @@ -0,0 +1,3 @@ +let n = 3; +StateCapitals +| top n by State asc diff --git a/testdata/Goldens/Let/output.csv b/testdata/Goldens/Let/output.csv new file mode 100644 index 0000000..0e045a2 --- /dev/null +++ b/testdata/Goldens/Let/output.csv @@ -0,0 +1,4 @@ +"State","StateCapital" +Alabama,Montgomery +Alaska,Juneau +Arizona,Phoenix diff --git a/testdata/Goldens/Let/output.sql b/testdata/Goldens/Let/output.sql new file mode 100644 index 0000000..8266626 --- /dev/null +++ b/testdata/Goldens/Let/output.sql @@ -0,0 +1 @@ +SELECT * FROM "StateCapitals" ORDER BY "State" ASC NULLS FIRST LIMIT 3;