diff --git a/sql/xsql/src/main/scala/org/apache/spark/sql/xsql/execution/datasources/mysql/MysqlSpecialStrategy.scala b/sql/xsql/src/main/scala/org/apache/spark/sql/xsql/execution/datasources/mysql/MysqlSpecialStrategy.scala index 23be7de..fbfc6ff 100644 --- a/sql/xsql/src/main/scala/org/apache/spark/sql/xsql/execution/datasources/mysql/MysqlSpecialStrategy.scala +++ b/sql/xsql/src/main/scala/org/apache/spark/sql/xsql/execution/datasources/mysql/MysqlSpecialStrategy.scala @@ -474,25 +474,29 @@ class TransmitOriginalQuery(session: SparkSession) extends Rule[LogicalPlan] { "rand", "cast", "like") - plan.expressions.forall { - case a: Alias => - if (a.child.isInstanceOf[AttributeReference] || a.child.isInstanceOf[ScalarSubquery] + + @inline def functionNotPushdown(plan: LogicalPlan): Boolean = { + plan.expressions.exists { + case a: Alias => + if (a.child.isInstanceOf[AttributeReference] || a.child.isInstanceOf[ScalarSubquery] || a.child.isInstanceOf[BinaryArithmetic]) { - true - } else { - val func = if (a.child.isInstanceOf[AggregateExpression]) { - a.child.asInstanceOf[AggregateExpression].aggregateFunction - } else { - a.child - } - if (!pushdownFunctions.contains(func.prettyName)) { false } else { - true + val func = if (a.child.isInstanceOf[AggregateExpression]) { + a.child.asInstanceOf[AggregateExpression].aggregateFunction + } else { + a.child + } + if (pushdownFunctions.contains(func.prettyName)) { + false + } else { + true + } } - } - case _ => true + case _ => false + } } + plan.find(functionNotPushdown).isEmpty } /**