Skip to content

Commit 53a685d

Browse files
committed
Integrate datafusion-distributed with datafusion-python
1 parent a4cf887 commit 53a685d

19 files changed

Lines changed: 1186 additions & 6 deletions

Cargo.lock

Lines changed: 321 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ resolver = "3"
3232

3333
[workspace.dependencies]
3434
tokio = { version = "1.52" }
35+
tonic = { version = "0.14", features = ["transport"] }
3536
pyo3 = { version = "0.28" }
3637
pyo3-async-runtimes = { version = "0.28" }
3738
pyo3-log = "0.13.3"
@@ -50,6 +51,7 @@ datafusion-functions-aggregate = { version = "54" }
5051
datafusion-functions-window = { version = "54" }
5152
datafusion-spark = { version = "54" }
5253
datafusion-expr = { version = "54" }
54+
datafusion-distributed = { version = "2" }
5355
prost = "0.14.3"
5456
serde_json = "1"
5557
uuid = { version = "1.23" }

crates/core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ tokio = { workspace = true, features = [
4040
"rt-multi-thread",
4141
"sync",
4242
] }
43+
tonic = { workspace = true }
4344
pyo3 = { workspace = true, features = [
4445
"extension-module",
4546
"generate-import-lib",
@@ -54,6 +55,7 @@ datafusion-substrait = { workspace = true, optional = true }
5455
datafusion-proto = { workspace = true }
5556
datafusion-ffi = { workspace = true }
5657
datafusion-spark = { workspace = true }
58+
datafusion-distributed = { workspace = true }
5759
prost = { workspace = true } # keep in line with `datafusion-substrait`
5860
serde_json = { workspace = true }
5961
uuid = { workspace = true, features = ["v4"] }

crates/core/src/context.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, Unboun
4343
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
4444
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
4545
use datafusion::execution::session_state::SessionStateBuilder;
46-
use datafusion::execution::{FunctionRegistry, TaskContextProvider};
46+
use datafusion::execution::{FunctionRegistry, SessionState, TaskContextProvider};
4747
use datafusion::prelude::{
4848
AvroReadOptions, CsvReadOptions, DataFrame, JsonReadOptions, ParquetReadOptions,
4949
};
50+
use datafusion_distributed::{DistributedConfig, DistributedExt, SessionStateBuilderExt};
5051
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
5152
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
5253
use datafusion_ffi::config::extension_options::FFI_ExtensionOptions;
@@ -78,6 +79,7 @@ use crate::common::data_type::PyScalarValue;
7879
use crate::common::df_schema::PyDFSchema;
7980
use crate::dataframe::PyDataFrame;
8081
use crate::dataset::Dataset;
82+
use crate::distributed_worker_resolver::PyWorkerResolver;
8183
use crate::errors::{
8284
PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err,
8385
};
@@ -219,6 +221,15 @@ impl PySessionConfig {
219221

220222
Ok(Self::from(config))
221223
}
224+
225+
#[pyo3(signature = (worker_resolver))]
226+
fn with_distributed(&self, worker_resolver: PyWorkerResolver) -> Self {
227+
let config = self
228+
.config
229+
.clone()
230+
.with_distributed_worker_resolver(worker_resolver);
231+
Self::from(config)
232+
}
222233
}
223234

224235
/// Runtime options for a SessionContext
@@ -392,13 +403,20 @@ impl PySessionContext {
392403
} else {
393404
RuntimeEnvBuilder::default()
394405
};
406+
let distributed = DistributedConfig::from_config_options(config.options()).is_ok();
407+
395408
let runtime = Arc::new(runtime_env_builder.build()?);
396-
let session_state = SessionStateBuilder::new()
409+
let mut builder = SessionStateBuilder::new()
397410
.with_config(config)
398411
.with_runtime_env(runtime)
399412
.with_default_features()
400-
.with_analyzer_rule(Arc::new(crate::analyzer::ResolveLambdaVariables::new()))
401-
.build();
413+
.with_analyzer_rule(Arc::new(crate::analyzer::ResolveLambdaVariables::new()));
414+
415+
if distributed {
416+
builder = builder.with_distributed_planner();
417+
}
418+
419+
let session_state = builder.build();
402420
let ctx = Arc::new(SessionContext::new_with_state(session_state));
403421
Ok(PySessionContext {
404422
ctx,
@@ -1430,6 +1448,14 @@ impl PySessionContext {
14301448
}
14311449

14321450
impl PySessionContext {
1451+
pub(crate) fn from_session_state(session_state: SessionState) -> Self {
1452+
Self {
1453+
ctx: Arc::new(SessionContext::new_with_state(session_state)),
1454+
logical_codec: Arc::new(PythonLogicalCodec::default()),
1455+
physical_codec: Arc::new(PythonPhysicalCodec::default()),
1456+
}
1457+
}
1458+
14331459
async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
14341460
self.ctx.table(name).await
14351461
}
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::collections::HashMap;
19+
use std::net::SocketAddr;
20+
21+
use async_trait::async_trait;
22+
use datafusion::common::{DataFusionError, Result as DataFusionResult};
23+
use datafusion::execution::{SessionState, SessionStateBuilder};
24+
use datafusion_distributed::{Worker, WorkerQueryContext, WorkerSessionBuilder};
25+
use datafusion_python_util::wait_for_future;
26+
use pyo3::Borrowed;
27+
use pyo3::exceptions::{PyRuntimeError, PyTypeError};
28+
use pyo3::prelude::*;
29+
use tonic::transport::Server;
30+
31+
use crate::context::PySessionContext;
32+
use crate::errors::{PyDataFusionError, PyDataFusionResult};
33+
34+
#[pyclass(
35+
from_py_object,
36+
frozen,
37+
name = "Worker",
38+
module = "datafusion",
39+
subclass
40+
)]
41+
#[derive(Clone)]
42+
pub struct PyWorker {
43+
worker: Worker,
44+
}
45+
46+
#[pymethods]
47+
impl PyWorker {
48+
#[new]
49+
fn new() -> Self {
50+
Self {
51+
worker: Worker::default(),
52+
}
53+
}
54+
55+
#[staticmethod]
56+
fn from_session_builder(session_builder: PyWorkerSessionBuilder) -> Self {
57+
Self {
58+
worker: Worker::from_session_builder(session_builder),
59+
}
60+
}
61+
62+
fn with_version(&self, version: String) -> Self {
63+
Self {
64+
worker: self.worker.clone().with_version(version),
65+
}
66+
}
67+
68+
fn with_max_message_size(&self, size: usize) -> Self {
69+
Self {
70+
worker: self.worker.clone().with_max_message_size(size),
71+
}
72+
}
73+
74+
#[pyo3(signature = (host = "127.0.0.1", port = 50051))]
75+
fn serve(&self, py: Python<'_>, host: &str, port: u16) -> PyDataFusionResult<()> {
76+
let addr = parse_socket_addr(host, port)?;
77+
let worker = self.worker.clone();
78+
wait_for_future(py, serve_worker(worker, addr))?.map_err(PyDataFusionError::from)
79+
}
80+
81+
#[pyo3(signature = (host = "127.0.0.1", port = 50051))]
82+
fn serve_async<'py>(
83+
&self,
84+
py: Python<'py>,
85+
host: &str,
86+
port: u16,
87+
) -> PyResult<Bound<'py, PyAny>> {
88+
let addr = parse_socket_addr(host, port)?;
89+
let worker = self.worker.clone();
90+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
91+
serve_worker(worker, addr)
92+
.await
93+
.map_err(PyDataFusionError::from)?;
94+
Ok(())
95+
})
96+
}
97+
}
98+
99+
#[pyclass(name = "WorkerQueryContext", module = "datafusion", subclass)]
100+
pub struct PyWorkerQueryContext {
101+
builder: Option<SessionStateBuilder>,
102+
headers: HashMap<String, String>,
103+
}
104+
105+
impl PyWorkerQueryContext {
106+
fn new(ctx: WorkerQueryContext) -> Self {
107+
let headers = ctx
108+
.headers
109+
.iter()
110+
.map(|(name, value)| {
111+
(
112+
name.as_str().to_owned(),
113+
value.to_str().unwrap_or_default().to_owned(),
114+
)
115+
})
116+
.collect();
117+
118+
Self {
119+
builder: Some(ctx.builder),
120+
headers,
121+
}
122+
}
123+
}
124+
125+
#[pymethods]
126+
impl PyWorkerQueryContext {
127+
fn session_context(mut slf: PyRefMut<'_, Self>) -> PyResult<PySessionContext> {
128+
let builder = slf.builder.take().ok_or_else(|| {
129+
PyRuntimeError::new_err("WorkerQueryContext.session_context() can only be called once")
130+
})?;
131+
Ok(PySessionContext::from_session_state(builder.build()))
132+
}
133+
134+
#[getter]
135+
fn headers(&self) -> HashMap<String, String> {
136+
self.headers.clone()
137+
}
138+
}
139+
140+
pub(crate) struct PyWorkerSessionBuilder {
141+
callback: Py<PyAny>,
142+
}
143+
144+
impl FromPyObject<'_, '_> for PyWorkerSessionBuilder {
145+
type Error = PyErr;
146+
147+
fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
148+
if !obj.is_callable() {
149+
return Err(PyTypeError::new_err(
150+
"Expected worker session builder to be callable",
151+
));
152+
}
153+
154+
Ok(Self {
155+
callback: obj.to_owned().unbind(),
156+
})
157+
}
158+
}
159+
160+
#[async_trait]
161+
impl WorkerSessionBuilder for PyWorkerSessionBuilder {
162+
async fn build_session_state(
163+
&self,
164+
ctx: WorkerQueryContext,
165+
) -> Result<SessionState, DataFusionError> {
166+
Python::attach(|py| -> PyResult<SessionState> {
167+
let ctx = Py::new(py, PyWorkerQueryContext::new(ctx))?;
168+
let result = self.callback.call1(py, (ctx,))?;
169+
let session_context = extract_session_context(result.bind(py))?;
170+
Ok(session_context.ctx.state())
171+
})
172+
.map_err(|error| DataFusionError::External(Box::new(error)))
173+
}
174+
}
175+
176+
fn extract_session_context(obj: &Bound<'_, PyAny>) -> PyResult<PySessionContext> {
177+
if let Ok(session_context) = obj.extract::<PySessionContext>() {
178+
return Ok(session_context);
179+
}
180+
181+
if let Ok(ctx_attr) = obj.getattr("ctx")
182+
&& let Ok(session_context) = ctx_attr.extract::<PySessionContext>()
183+
{
184+
return Ok(session_context);
185+
}
186+
187+
Err(PyTypeError::new_err(
188+
"WorkerSessionBuilder.build_session_state() must return a datafusion.SessionContext",
189+
))
190+
}
191+
192+
fn parse_socket_addr(host: &str, port: u16) -> PyDataFusionResult<SocketAddr> {
193+
format!("{host}:{port}").parse().map_err(|error| {
194+
PyDataFusionError::Common(format!(
195+
"invalid worker bind address {host}:{port}: {error}"
196+
))
197+
})
198+
}
199+
200+
async fn serve_worker(worker: Worker, addr: SocketAddr) -> DataFusionResult<()> {
201+
Server::builder()
202+
.add_service(worker.into_worker_server())
203+
.serve(addr)
204+
.await
205+
.map_err(|error| DataFusionError::External(Box::new(error)))
206+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::DataFusionError;
19+
use datafusion_distributed::WorkerResolver;
20+
use pyo3::Borrowed;
21+
use pyo3::exceptions::{PyTypeError, PyValueError};
22+
use pyo3::prelude::*;
23+
use pyo3::types::PyString;
24+
use url::Url;
25+
26+
pub(crate) struct PyWorkerResolver {
27+
get_urls: Py<PyAny>,
28+
}
29+
30+
impl FromPyObject<'_, '_> for PyWorkerResolver {
31+
type Error = PyErr;
32+
33+
fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
34+
let get_urls = obj.getattr("get_urls")?;
35+
if !get_urls.is_callable() {
36+
return Err(PyTypeError::new_err(
37+
"Expected worker_resolver.get_urls to be callable",
38+
));
39+
}
40+
41+
Ok(Self {
42+
get_urls: get_urls.unbind(),
43+
})
44+
}
45+
}
46+
47+
struct WorkerUrls(Vec<Url>);
48+
49+
impl FromPyObject<'_, '_> for WorkerUrls {
50+
type Error = PyErr;
51+
52+
fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
53+
if obj.is_instance_of::<PyString>() {
54+
return Err(PyTypeError::new_err(
55+
"WorkerResolver.get_urls() must return an iterable of URL strings, not a string",
56+
));
57+
}
58+
59+
let mut parsed_urls = Vec::new();
60+
for url in obj.try_iter()? {
61+
let url = url?;
62+
let url = url.extract::<String>()?;
63+
let parsed_url = Url::parse(&url).map_err(|error| {
64+
PyValueError::new_err(format!(
65+
"WorkerResolver.get_urls() returned invalid URL {url:?}: {error}"
66+
))
67+
})?;
68+
parsed_urls.push(parsed_url);
69+
}
70+
71+
Ok(Self(parsed_urls))
72+
}
73+
}
74+
75+
impl WorkerResolver for PyWorkerResolver {
76+
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
77+
Python::attach(|py| -> PyResult<Vec<Url>> {
78+
let urls = self.get_urls.call0(py)?;
79+
let urls = urls.extract::<WorkerUrls>(py)?;
80+
Ok(urls.0)
81+
})
82+
.map_err(|error| DataFusionError::External(Box::new(error)))
83+
}
84+
}

0 commit comments

Comments
 (0)