Skip to content

Commit e0cbf48

Browse files
authored
Implement PyArrow Dataset TableProvider (#9)
* Implement PyArrow Dataset TableProvider and register_dataset context functions. * Add dataset filter test. * Change match on booleans to if else. * Update Dataset TableProvider for updates in DataFusion 10.0.0 trait. * Fixes to build with DataFusion 10.0.0. * Improved DatasetExec physical plan printing. Added nested filter test.
1 parent a1e1e97 commit e0cbf48

File tree

10 files changed

+695
-1
lines changed

10 files changed

+695
-1
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ datafusion-expr = { version = "^10.0.0" }
3939
datafusion-common = { version = "^10.0.0", features = ["pyarrow"] }
4040
uuid = { version = "0.8", features = ["v4"] }
4141
mimalloc = { version = "*", optional = true, default-features = false }
42+
async-trait = "0.1"
43+
futures = "0.3"
4244

4345
[lib]
4446
name = "datafusion_python"

datafusion/tests/test_context.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# under the License.
1717

1818
import pyarrow as pa
19+
import pyarrow.dataset as ds
20+
21+
from datafusion import column, literal
1922

2023

2124
def test_register_record_batches(ctx):
@@ -72,3 +75,70 @@ def test_deregister_table(ctx, database):
7275

7376
ctx.deregister_table("csv")
7477
assert public.names() == {"csv1", "csv2"}
78+
79+
def test_register_dataset(ctx):
80+
# create a RecordBatch and register it as a pyarrow.dataset.Dataset
81+
batch = pa.RecordBatch.from_arrays(
82+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
83+
names=["a", "b"],
84+
)
85+
dataset = ds.dataset([batch])
86+
ctx.register_dataset("t", dataset)
87+
88+
assert ctx.tables() == {"t"}
89+
90+
result = ctx.sql("SELECT a+b, a-b FROM t").collect()
91+
92+
assert result[0].column(0) == pa.array([5, 7, 9])
93+
assert result[0].column(1) == pa.array([-3, -3, -3])
94+
95+
def test_dataset_filter(ctx, capfd):
96+
# create a RecordBatch and register it as a pyarrow.dataset.Dataset
97+
batch = pa.RecordBatch.from_arrays(
98+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
99+
names=["a", "b"],
100+
)
101+
dataset = ds.dataset([batch])
102+
ctx.register_dataset("t", dataset)
103+
104+
assert ctx.tables() == {"t"}
105+
df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5")
106+
107+
# Make sure the filter was pushed down in Physical Plan
108+
df.explain()
109+
captured = capfd.readouterr()
110+
assert "filter_expr=(((2 <= a) and (a <= 3)) and (b > 5))" in captured.out
111+
112+
result = df.collect()
113+
114+
assert result[0].column(0) == pa.array([9])
115+
assert result[0].column(1) == pa.array([-3])
116+
117+
118+
def test_dataset_filter_nested_data(ctx):
119+
# create Arrow StructArrays to test nested data types
120+
data = pa.StructArray.from_arrays(
121+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
122+
names=["a", "b"],
123+
)
124+
batch = pa.RecordBatch.from_arrays(
125+
[data],
126+
names=["nested_data"],
127+
)
128+
dataset = ds.dataset([batch])
129+
ctx.register_dataset("t", dataset)
130+
131+
assert ctx.tables() == {"t"}
132+
133+
df = ctx.table("t")
134+
135+
# This filter will not be pushed down to DatasetExec since it isn't supported
136+
df = df.select(
137+
column("nested_data")["a"] + column("nested_data")["b"],
138+
column("nested_data")["a"] - column("nested_data")["b"],
139+
).filter(column("nested_data")["b"] > literal(5))
140+
141+
result = df.collect()
142+
143+
assert result[0].column(0) == pa.array([9])
144+
assert result[0].column(1) == pa.array([-3])

datafusion/tests/test_sql.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pyarrow as pa
20+
import pyarrow.dataset as ds
2021
import pytest
2122

2223
from datafusion import udf
@@ -121,6 +122,17 @@ def test_register_parquet_partitioned(ctx, tmp_path):
121122
rd = result.to_pydict()
122123
assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1}
123124

125+
def test_register_dataset(ctx, tmp_path):
126+
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
127+
dataset = ds.dataset(path, format="parquet")
128+
129+
ctx.register_dataset("t", dataset)
130+
assert ctx.tables() == {"t"}
131+
132+
result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect()
133+
result = pa.Table.from_batches(result)
134+
assert result.to_pydict() == {"cnt": [100]}
135+
124136

125137
def test_execute(ctx, tmp_path):
126138
data = [1, 1, 2, 2, 3, 11, 12]

src/context.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ use pyo3::prelude::*;
2525

