Skip to content

Commit 0c3a6a1

Browse files
nils-braunrajagurunathGurunath LankupalliVenugopal
authored
Multiple schemas allowed (#205)
* ML model improvement : Adding "SHOW MODELS and DESCRIBE MODEL" Author: rajagurunath <[email protected]> Date: Mon May 24 02:37:40 2021 +0530 * fix typo * ML model improvement : added EXPORT MODEL * ML model improvement : refactoring for PR * ML model improvement : Adding stmts in notebook * ML model improvement : Adding stmts in notebook * ML model improvement : also test the non-happy path * ML model improvement : Added mlflow and <With> in sql for extra params * ML model improvement : Added mlflow and <With> in sql for extra params * Added Test cases for Export MODEL * Added ML documentation about the following: 1. SHOW MODELS 2. DESCRIBE MODEL 3. EXPORT MODEL * refactored based on PR * Added support only for sklearn compatible models * excluded mlflow part from code coverage * install mlflow in test cluster * Added test for non sklearn compatible model * Added: initial draft of referencing multiple schemas * Added schema DDLs 1. Create Schema 2. Use schema 3. Drop schema 4. Added testcases * Use compound identifier for models, tables, experiments, views * Split the compound identifiers - without using the schema so far * Added a schema_name parameter to most functions and actually use the schema * Pass on the schemas to JAVA * Some simplifications and tests * Some cleanup, documentation and more tests (and fixed a bug in aggregation) * Remove unneeded import Co-authored-by: gurunath <[email protected]> Co-authored-by: Gurunath LankupalliVenugopal <[email protected]>
1 parent 3815a49 commit 0c3a6a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+776
-268
lines changed

dask_sql/context.py

Lines changed: 158 additions & 79 deletions
Large diffs are not rendered by default.

dask_sql/datacontainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from typing import Dict, List, Tuple, Union
1+
from collections import namedtuple
2+
from typing import Any, Callable, Dict, List, Tuple, Union
23

34
import dask.dataframe as dd
5+
import pandas as pd
46

57
ColumnType = Union[str, int]
68

9+
FunctionDescription = namedtuple(
10+
"FunctionDescription", ["name", "parameters", "return_type", "aggregation"]
11+
)
12+
713

814
class ColumnContainer:
915
# Forward declaration
@@ -173,3 +179,13 @@ def assign(self) -> dd.DataFrame:
173179
}
174180
)
175181
return df[self.column_container.columns]
182+
183+
184+
class SchemaContainer:
185+
def __init__(self, name: str):
186+
self.__name__ = name
187+
self.tables: Dict[str, DataContainer] = {}
188+
self.experiments: Dict[str, pd.DataFrame] = {}
189+
self.models: Dict[str, Tuple[Any, List[str]]] = {}
190+
self.functions: Dict[str, Callable] = {}
191+
self.function_lists: List[FunctionDescription] = []

dask_sql/integrations/fugue.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def median(df:pd.DataFrame) -> pd.DataFrame:
113113
_global, _local = get_caller_global_local_vars()
114114

115115
dag = FugueSQLWorkflow()
116-
dfs = {} if ctx is None else {k: dag.df(v.df) for k, v in ctx.tables.items()}
116+
dfs = (
117+
{}
118+
if ctx is None
119+
else {k: dag.df(v.df) for k, v in ctx.schema[ctx.schema_name].tables.items()}
120+
)
117121
result = dag._sql(sql, _global, _local, **dfs)
118122
dag.run(DaskSQLExecutionEngine(conf=fugue_conf))
119123

dask_sql/java.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _set_or_check_java_home():
9090
DaskScalarFunction = com.dask.sql.schema.DaskScalarFunction
9191
DaskSchema = com.dask.sql.schema.DaskSchema
9292
RelationalAlgebraGenerator = com.dask.sql.application.RelationalAlgebraGenerator
93+
RelationalAlgebraGeneratorBuilder = (
94+
com.dask.sql.application.RelationalAlgebraGeneratorBuilder
95+
)
9396
SqlTypeName = org.apache.calcite.sql.type.SqlTypeName
9497
ValidationException = org.apache.calcite.tools.ValidationException
9598
SqlParseException = org.apache.calcite.sql.parser.SqlParseException

dask_sql/physical/rel/custom/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,36 @@
22
from .columns import ShowColumnsPlugin
33
from .create_experiment import CreateExperimentPlugin
44
from .create_model import CreateModelPlugin
5+
from .create_schema import CreateSchemaPlugin
56
from .create_table import CreateTablePlugin
67
from .create_table_as import CreateTableAsPlugin
78
from .describe_model import ShowModelParamsPlugin
89
from .drop_model import DropModelPlugin
10+
from .drop_schema import DropSchemaPlugin
911
from .drop_table import DropTablePlugin
1012
from .export_model import ExportModelPlugin
1113
from .predict import PredictModelPlugin
1214
from .schemas import ShowSchemasPlugin
1315
from .show_models import ShowModelsPlugin
16+
from .switch_schema import SwitchSchemaPlugin
1417
from .tables import ShowTablesPlugin
1518

