Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ rewritten to their query representation
- fix the parsing of DATE_ADD/DATE_DIFF with an uncapitalized datetime field argument.
- fix that wrong offset is set when parsing Ion timestamp with time zone into Datum.
- Reimplemented DATE_ADD with interval plus arithmetic.
- Fix ORDER BY statement does not recognize alias from SELECT statement

### Changed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,26 @@ internal object PErrors {
)
}

/**
* @param location see [PError.location]
* @param id see [PError.VAR_REF_AMBIGUOUS]
* @param candidates see [PError.VAR_REF_AMBIGUOUS]
* @return an error representing [PError.VAR_REF_AMBIGUOUS]
*/
internal fun varRefAmbiguous(
location: SourceLocation?,
id: org.partiql.ast.Identifier?,
candidates: List<String?>?
): PError {
return PError(
PError.VAR_REF_AMBIGUOUS,
Severity.ERROR(),
PErrorKind.SEMANTIC(),
location,
mapOf("ID" to id, "CANDIDATES" to candidates)
)
}

/**
* @param path see [PError.INVALID_EXCLUDE_PATH]
* @return an error representing [PError.INVALID_EXCLUDE_PATH]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import org.partiql.planner.PartiQLPlannerPass
import org.partiql.planner.internal.transforms.AstToPlan
import org.partiql.planner.internal.transforms.NormalizeFromSource
import org.partiql.planner.internal.transforms.NormalizeGroupBy
import org.partiql.planner.internal.transforms.OrderByAliasSupport
import org.partiql.planner.internal.transforms.PlanTransform
import org.partiql.planner.internal.typer.PlanTyper
import org.partiql.spi.Context
import org.partiql.spi.catalog.Session
import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.PErrorListener
import org.partiql.spi.errors.PRuntimeException
import org.partiql.spi.types.PType

Expand All @@ -36,7 +38,7 @@ internal class SqlPlanner(
val env = Env(session, ctx.errorListener)

// 1. Normalize
val ast = statement.normalize()
val ast = statement.normalize(ctx.errorListener)

// 2. AST to Rel/Rex
val root = AstToPlan.apply(ast, env)
Expand Down Expand Up @@ -64,11 +66,12 @@ internal class SqlPlanner(
/**
* AST normalization
*/
private fun Statement.normalize(): Statement {
private fun Statement.normalize(listener: PErrorListener): Statement {
// could be a fold, but this is nice for setting breakpoints
var ast = this
ast = NormalizeFromSource.apply(ast)
ast = NormalizeGroupBy.apply(ast)
ast = OrderByAliasSupport(listener).apply(ast)
return ast
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package org.partiql.planner.internal.transforms

import org.partiql.ast.Ast.exprQuerySet
import org.partiql.ast.Ast.orderBy
import org.partiql.ast.Ast.sort
import org.partiql.ast.AstNode
import org.partiql.ast.AstRewriter
import org.partiql.ast.OrderBy
import org.partiql.ast.QueryBody
import org.partiql.ast.SelectItem
import org.partiql.ast.Statement
import org.partiql.ast.With
import org.partiql.ast.expr.Expr
import org.partiql.ast.expr.ExprQuerySet
import org.partiql.ast.expr.ExprVarRef
import org.partiql.planner.internal.PErrors
import org.partiql.spi.errors.PErrorListener
import kotlin.collections.MutableMap

/**
* Replaces ORDER BY aliases with their corresponding SELECT expressions using stack-based scope tracking.
*/
internal class OrderByAliasSupport(val listener: PErrorListener) : AstPass {
override fun apply(statement: Statement): Statement {
return Visitor(listener).visitStatement(statement, ArrayDeque()) as Statement
}

/**
* Maintains alias scope stack for nested queries and resolves ORDER BY aliases.
*/
private class Visitor(val listener: PErrorListener) :
AstRewriter<ArrayDeque<MutableMap<String, MutableList<Expr>>>>() {
/**
* Pushes new alias scope, processes query, then pops scope.
*/
override fun visitExprQuerySet(
node: ExprQuerySet,
ctx: ArrayDeque<MutableMap<String, MutableList<Expr>>>
): AstNode {
// Push new scope
ctx.addLast(mutableMapOf())

// Visit all statements that may have SELECT or ORDER BY
val body = node.body.let { visitQueryBody(it, ctx) as QueryBody }
val orderBy = node.orderBy?.let {
if (body !is QueryBody.SetOp) {
// Skip alias replacement if the query body is set operations
visitOrderBy(it, ctx) as OrderBy?
} else {
node.orderBy
}
}
val with = node.with?.let { visitWith(it, ctx) as With? }
val transformed = if (body !== node.body || orderBy !== node.orderBy || with !== node.with
) {
exprQuerySet(body, orderBy, node.limit, node.offset, with)
} else {
node
}

// Pop scope
ctx.removeLast()
return transformed
}

/**
* Collects SELECT aliases into current scope map.
*/
override fun visitSelectItem(
node: SelectItem,
ctx: ArrayDeque<MutableMap<String, MutableList<Expr>>>
): AstNode {
if (node is SelectItem.Expr) {
node.asAlias?.let { alias ->
ctx.last().getOrPut(alias.text) { mutableListOf() }.add(node.expr)
}
}

return node
}

/**
* Replaces ORDER BY aliases with their SELECT expressions.
*/
override fun visitOrderBy(node: OrderBy, ctx: ArrayDeque<MutableMap<String, MutableList<Expr>>>): AstNode {
val aliasMap = ctx.last()
if (aliasMap.isEmpty()) return node

val transformedSorts = node.sorts.map { sort ->
val transformedExpr = resolveExpr(sort.expr, aliasMap)
if (transformedExpr != sort.expr) {
sort(
expr = transformedExpr,
order = sort.order,
nulls = sort.nulls
)
} else {
sort
}
}
return orderBy(transformedSorts)
}

/**
* Resolves variable references to their aliased expressions.
* Regular identifiers use case-insensitive matching, delimited use case-sensitive.
*/
private fun resolveExpr(expr: Expr, aliasMap: Map<String, List<Expr>>): Expr {
return when (expr) {
is ExprVarRef -> {
val identifier = expr.identifier.identifier
val orderByName = identifier.text

val candidates = if (identifier.isRegular) {
aliasMap.filterKeys { it.equals(orderByName, ignoreCase = true) }.values.flatten()
} else {
aliasMap[orderByName]
}

if (candidates == null) {
expr
} else if (candidates.size == 1) {
candidates[0]
} else {
if (candidates.size > 1) {
val candidateNames = candidates.mapNotNull {
val ref = it
if (ref is ExprVarRef) {
ref.identifier.identifier.text
} else {
null
}
}
listener.report(PErrors.varRefAmbiguous(null, expr.identifier, candidateNames))
}
expr
}
}
else -> expr
}
}
}
}
Loading
Loading