2626
use datafusion::arrow::datatypes::Schema;
2727
use datafusion::arrow::record_batch::RecordBatch;
28+
use datafusion::datasource::datasource::TableProvider;
2829
use datafusion::datasource::MemTable;
2930
use datafusion::execution::context::SessionContext;
3031
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
3132

3233
use crate::catalog::{PyCatalog, PyTable};
3334
use crate::dataframe::PyDataFrame;
35+
use crate::dataset::Dataset;
3436
use crate::errors::DataFusionError;
3537
use crate::udf::PyScalarUDF;
3638
use crate::utils::wait_for_future;
@@ -173,6 +175,17 @@ impl PySessionContext {
173175
Ok(())
174176
}
175177

178+
// Registers a PyArrow.Dataset
179+
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
180+
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
181+
182+
self.ctx
183+
.register_table(name, table)
184+
.map_err(DataFusionError::from)?;
185+
186+
Ok(())
187+
}
188+
176189
fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
177190
self.ctx.register_udf(udf.function);
178191
Ok(())

src/dataset.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use pyo3::exceptions::PyValueError;
19+
/// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset
20+
/// This allows us to use PyArrow Datasets as Datafusion tables while pushing down projections and filters
21+
use pyo3::prelude::*;
22+
use pyo3::types::PyType;
23+
24+
use std::any::Any;
25+
use std::sync::Arc;
26+
27+
use async_trait::async_trait;
28+
29+
use datafusion::arrow::datatypes::SchemaRef;
30+
use datafusion::datasource::datasource::TableProviderFilterPushDown;
31+
use datafusion::datasource::{TableProvider, TableType};
32+
use datafusion::error::{DataFusionError, Result as DFResult};
33+
use datafusion::execution::context::SessionState;
34+
use datafusion::logical_plan::*;
35+
use datafusion::physical_plan::ExecutionPlan;
36+
37+
use crate::dataset_exec::DatasetExec;
38+
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
39+
40+
// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion TableProvider around it
41+
#[derive(Debug, Clone)]
42+
pub(crate) struct Dataset {
43+
dataset: PyObject,
44+
}
45+
46+
impl Dataset {
47+
// Creates a Python PyArrow.Dataset
48+
pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
49+
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
50+
let ds = PyModule::import(py, "pyarrow.dataset")?;
51+
let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
52+
if dataset.is_instance(ds_type)? {
53+
Ok(Dataset {
54+
dataset: dataset.into(),
55+
})
56+
} else {
57+
Err(PyValueError::new_err(
58+
"dataset argument must be a pyarrow.dataset.Dataset object",
59+
))
60+
}
61+
}
62+
}
63+
64+
#[async_trait]
65+
impl TableProvider for Dataset {
66+
/// Returns the table provider as [`Any`](std::any::Any) so that it can be
67+
/// downcast to a specific implementation.
68+
fn as_any(&self) -> &dyn Any {
69+
self
70+
}
71+
72+
/// Get a reference to the schema for this table
73+
fn schema(&self) -> SchemaRef {
74+
Python::with_gil(|py| {
75+
let dataset = self.dataset.as_ref(py);
76+
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
77+
Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
78+
})
79+
}
80+
81+
/// Get the type of this table for metadata/catalog purposes.
82+
fn table_type(&self) -> TableType {
83+
TableType::Base
84+
}
85+
86+
/// Create an ExecutionPlan that will scan the table.
87+
/// The table provider will be usually responsible of grouping
88+
/// the source data into partitions that can be efficiently
89+
/// parallelized or distributed.
90+
async fn scan(
91+
&self,
92+
_ctx: &SessionState,
93+
projection: &Option<Vec<usize>>,
94+
filters: &[Expr],
95+
// limit can be used to reduce the amount scanned
96+
// from the datasource as a performance optimization.
97+
// If set, it contains the amount of rows needed by the `LogicalPlan`,
98+
// The datasource should return *at least* this number of rows if available.
99+
_limit: Option<usize>,
100+
) -> DFResult<Arc<dyn ExecutionPlan>> {
101+
Python::with_gil(|py| {
102+
let plan: Arc<dyn ExecutionPlan> = Arc::new(
103+
DatasetExec::new(py, self.dataset.as_ref(py), projection.clone(), filters)
104+
.map_err(|err| DataFusionError::External(Box::new(err)))?,
105+
);
106+
Ok(plan)
107+
})
108+
}
109+
110+
/// Tests whether the table provider can make use of a filter expression
111+
/// to optimise data retrieval.
112+
fn supports_filter_pushdown(&self, filter: &Expr) -> DFResult<TableProviderFilterPushDown> {
113+
match PyArrowFilterExpression::try_from(filter) {
114+
Ok(_) => Ok(TableProviderFilterPushDown::Exact),
115+
_ => Ok(TableProviderFilterPushDown::Unsupported),
116+
}
117+
}
118+
}

0 commit comments

Comments
 (0)