1619
__all__ = [
1720
AnalyzeTablePlugin,
21+
CreateExperimentPlugin,
1822
CreateModelPlugin,
23+
CreateSchemaPlugin,
1924
CreateTableAsPlugin,
2025
CreateTablePlugin,
2126
DropModelPlugin,
27+
DropSchemaPlugin,
2228
DropTablePlugin,
29+
ExportModelPlugin,
2330
PredictModelPlugin,
2431
ShowColumnsPlugin,
32+
ShowModelParamsPlugin,
33+
ShowModelsPlugin,
2534
ShowSchemasPlugin,
2635
ShowTablesPlugin,
27-
ShowModelsPlugin,
28-
ShowModelParamsPlugin,
29-
ExportModelPlugin,
30-
CreateExperimentPlugin,
36+
SwitchSchemaPlugin,
3137
]

dask_sql/physical/rel/custom/analyze.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dask_sql.datacontainer import ColumnContainer, DataContainer
55
from dask_sql.mappings import python_to_sql_type
66
from dask_sql.physical.rel.base import BaseRelPlugin
7-
from dask_sql.utils import get_table_from_compound_identifier
87

98

109
class AnalyzeTablePlugin(BaseRelPlugin):
@@ -28,8 +27,8 @@ class AnalyzeTablePlugin(BaseRelPlugin):
2827
def convert(
2928
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
3029
) -> DataContainer:
31-
components = list(map(str, sql.getTableName().names))
32-
dc = get_table_from_compound_identifier(context, components)
30+
schema_name, name = context.fqn(sql.getTableName())
31+
dc = context.schema[schema_name].tables[name]
3332
columns = list(map(str, sql.getColumnList()))
3433

3534
if not columns:

dask_sql/physical/rel/custom/columns.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dask_sql.datacontainer import ColumnContainer, DataContainer
55
from dask_sql.mappings import python_to_sql_type
66
from dask_sql.physical.rel.base import BaseRelPlugin
7-
from dask_sql.utils import get_table_from_compound_identifier
87

98

109
class ShowColumnsPlugin(BaseRelPlugin):
@@ -22,8 +21,8 @@ class ShowColumnsPlugin(BaseRelPlugin):
2221
def convert(
2322
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
2423
) -> DataContainer:
25-
components = list(map(str, sql.getTable().names))
26-
dc = get_table_from_compound_identifier(context, components)
24+
schema_name, name = context.fqn(sql.getTable())
25+
dc = context.schema[schema_name].tables[name]
2726

2827
cols = dc.column_container.columns
2928
dtypes = list(map(lambda x: str(python_to_sql_type(x)).lower(), dc.df.dtypes))

dask_sql/physical/rel/custom/create_experiment.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ def convert(
9898
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
9999
) -> DataContainer:
100100
select = sql.getSelect()
101-
experiment_name = str(sql.getExperimentName())
101+
schema_name, experiment_name = context.fqn(sql.getExperimentName())
102102
kwargs = convert_sql_kwargs(sql.getKwargs())
103103

104-
if experiment_name in context.experiments:
104+
if experiment_name in context.schema[schema_name].experiments:
105105
if sql.getIfNotExists():
106106
return
107107
elif not sql.getReplace():
@@ -175,6 +175,7 @@ def convert(
175175
experiment_name,
176176
ParallelPostFit(estimator=search.best_estimator_),
177177
X.columns,
178+
schema_name=schema_name,
178179
)
179180

180181
if automl_class:
@@ -198,9 +199,12 @@ def convert(
198199
experiment_name,
199200
ParallelPostFit(estimator=automl.fitted_pipeline_),
200201
X.columns,
202+
schema_name=schema_name,
201203
)
202204

203-
context.register_experiment(experiment_name, experiment_results=df)
205+
context.register_experiment(
206+
experiment_name, experiment_results=df, schema_name=schema_name
207+
)
204208
cc = ColumnContainer(df.columns)
205209
dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)
206210
return dc

dask_sql/physical/rel/custom/create_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def convert(
105105
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
106106
) -> DataContainer:
107107
select = sql.getSelect()
108-
model_name = str(sql.getModelName())
108+
schema_name, model_name = context.fqn(sql.getModelName())
109109
kwargs = convert_sql_kwargs(sql.getKwargs())
110110

111-
if model_name in context.models:
111+
if model_name in context.schema[schema_name].models:
112112
if sql.getIfNotExists():
113113
return
114114
elif not sql.getReplace():
@@ -162,4 +162,4 @@ def convert(
162162
y = None
163163

164164
model.fit(X, y, **fit_kwargs)
165-
context.register_model(model_name, model, X.columns)
165+
context.register_model(model_name, model, X.columns, schema_name=schema_name)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
3+
from dask_sql.datacontainer import DataContainer
4+
from dask_sql.physical.rel.base import BaseRelPlugin
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class CreateSchemaPlugin(BaseRelPlugin):
10+
"""
11+
Create a schema with the given name
12+
and register it at the context.
13+
The SQL call looks like
14+
15+
CREATE SCHEMA <schema-name>
16+
17+
Using this SQL is equivalent to just doing
18+
19+
context.create_schema(<schema-name>)
20+
21+
but can also be used without writing a single line of code.
22+
Nothing is returned.
23+
"""
24+
25+
class_name = "com.dask.sql.parser.SqlCreateSchema"
26+
27+
def convert(
28+
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
29+
):
30+
schema_name = str(sql.getSchemaName())
31+
32+
if schema_name in context.schema:
33+
if sql.getIfNotExists():
34+
return
35+
elif not sql.getReplace():
36+
raise RuntimeError(
37+
f"A Schema with the name {schema_name} is already present."
38+
)
39+
40+
context.create_schema(schema_name)

0 commit comments

Comments
 (0)