diff --git a/xorbits_sql/core.py b/xorbits_sql/core.py index 37a5b57..8022cc3 100644 --- a/xorbits_sql/core.py +++ b/xorbits_sql/core.py @@ -109,7 +109,7 @@ def execute( logger.debug("Logical Plan: %s", plan) now = time.time() - result = XorbitsExecutor(tables=tables_).execute(plan) + result = XorbitsExecutor(tables=tables_, schema=schema).execute(plan) logger.debug("Query finished: %f", time.time() - now) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 9bdc223..852aa99 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -16,11 +16,12 @@ import operator from functools import lru_cache +from typing import Any import pandas import xorbits import xorbits.pandas as pd -from sqlglot import exp, planner +from sqlglot import MappingSchema, exp, planner from xoscar.utils import TypeDispatcher, classproperty from .errors import ExecuteError, UnsupportedError @@ -36,23 +37,39 @@ exp.Variance: "var", } +_SQLGLOT_TYPE_TO_DTYPE = { + "float": "float32", + "double": "float64", + "int": "int32", + "tinyint": "int8", + "smallint": "int16", + "bigint": "int64", +} + class XorbitsExecutor: - def __init__(self, tables: Tables | None = None): + def __init__( + self, tables: Tables | None = None, schema: MappingSchema | None = None + ): self.tables = tables or Tables() + self.schema = schema @classproperty @lru_cache(1) def _exp_visitors(cls) -> TypeDispatcher: dispatcher = TypeDispatcher() + for func in exp.ALL_FUNCTIONS: + dispatcher.register(func, cls._func) dispatcher.register(exp.Alias, cls._alias) dispatcher.register(exp.Binary, cls._func) dispatcher.register(exp.Boolean, cls._boolean) + dispatcher.register(exp.Cast, cls._cast) dispatcher.register(exp.Column, cls._column) dispatcher.register(exp.Literal, cls._literal) + dispatcher.register(exp.Null, cls._null) + dispatcher.register(exp.Unary, cls._func) dispatcher.register(exp.Ordered, cls._ordered) - for func in exp.ALL_FUNCTIONS: - dispatcher.register(func, cls._func) + dispatcher.register(exp.Paren, cls._paren) return dispatcher @classmethod @@ -74,7 +91,7 @@ def _literal(literal: exp.Literal, context: dict[str, pd.DataFrame]): elif literal.is_int: return int(literal.this) elif literal.is_star: - return ... + return slice(None) else: return float(literal.this) @@ -82,14 +99,67 @@ def _literal(literal: exp.Literal, context: dict[str, pd.DataFrame]): def _boolean(boolean: exp.Boolean, context: dict[str, pd.DataFrame]): return True if boolean.this else False + @staticmethod + def _null(null: exp.Null, context: dict[str, pd.DataFrame]): + return None + + @classmethod + def _cast( + cls, + cast: exp.Cast, + context: dict[str, pd.DataFrame], + ): + this = cls._visit_exp(cast.this, context) + to = getattr(exp.DataType.Type, str(cast.to)) + + if to == exp.DataType.Type.DATE: + if pandas.api.types.is_scalar(this): + return pandas.to_datetime(this).to_pydatetime().date() + else: + return pd.to_datetime(this).dt.date + elif to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): + return pd.to_datetime(this) + elif to == exp.DataType.Type.BOOLEAN: + if pandas.api.types.is_scalar(this): + return bool(this) + else: + # TODO: convert to arrow string when it's default in pandas + return this.astype(bool) + elif to in exp.DataType.TEXT_TYPES: + if pandas.api.types.is_scalar(this): + return str(this) + else: + # TODO: convert to arrow string when it's default in pandas + return this.astype(str) + elif str(cast.to) in _SQLGLOT_TYPE_TO_DTYPE: + pd_type = _SQLGLOT_TYPE_TO_DTYPE[str(cast.to)] + if pandas.api.types.is_scalar(this): + return pandas.Series([this], dtype=pd_type)[0] + else: + return this.astype(pd_type) + else: + raise NotImplementedError(f"Casting {cast.this} to '{to}' not implemented.") + @staticmethod def _column(column: exp.Column, context: dict[str, pd.DataFrame]) -> pd.Series: return context[column.table][column.name] @classmethod - def _alias(cls, alias: exp.Alias, context: dict[str, pd.DataFrame]) -> pd.Series: + def _alias( + cls, + alias: exp.Alias, + context: dict[str, pd.DataFrame], + ) -> pd.Series: return cls._visit_exp(alias.this, context).rename(alias.output_name) + @classmethod + def _paren( + cls, + paren: exp.Paren, + context: dict[str, pd.DataFrame], + ): + return cls._visit_exp(paren.this, context) + @classproperty @lru_cache(1) def _operator_visitors(cls) -> TypeDispatcher: @@ -100,10 +170,14 @@ def _operator_visitors(cls) -> TypeDispatcher: dispatcher.register(exp.Div, operator.truediv) dispatcher.register(exp.GT, operator.gt) dispatcher.register(exp.GTE, operator.ge) + dispatcher.register(exp.Is, cls._is) dispatcher.register(exp.LT, operator.lt) dispatcher.register(exp.LTE, operator.le) dispatcher.register(exp.Mul, operator.mul) dispatcher.register(exp.NEQ, operator.ne) + dispatcher.register(exp.Not, operator.neg) + dispatcher.register(exp.Or, operator.or_) + dispatcher.register(exp.Like, cls._like) dispatcher.register(exp.Sub, operator.sub) return dispatcher @@ -121,6 +195,18 @@ def _func(cls, func: exp.Expression, context: dict[str, pd.DataFrame]) -> pd.Ser ) return func(*values) + @classmethod + def _like(cls, left: pd.Series, right: str): + r = right.replace("_", ".").replace("%", ".*") + return left.str.contains(r, regex=True, na=True) + + @classmethod + def _is(cls, left: pd.Series, right: Any): + if right is None: + return left.isnull() + else: + return left == right + def execute(self, plan: planner.Plan) -> pd.DataFrame: finished = set() queue = set(plan.leaves) @@ -192,13 +278,33 @@ def scan( return {step.name: self._project_and_filter(step, context, df)} @staticmethod - def _scan_csv(step: planner.Scan) -> dict[str, pd.DataFrame]: + def _schema_to_dtype(schema: dict[str, str]) -> dict[str, str]: + result = dict() + for name, type_name in schema.items(): + try: + result[name] = _SQLGLOT_TYPE_TO_DTYPE[type_name.lower()] + except KeyError: + continue + return result + + def _scan_csv(self, step: planner.Scan) -> dict[str, pd.DataFrame]: alias = step.source.alias source: exp.ReadCSV = step.source.this args = source.expressions filename = source.name - df = pd.read_csv(filename, **{arg.name: arg for arg in args}) + + delimiter = "," + args = iter(arg.name for arg in args) + for k, v in zip(args, args): + if k == "delimiter": + delimiter = v + + dtype = None + if self.schema and alias in self.schema.mapping: + dtype = self._schema_to_dtype(self.schema.mapping[alias]) + + df = pd.read_csv(filename, sep=delimiter, dtype=dtype) return {alias: df} def _project_and_filter( @@ -219,14 +325,15 @@ def _project_and_filter( def aggregate( self, step: planner.Aggregate, context: dict[str, pd.DataFrame] ) -> dict[str, pd.DataFrame]: - dfs = list(context.values()) - assert len(dfs) == 1 - df = dfs[0] + df = context[step.source] group_by = [self._visit_exp(g, context) for g in step.group.values()] if step.operands: for op in step.operands: - df[op.alias_or_name] = self._visit_exp(op, context) + if isinstance(op.this, exp.Star): + df[op.alias_or_name] = 1 + else: + df[op.alias_or_name] = self._visit_exp(op, context) aggregations = dict() names = list(step.group) @@ -244,11 +351,20 @@ def aggregate( column=agg.this.alias_or_name, aggfunc=aggfunc ) - result = df.groupby(group_by).agg(**aggregations).reset_index() - result.columns = names + if aggregations: + if step.group: + result = df.groupby(group_by).agg(**aggregations).reset_index() + else: + result = df.agg(**aggregations).reset_index(drop=True) + result.columns = names + else: + assert len(group_by) == len(names) + result = pd.DataFrame(dict(zip(names, group_by))).drop_duplicates() if step.projections or step.condition: - result = self._project_and_filter(step, {step.name: result}, result) + result = self._project_and_filter( + step, {step.name: result, **{name: result for name in context}}, result + ) if isinstance(step.limit, int): result = result.iloc[: step.limit] @@ -261,19 +377,19 @@ def join( source = step.name source_df = context[source] source_context = {source: source_df} - column_slices = {source: slice(0, source_df.shape[1])} + column_slices = {source: slice(0, len(source_df.dtypes))} df = None for name, join in step.joins.items(): df = context[name] join_context = {name: df} start = max(r.stop for r in column_slices.values()) - column_slices[name] = slice(start, df.shape[1] + start) + column_slices[name] = slice(start, len(df.dtypes) + start) if join.get("source_key"): - df = self._hash_join(join, source_context, join_context) + df = self._hash_join(join, source_df, source_context, df, join_context) else: - df = self._nested_loop_join(join, source_context, join_context) + df = self._nested_loop_join(join, source_df, df) condition = self._visit_exp(join["condition"], {name: df}) if condition is not True: @@ -283,6 +399,7 @@ def join( name: df.iloc[:, column_slice] for name, column_slice in column_slices.items() } + source_df = df if not step.condition and not step.projections: return source_context @@ -292,48 +409,48 @@ def join( if step.projections: return {step.name: sink} else: - return source_context + return {name: sink for name in source_context} def _nested_loop_join( self, join: dict, - source_context: dict[str, pd.DataFrame], - join_context: dict[str, pd.DataFrame], + source_df: pd.DataFrame, + join_df: pd.DataFrame, ) -> pd.DataFrame: def func(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: if pandas.__version__ >= "1.2.0": - return left.merge(right, on="cross") + return left.merge(right, how="cross") else: left["_on"] = 1 right["_on"] = 1 result = left.merge(right, on="_on") return result[left.dtypes.index.tolist() + right.dtypes.index.tolist()] - source_df = next(iter(source_context.values())) - join_df = next(iter(join_context.values())) - return source_df.cartisan_chunk(join_df, func) + return source_df.cartesian_chunk(join_df, func) def _hash_join( self, join: dict, + source_df: pd.DataFrame, source_context: dict[str, pd.DataFrame], + join_df: pd.DataFrame, join_context: dict[str, pd.DataFrame], ) -> pd.DataFrame: cols = [] - source_df = next(iter(source_context.values())) + source_df = pd.DataFrame({c: source_df[c] for c in source_df.dtypes.index}) cols.extend(source_df.dtypes.index.tolist()) left_ons = [] for i, source_key in enumerate(join["source_key"]): - col_name = f"_on_{i}" + col_name = f"__on_{i}" left_ons.append(col_name) source_df[col_name] = self._visit_exp(source_key, source_context) - join_df = next(iter(join_context.values())) + join_df = pd.DataFrame({c: join_df[c] for c in join_df.dtypes.index}) cols.extend(join_df.dtypes.index.tolist()) right_ons = [] for i, join_key in enumerate(join["join_key"]): - col_name = f"_on_{i}" + col_name = f"__on_{i}" right_ons.append(col_name) join_df[col_name] = self._visit_exp(join_key, join_context) @@ -344,9 +461,12 @@ def _hash_join( how = "right" result = source_df.merge(join_df, how=how, left_on=left_ons, right_on=right_ons) - result = result[ - [col for col in result.dtypes.index if not col.startswith("_on_")] + ilocs = [ + i + for i, col in enumerate(result.dtypes.index) + if not col.startswith("__on_") ] + result = result.iloc[:, ilocs] result.columns = cols return result @@ -361,11 +481,10 @@ def _ordered(cls, ordered: exp.Ordered, context: dict[str, pd.DataFrame]): def sort( self, step: planner.Sort, context: dict[str, pd.DataFrame] ) -> dict[str, pd.DataFrame]: - assert len(context) == 1 - df = next(iter(context.values())) + df = context[step.name] + df = pd.DataFrame({n: df[n] for n in df.dtypes.index}) for projection in step.projections: df[projection.alias_or_name] = self._visit_exp(projection, context) - slc = slice(df.shape[1] - len(step.projections), df.shape[1]) sort_context = {"": df, **context} @@ -374,7 +493,7 @@ def sort( ascendings = [] na_position = None for i, (s, descending, cur_na_position) in enumerate(sort): - sort_col = f"_s_{i}" + sort_col = f"__s_{i}" sort_cols.append(sort_col) ascendings.append(not descending) if na_position is None: @@ -383,14 +502,13 @@ def sort( raise NotImplementedError("nulls_first must be same for all sort keys") df[sort_col] = s - df = df.sort_values( - by=sort_cols, ascending=ascendings, na_position=na_position - ).iloc[:, slc] + df = df.sort_values(by=sort_cols, ascending=ascendings, na_position=na_position) + df = df[[p.alias_or_name for p in step.projections]] if isinstance(step.limit, int): df = df.iloc[: step.limit] - return {step.name: df} + return {step.name: df.reset_index(drop=True)} def set_operation( self, step: planner.SetOperation, context: dict[str, pd.DataFrame] diff --git a/xorbits_sql/tests/test_execute.py b/xorbits_sql/tests/test_execute.py index 20a22f8..97e6672 100644 --- a/xorbits_sql/tests/test_execute.py +++ b/xorbits_sql/tests/test_execute.py @@ -111,6 +111,6 @@ def test_sort(prepare_data): expected = raw_df.sort_values(by="c", ascending=False) expected["b"] *= 5 - expected = expected.iloc[:10] + expected = expected.iloc[:10].reset_index(drop=True) result = execute(sql, tables={"t1": xpd.DataFrame(raw_df)}).fetch() pd.testing.assert_frame_equal(result, expected) diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py new file mode 100644 index 0000000..3d8aaa0 --- /dev/null +++ b/xorbits_sql/tests/test_tpc_h.py @@ -0,0 +1,64 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import duckdb +import pandas as pd +import pytest +from sqlglot import exp, parse_one + +from .. import execute +from .helpers import FILE_DIR, TPCH_SCHEMA, load_sql + +DIR = FILE_DIR + "/tpc-h/" + + +@pytest.fixture +def prepare_data(): + conn = duckdb.connect() + + for table, columns in TPCH_SCHEMA.items(): + conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV('{DIR}{table}.csv', delim='|', header=True, columns={columns}) + """ + ) + + sqls = [(sql, expected) for _, sql, expected in load_sql("tpc-h/tpc-h.sql")] + + try: + yield conn, sqls + finally: + conn.close() + + +def _to_csv(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Table) and expression.name not in ("revenue"): + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv', 'delimiter', '|') AS {expression.alias_or_name}" + ) + return expression + + +def test_execute_tpc_h(prepare_data): + conn, sqls = prepare_data + for sql, _ in sqls[:6]: + expected = conn.execute(sql).fetchdf() + result = execute( + parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), + TPCH_SCHEMA, + dialect="duckdb", + ).fetch() + pd.testing.assert_frame_equal(result, expected)