Skip to content

Commit

Permalink
Merge pull request #52
Browse files Browse the repository at this point in the history
Add `let` statement.
Requires a breaking change in API to parser.Parse to support multiple statements,
but pql.Compile does not change.
The REPL also required some tweaking.

Fixes #40
  • Loading branch information
zombiezen authored Jun 11, 2024
2 parents a015aeb + 60e4228 commit 73d23b4
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 178 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ documentation is representative of the current pql api.
- [`as`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/as-operator)
- [`count`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/count-operator)
- [`join`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/join-operator)
- [`let` statements](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/let-statement),
but only scalar expressions are supported.
- [`project`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/project-operator)
- [`extend`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/extend-operator)
- [`sort`/`order`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/sort-operator)
Expand Down
16 changes: 15 additions & 1 deletion cmd/pql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e
}

var finalError error
letStatements := new(strings.Builder)
for scanner.Scan() {
sb.Write(scanner.Bytes())
sb.WriteByte('\n')
Expand All @@ -80,7 +81,20 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e
}

for _, stmt := range statements[:len(statements)-1] {
sql, err := pql.Compile(stmt)
// Valid let statements are prepended to an ongoing prelude.
tokens := parser.Scan(stmt)
if len(tokens) > 0 && tokens[0].Kind == parser.TokenIdentifier && tokens[0].Value == "let" {
if _, err := pql.Compile(letStatements.String() + stmt + ";X"); err != nil {
logError(err)
finalError = errors.New("one or more statements could not be compiled")
} else {
letStatements.WriteString(stmt)
letStatements.WriteString(";\n")
}
continue
}

sql, err := pql.Compile(letStatements.String() + stmt)
if err != nil {
logError(err)
finalError = errors.New("one or more statements could not be compiled")
Expand Down
36 changes: 36 additions & 0 deletions parser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,20 @@ func (id *QualifiedIdent) Span() Span {

func (id *QualifiedIdent) expression() {}

type Statement interface {
Node
statement()
}

// TabularExpr is a query expression that produces a table.
// It implements [Statement].
type TabularExpr struct {
Source TabularDataSource
Operators []TabularOperator
}

func (x *TabularExpr) statement() {}

func (x *TabularExpr) Span() Span {
if x == nil {
return nullSpan()
Expand Down Expand Up @@ -547,6 +555,29 @@ func (idx *IndexExpr) Span() Span {

func (idx *IndexExpr) expression() {}

// A LetStatement node represents a let statement,
// assigning an expression to a name.
// It implements [Statement].
type LetStatement struct {
Keyword Span
Name *Ident
Assign Span
X Expr
}

func (stmt *LetStatement) statement() {}

func (stmt *LetStatement) Span() Span {
if stmt == nil {
return nullSpan()
}
xSpan := nullSpan()
if stmt.X != nil {
xSpan = stmt.X.Span()
}
return unionSpans(stmt.Keyword, stmt.Name.Span(), stmt.Assign, xSpan)
}

// Walk traverses an AST in depth-first order.
// If the visit function returns true for a node,
// the visit function will be called for its children.
Expand Down Expand Up @@ -684,6 +715,11 @@ func Walk(n Node, visit func(n Node) bool) {
stack = append(stack, n.Index)
stack = append(stack, n.X)
}
case *LetStatement:
if visit(n) {
stack = append(stack, n.X)
stack = append(stack, n.Name)
}
default:
panic(fmt.Errorf("unknown Node type %T", n))
}
Expand Down
223 changes: 156 additions & 67 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,119 @@ type parser struct {
splitKind TokenKind
}

// Parse converts a Pipeline Query Language tabular expression
// Parse converts a Pipeline Query Language query
// into an Abstract Syntax Tree (AST).
func Parse(query string) (*TabularExpr, error) {
func Parse(query string) ([]Statement, error) {
p := &parser{
source: query,
tokens: Scan(query),
}
expr, err := p.tabularExpr()
if p.pos < len(p.tokens) {
trailingToken := p.tokens[p.pos]
if trailingToken.Kind == TokenError {
err = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New(trailingToken.Value),
})
var result []Statement
var resultError error
for {
stmtParser := p.splitSemi()

stmt, err := firstParse(
func() (Statement, error) {
stmt, err := stmtParser.letStatement()
if stmt == nil {
// Prevent returning a non-nil interface.
return nil, err
}
return stmt, err
},
func() (Statement, error) {
expr, err := stmtParser.tabularExpr()
if expr == nil {
// Prevent returning a non-nil interface.
return nil, err
}
return expr, err
},
)

if isNotFound(err) {
// We're okay with empty statements, we just ignore them.
if stmtParser.pos < len(stmtParser.tokens) {
trailingToken := stmtParser.tokens[stmtParser.pos]
if trailingToken.Kind == TokenError {
resultError = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New(trailingToken.Value),
})
} else {
resultError = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New("unrecognized token"),
})
}
}
} else {
err = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New("unrecognized token"),
})
if stmt != nil {
result = append(result, stmt)
}
resultError = joinErrors(resultError, makeErrorOpaque(err))
resultError = joinErrors(resultError, stmtParser.endSplit())
}

// Next token, if present, guaranteed to be a semicolon.
if _, ok := p.next(); !ok {
break
}
} else if isNotFound(err) {
err = &parseError{
}

if resultError != nil {
return result, fmt.Errorf("parse pipeline query language: %w", resultError)
}
return result, nil
}

func firstParse[T any](productions ...func() (T, error)) (T, error) {
for _, p := range productions[:len(productions)-1] {
x, err := p()
if !isNotFound(err) {
return x, err
}
}
return productions[len(productions)-1]()
}

func (p *parser) letStatement() (*LetStatement, error) {
keyword, _ := p.next()
if keyword.Kind != TokenIdentifier || keyword.Value != "let" {
p.prev()
return nil, &parseError{
source: p.source,
span: keyword.Span,
err: notFoundError{fmt.Errorf("expected 'let', got %s", formatToken(p.source, keyword))},
}
}

stmt := &LetStatement{
Keyword: keyword.Span,
Assign: nullSpan(),
}
var err error
stmt.Name, err = p.ident()
if err != nil {
return stmt, makeErrorOpaque(err)
}
assign, _ := p.next()
if assign.Kind != TokenAssign {
return stmt, &parseError{
source: p.source,
span: indexSpan(len(query)),
err: errors.New("empty query"),
span: assign.Span,
err: fmt.Errorf("expected '=', got %s", formatToken(p.source, assign)),
}
}
stmt.Assign = assign.Span
stmt.X, err = p.expr()
if err != nil {
return expr, fmt.Errorf("parse pipeline query language: %w", err)
return stmt, makeErrorOpaque(err)
}
return expr, nil
return stmt, nil
}

func (p *parser) tabularExpr() (*TabularExpr, error) {
Expand Down Expand Up @@ -294,27 +373,10 @@ func (p *parser) takeOperator(pipe, keyword Token) (*TakeOperator, error) {
Pipe: pipe.Span,
Keyword: keyword.Span,
}

tok, _ := p.next()
if tok.Kind != TokenNumber {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
}
rowCount := &BasicLit{
Kind: tok.Kind,
Value: tok.Value,
ValueSpan: tok.Span,
}
op.RowCount = rowCount
if !rowCount.IsInteger() {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
var err error
op.RowCount, err = p.rowCount()
if err != nil {
return op, makeErrorOpaque(err)
}
return op, nil
}
Expand All @@ -326,30 +388,13 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) {
By: nullSpan(),
}

tok, _ := p.next()
if tok.Kind != TokenNumber {
p.prev()
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
}
rowCount := &BasicLit{
Kind: tok.Kind,
Value: tok.Value,
ValueSpan: tok.Span,
}
op.RowCount = rowCount
if !rowCount.IsInteger() {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
var err error
op.RowCount, err = p.rowCount()
if err != nil {
return op, makeErrorOpaque(err)
}

tok, _ = p.next()
tok, _ := p.next()
if tok.Kind != TokenBy {
p.prev()
return op, &parseError{
Expand All @@ -360,11 +405,28 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) {
}
op.By = tok.Span

var err error
op.Col, err = p.sortTerm()
return op, makeErrorOpaque(err)
}

func (p *parser) rowCount() (Expr, error) {
x, err := p.expr()
if err != nil {
return x, err
}
if lit, ok := x.(*BasicLit); ok {
// Do basic check for common case of literals.
if !lit.IsInteger() {
return x, fmt.Errorf("expected integer, got %s", formatToken(p.source, Token{
Kind: lit.Kind,
Span: lit.ValueSpan,
Value: lit.Value,
}))
}
}
return x, nil
}

func (p *parser) projectOperator(pipe, keyword Token) (*ProjectOperator, error) {
op := &ProjectOperator{
Pipe: pipe.Span,
Expand Down Expand Up @@ -1042,7 +1104,9 @@ func (p *parser) qualifiedIdent() (*QualifiedIdent, error) {
// split advances the parser to right before the next token of the given kind,
// and returns a new parser that reads the tokens that were skipped over.
// It ignores tokens that are in parenthetical groups after the initial parse position.
// If no such token is found, skipTo advances to EOF.
// If no such token is found, split advances to EOF.
//
// For splitting by semicolon, see [*parser.splitSemi].
func (p *parser) split(search TokenKind) *parser {
// stack is the list of expected closing parentheses/brackets.
// When a closing parenthesis/bracket is encountered,
Expand Down Expand Up @@ -1103,6 +1167,31 @@ loop:
}
}

// splitSemi advances the parser to right before the next semicolon,
// and returns a new parser that reads the tokens that were skipped over.
// If no semicolon is found, splitSemi advances to EOF.
func (p *parser) splitSemi() *parser {
start := p.pos
for {
tok, ok := p.next()
if !ok {
return &parser{
source: p.source,
tokens: p.tokens[start:],
splitKind: TokenSemi,
}
}
if tok.Kind == TokenSemi {
p.prev()
return &parser{
source: p.source,
tokens: p.tokens[start:p.pos],
splitKind: TokenSemi,
}
}
}
}

func (p *parser) endSplit() error {
if p.splitKind == 0 {
// This is a bug, but treating as an error instead of panicing.
Expand Down
Loading

0 comments on commit 73d23b4

Please sign in to comment.