This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new bd45acbf0 feat: (remote) shuffle reader cleanup (#1503)
bd45acbf0 is described below
commit bd45acbf090764b2324d272bca320b3fc3d160bd
Author: Marko Milenković <[email protected]>
AuthorDate: Fri Mar 13 11:54:42 2026 +0000
feat: (remote) shuffle reader cleanup (#1503)
* shuffle (remote) reader cleanup
* fix review comments
* minor
* skip index check if sort shuffle is disabled
---
ballista/core/src/client.rs | 61 +++--
ballista/core/src/config.rs | 40 +--
.../core/src/execution_plans/distributed_query.rs | 8 +-
.../core/src/execution_plans/shuffle_reader.rs | 285 ++++++++-------------
ballista/core/src/extension.rs | 38 ++-
ballista/core/src/utils.rs | 28 +-
examples/examples/standalone-substrait.rs | 2 -
7 files changed, 224 insertions(+), 238 deletions(-)
diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs
index c01d431ca..fc771fd14 100644
--- a/ballista/core/src/client.rs
+++ b/ballista/core/src/client.rs
@@ -17,17 +17,11 @@
//! Client API for sending requests to executors.
-use std::collections::HashMap;
-use std::sync::Arc;
-
-use std::{
- convert::{TryFrom, TryInto},
- task::{Context, Poll},
-};
-
use crate::error::{BallistaError, Result as BResult};
+use crate::extension::BallistaConfigGrpcEndpoint;
+use crate::serde::protobuf;
use crate::serde::scheduler::{Action, PartitionId};
-
+use crate::utils::create_grpc_client_endpoint;
use arrow_flight;
use arrow_flight::Ticket;
use arrow_flight::utils::flight_data_to_arrow_batch;
@@ -43,21 +37,23 @@ use datafusion::arrow::{
};
use datafusion::error::DataFusionError;
use datafusion::error::Result;
-
-use crate::extension::BallistaConfigGrpcEndpoint;
-use crate::serde::protobuf;
-
-use crate::utils::create_grpc_client_endpoint;
-
use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use futures::{Stream, StreamExt};
use log::{debug, warn};
use prost::Message;
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::{
+ convert::{TryFrom, TryInto},
+ task::{Context, Poll},
+};
use tonic::{Code, Streaming};
/// Client for interacting with Ballista executors.
#[derive(Clone)]
pub struct BallistaClient {
+ host: String,
+ port: u16,
flight_client: FlightServiceClient<tonic::transport::channel::Channel>,
}
@@ -109,7 +105,11 @@ impl BallistaClient {
debug!("BallistaClient connected OK: {flight_client:?}");
- Ok(Self { flight_client })
+ Ok(Self {
+ flight_client,
+ host: host.to_string(),
+ port,
+ })
}
/// Retrieves a partition from an executor.
@@ -117,13 +117,42 @@ impl BallistaClient {
/// Depending on the value of the `flight_transport` parameter, this
method will utilize either
/// the Arrow Flight protocol for compatibility, or a more efficient
block-based transfer mechanism.
/// The block-based transfer is optimized for performance and reduces
computational overhead on the server.
+ ///
+ /// This method is to be used for direct connection to the executor
holding the required shuffle partition.
pub async fn fetch_partition(
&mut self,
executor_id: &str,
partition_id: &PartitionId,
path: &str,
+ flight_transport: bool,
+ ) -> BResult<SendableRecordBatchStream> {
+ let host = self.host.to_owned();
+ let port = self.port;
+ self.fetch_partition_proxied(
+ executor_id,
+ partition_id,
+ &host,
+ port,
+ path,
+ flight_transport,
+ )
+ .await
+ }
+
+ /// Retrieves a partition from an executor.
+ ///
+ /// Depending on the value of the `flight_transport` parameter, this
method will utilize either
+ /// the Arrow Flight protocol for compatibility, or a more efficient
block-based transfer mechanism.
+ /// The block-based transfer is optimized for performance and reduces
computational overhead on the server.
+ ///
+ /// This method should be used if the request may be proxied.
+ pub async fn fetch_partition_proxied(
+ &mut self,
+ executor_id: &str,
+ partition_id: &PartitionId,
host: &str,
port: u16,
+ path: &str,
flight_transport: bool,
) -> BResult<SendableRecordBatchStream> {
let action = Action::FetchPartition {
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index ca151558c..2fde0a2b2 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -18,23 +18,19 @@
//! Ballista configuration
-use std::result;
-use std::{collections::HashMap, fmt::Display};
-
use crate::error::{BallistaError, Result};
-
use datafusion::{
arrow::datatypes::DataType, common::config_err, config::ConfigExtension,
};
+use std::result;
+use std::{collections::HashMap, fmt::Display};
/// Configuration key for setting the job name displayed in the web UI.
pub const BALLISTA_JOB_NAME: &str = "ballista.job.name";
/// Configuration key for standalone processing parallelism.
pub const BALLISTA_STANDALONE_PARALLELISM: &str =
"ballista.standalone.parallelism";
-
/// Configuration key for disabling default cache extension node.
pub const BALLISTA_CACHE_NOOP: &str = "ballista.cache.noop";
-
/// Configuration key for maximum concurrent shuffle read requests.
pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str =
"ballista.shuffle.max_concurrent_read_requests";
@@ -44,7 +40,6 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
/// Configuration key to prefer Flight protocol for remote shuffle reads.
pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
"ballista.shuffle.remote_read_prefer_flight";
-
/// max message size for gRPC clients
pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
"ballista.grpc_client_max_message_size";
@@ -82,6 +77,8 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
"ballista.shuffle.sort_based.batch_size";
/// Should client employ pull or push job tracking strategy
pub const BALLISTA_CLIENT_PULL: &str = "ballista.client.pull";
+/// Should client use tls connection
+pub const BALLISTA_CLIENT_USE_TLS: &str = "ballista.client.use_tls";
/// Result type for configuration parsing operations.
pub type ParseResult<T> = result::Result<T, String>;
@@ -162,6 +159,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String,
ConfigEntry>> = LazyLock::new(||
ConfigEntry::new(BALLISTA_CLIENT_PULL.to_string(),
"Should client employ pull or push job tracking. In
pull mode client will make a request to server in the loop, until job finishes.
Pull mode is kept for legacy clients.".to_string(),
DataType::Boolean,
+ Some(false.to_string())),
+ ConfigEntry::new(BALLISTA_CLIENT_USE_TLS.to_string(),
+ "Should connection between client, scheduler, and
executors use TLS.".to_string(),
+ DataType::Boolean,
Some(false.to_string()))
];
entries
@@ -274,11 +275,6 @@ impl BallistaConfig {
&self.settings
}
- /// Returns the maximum message size for gRPC clients in bytes.
- pub fn default_grpc_client_max_message_size(&self) -> usize {
- self.get_usize_setting(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE)
- }
-
/// Returns the standalone processing parallelism level.
pub fn default_standalone_parallelism(&self) -> usize {
self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM)
@@ -290,25 +286,30 @@ impl BallistaConfig {
}
/// Returns the gRPC client connection timeout in seconds.
- pub fn default_grpc_client_connect_timeout_seconds(&self) -> usize {
+ pub fn grpc_client_connect_timeout_seconds(&self) -> usize {
self.get_usize_setting(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS)
}
/// Returns the gRPC client request timeout in seconds.
- pub fn default_grpc_client_timeout_seconds(&self) -> usize {
+ pub fn grpc_client_timeout_seconds(&self) -> usize {
self.get_usize_setting(BALLISTA_GRPC_CLIENT_TIMEOUT_SECONDS)
}
/// Returns the TCP keep-alive interval for gRPC clients in seconds.
- pub fn default_grpc_client_tcp_keepalive_seconds(&self) -> usize {
+ pub fn grpc_client_tcp_keepalive_seconds(&self) -> usize {
self.get_usize_setting(BALLISTA_GRPC_CLIENT_TCP_KEEPALIVE_SECONDS)
}
/// Returns the HTTP/2 keep-alive interval for gRPC clients in seconds.
- pub fn default_grpc_client_http2_keepalive_interval_seconds(&self) ->
usize {
+ pub fn grpc_client_http2_keepalive_interval_seconds(&self) -> usize {
self.get_usize_setting(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS)
}
+ /// Returns the maximum message size for gRPC clients in bytes.
+ pub fn grpc_client_max_message_size(&self) -> usize {
+ self.get_usize_setting(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE)
+ }
+
/// Returns whether the default cache node extension is disabled.
pub fn cache_noop(&self) -> bool {
self.get_bool_setting(BALLISTA_CACHE_NOOP)
@@ -373,6 +374,11 @@ impl BallistaConfig {
self.get_bool_setting(BALLISTA_CLIENT_PULL)
}
+ /// should client use TLS to communicate with ballista cluster
+ pub fn client_use_tls(&self) -> bool {
+ self.get_bool_setting(BALLISTA_CLIENT_USE_TLS)
+ }
+
fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
@@ -539,7 +545,7 @@ mod tests {
#[test]
fn default_config() -> Result<()> {
let config = BallistaConfig::default();
- assert_eq!(16777216, config.default_grpc_client_max_message_size());
+ assert_eq!(16777216, config.grpc_client_max_message_size());
Ok(())
}
}
diff --git a/ballista/core/src/execution_plans/distributed_query.rs
b/ballista/core/src/execution_plans/distributed_query.rs
index 5f1ad258d..533ee3541 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -253,7 +253,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
self.scheduler_url.clone(),
self.session_id.clone(),
query,
- self.config.default_grpc_client_max_message_size(),
+ self.config.grpc_client_max_message_size(),
GrpcClientConfig::from(&self.config),
Arc::new(self.metrics.clone()),
partition,
@@ -280,7 +280,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
execute_query_push(
self.scheduler_url.clone(),
query,
- self.config.default_grpc_client_max_message_size(),
+ self.config.grpc_client_max_message_size(),
GrpcClientConfig::from(&self.config),
Arc::new(self.metrics.clone()),
partition,
@@ -717,12 +717,12 @@ async fn fetch_partition(
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
ballista_client
- .fetch_partition(
+ .fetch_partition_proxied(
&metadata.id,
&partition_id.into(),
- &location.path,
host,
port,
+ &location.path,
flight_transport,
)
.await
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs
b/ballista/core/src/execution_plans/shuffle_reader.rs
index c3676a93b..bc81b2150 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -15,49 +15,46 @@
// specific language governing permissions and limitations
// under the License.
-use async_trait::async_trait;
-use datafusion::arrow::ipc::reader::StreamReader;
-use datafusion::common::stats::Precision;
-use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer,
PushBatchStatus};
-use std::any::Any;
-use std::collections::HashMap;
-use std::fmt::Debug;
-use std::fs::File;
-use std::io::BufReader;
-use std::pin::Pin;
-use std::result;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
use crate::client::BallistaClient;
+use crate::error::BallistaError;
use crate::execution_plans::sort_shuffle::{
get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition,
};
use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt};
use crate::serde::scheduler::{PartitionLocation, PartitionStats};
-
+use crate::utils::GrpcClientConfig;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::error::ArrowError;
+use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::runtime::SpawnedTask;
-
+use datafusion::common::stats::Precision;
use datafusion::error::{DataFusionError, Result};
+use datafusion::execution::context::TaskContext;
+use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer,
PushBatchStatus};
use datafusion::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
Partitioning,
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
+use datafusion::prelude::SessionConfig;
use futures::{Stream, StreamExt, TryStreamExt, ready};
-
-use crate::error::BallistaError;
-use datafusion::execution::context::TaskContext;
-use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use itertools::Itertools;
use log::{debug, error, trace};
use rand::prelude::SliceRandom;
use rand::rng;
+use std::any::Any;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::fs::File;
+use std::io::BufReader;
+use std::pin::Pin;
+use std::result;
+use std::sync::Arc;
+use std::task::{Context, Poll};
use tokio::sync::{Semaphore, mpsc};
use tokio_stream::wrappers::ReceiverStream;
@@ -162,17 +159,9 @@ impl ExecutionPlan for ShuffleReaderExec {
debug!("ShuffleReaderExec::execute({task_id})");
let config = context.session_config();
-
- let max_request_num =
- config.ballista_shuffle_reader_maximum_concurrent_requests();
- let max_message_size = config.ballista_grpc_client_max_message_size();
- let force_remote_read =
config.ballista_shuffle_reader_force_remote_read();
- let prefer_flight =
config.ballista_shuffle_reader_remote_prefer_flight();
let batch_size = config.batch_size();
- let customize_endpoint =
config.ballista_override_create_grpc_client_endpoint();
- let use_tls = config.ballista_use_tls();
- if force_remote_read {
+ if config.ballista_shuffle_reader_force_remote_read() {
debug!(
"All shuffle partitions will be read as remote partitions! To
disable this behavior set: `{}=false`",
crate::config::BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ
@@ -180,7 +169,9 @@ impl ExecutionPlan for ShuffleReaderExec {
}
log::debug!(
- "ShuffleReaderExec::execute({task_id}) max_request_num:
{max_request_num}, max_message_size: {max_message_size}"
+ "ShuffleReaderExec::execute({task_id}) max_request_num: {},
max_message_size: {}",
+ config.ballista_shuffle_reader_maximum_concurrent_requests(),
+ config.ballista_grpc_client_max_message_size()
);
let mut partition_locations = HashMap::new();
for p in &self.partition[partition] {
@@ -198,15 +189,7 @@ impl ExecutionPlan for ShuffleReaderExec {
.collect();
// Shuffle partitions for evenly send fetching partition requests to
avoid hot executors within multiple tasks
partition_locations.shuffle(&mut rng());
- let response_receiver = send_fetch_partitions(
- partition_locations,
- max_request_num,
- max_message_size,
- force_remote_read,
- prefer_flight,
- customize_endpoint,
- use_tls,
- );
+ let response_receiver = send_fetch_partitions(partition_locations,
config);
let input_stream = Box::pin(RecordBatchStreamAdapter::new(
self.schema.clone(),
@@ -405,19 +388,19 @@ fn local_remote_read_split(
fn send_fetch_partitions(
partition_locations: Vec<PartitionLocation>,
- max_request_num: usize,
- max_message_size: usize,
- force_remote_read: bool,
- flight_transport: bool,
- customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
- use_tls: bool,
+ config: &SessionConfig,
) -> AbortableReceiverStream {
+ let max_request_num =
config.ballista_shuffle_reader_maximum_concurrent_requests();
+ let sort_shuffle_enabled = config.ballista_sort_shuffle_enabled();
+
let (response_sender, response_receiver) = mpsc::channel(max_request_num);
let semaphore = Arc::new(Semaphore::new(max_request_num));
let mut spawned_tasks: Vec<SpawnedTask<()>> = vec![];
- let (local_locations, remote_locations): (Vec<_>, Vec<_>) =
- local_remote_read_split(partition_locations, force_remote_read);
+ let (local_locations, remote_locations): (Vec<_>, Vec<_>) =
local_remote_read_split(
+ partition_locations,
+ config.ballista_shuffle_reader_force_remote_read(),
+ );
debug!(
"local shuffle file counts:{}, remote shuffle file count:{}.",
@@ -427,46 +410,53 @@ fn send_fetch_partitions(
// keep local shuffle files reading in serial order for memory control.
let response_sender_c = response_sender.clone();
- let customize_endpoint_c = customize_endpoint.clone();
- spawned_tasks.push(SpawnedTask::spawn(async move {
- for p in local_locations {
- let r = PartitionReaderEnum::Local
- .fetch_partition(
- &p,
- max_message_size,
- flight_transport,
- customize_endpoint_c.clone(),
- use_tls,
- )
- .await;
- if let Err(e) = response_sender_c.send(r).await {
- error!("Fail to send response event to the channel due to
{e}");
+
+ //
+ // fetching local partitions (read from file)
+ //
+
+ spawned_tasks.push(SpawnedTask::spawn_blocking({
+ move || {
+ for p in local_locations {
+ let r = fetch_partition_local(&p, sort_shuffle_enabled);
+ if let Err(e) = response_sender_c.blocking_send(r) {
+ error!("Fail to send response event to the channel due to
{e}");
+ }
}
}
}));
+ //
+ // fetching remote partitions (uses grpc flight protocol)
+ //
+ let grpc_config: Arc<GrpcClientConfig> =
Arc::new((&config.ballista_config()).into());
+ let customize_endpoint =
config.ballista_override_create_grpc_client_endpoint();
+ let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
+
for p in remote_locations.into_iter() {
let semaphore = semaphore.clone();
let response_sender = response_sender.clone();
- let customize_endpoint_c = customize_endpoint.clone();
- spawned_tasks.push(SpawnedTask::spawn(async move {
- // Block if exceeds max request number.
- let permit = semaphore.acquire_owned().await.unwrap();
- let r = PartitionReaderEnum::FlightRemote
- .fetch_partition(
+
+ spawned_tasks.push(SpawnedTask::spawn({
+ let customize_endpoint = customize_endpoint.clone();
+ let grpc_config = grpc_config.clone();
+ async move {
+ // Block if exceeds max request number.
+ let permit = semaphore.acquire_owned().await.unwrap();
+ let r = fetch_partition_remote(
&p,
- max_message_size,
- flight_transport,
- customize_endpoint_c,
- use_tls,
+ grpc_config,
+ prefer_flight,
+ customize_endpoint,
)
.await;
- // Block if the channel buffer is full.
- if let Err(e) = response_sender.send(r).await {
- error!("Fail to send response event to the channel due to
{e}");
+ // Block if the channel buffer is full.
+ if let Err(e) = response_sender.send(r).await {
+ error!("Fail to send response event to the channel due to
{e}");
+ }
+ // Increase semaphore by dropping existing permits.
+ drop(permit);
}
- // Increase semaphore by dropping existing permits.
- drop(permit);
}));
}
@@ -477,112 +467,68 @@ fn check_is_local_location(location: &PartitionLocation)
-> bool {
std::path::Path::new(location.path.as_str()).exists()
}
-/// Partition reader Trait, different partition reader can have
-#[async_trait]
-trait PartitionReader: Send + Sync + Clone {
- // Read partition data from PartitionLocation
- async fn fetch_partition(
- &self,
- location: &PartitionLocation,
- max_message_size: usize,
- flight_transport: bool,
- customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
- use_tls: bool,
- ) -> result::Result<SendableRecordBatchStream, BallistaError>;
-}
-
-#[derive(Clone)]
-enum PartitionReaderEnum {
- Local,
- FlightRemote,
- #[allow(dead_code)]
- ObjectStoreRemote,
-}
+async fn new_ballista_client(
+ host: &str,
+ port: u16,
+ config: &GrpcClientConfig,
+ customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
+) -> result::Result<BallistaClient, BallistaError> {
+ let max_message_size = config.max_message_size;
+ let use_tls = config.use_tls;
-#[async_trait]
-impl PartitionReader for PartitionReaderEnum {
- // Notice return `BallistaError::FetchFailed` will let scheduler
re-schedule the task.
- async fn fetch_partition(
- &self,
- location: &PartitionLocation,
- max_message_size: usize,
- flight_transport: bool,
- customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
- use_tls: bool,
- ) -> result::Result<SendableRecordBatchStream, BallistaError> {
- match self {
- PartitionReaderEnum::FlightRemote => {
- fetch_partition_remote(
- location,
- max_message_size,
- flight_transport,
- customize_endpoint,
- use_tls,
- )
- .await
- }
- PartitionReaderEnum::Local =>
fetch_partition_local(location).await,
- PartitionReaderEnum::ObjectStoreRemote => {
- fetch_partition_object_store(location).await
- }
- }
- }
+ BallistaClient::try_new(host, port, max_message_size, use_tls,
customize_endpoint)
+ .await
}
async fn fetch_partition_remote(
location: &PartitionLocation,
- max_message_size: usize,
- flight_transport: bool,
+ config: Arc<GrpcClientConfig>,
+ prefer_flight: bool,
customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
- use_tls: bool,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
let metadata = &location.executor_meta;
let partition_id = &location.partition_id;
- // TODO for shuffle client connections, we should avoid creating new
connections again and again.
- // And we should also avoid to keep alive too many connections for long
time.
let host = metadata.host.as_str();
let port = metadata.port;
- let mut ballista_client = BallistaClient::try_new(
- host,
- port,
- max_message_size,
- use_tls,
- customize_endpoint,
- )
- .await
- .map_err(|error| match error {
- // map grpc connection error to partition fetch error.
- BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
- metadata.id.clone(),
- partition_id.stage_id,
- partition_id.partition_id,
- msg,
- ),
- other => other,
- })?;
+
+ // TODO for shuffle client connections, we should avoid creating new
connections again and again.
+ // And we should also avoid to keep alive too many connections for long
time.
+ let mut ballista_client =
+ new_ballista_client(host, port, &config, customize_endpoint)
+ .await
+ .map_err(|error| match error {
+ // map grpc connection error to partition fetch error.
+ BallistaError::GrpcConnectionError(msg) =>
BallistaError::FetchFailed(
+ metadata.id.clone(),
+ partition_id.stage_id,
+ partition_id.partition_id,
+ msg,
+ ),
+ other => other,
+ })?;
ballista_client
- .fetch_partition(
- &metadata.id,
- partition_id,
- &location.path,
- host,
- port,
- flight_transport,
- )
+ .fetch_partition(&metadata.id, partition_id, &location.path,
prefer_flight)
.await
}
-async fn fetch_partition_local(
+fn fetch_partition_local(
location: &PartitionLocation,
+ sort_shuffle_enabled: bool,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
let path = &location.path;
let metadata = &location.executor_meta;
let partition_id = &location.partition_id;
let data_path = std::path::Path::new(path);
+ // TODO: we check if file is there then we open it alter
+ // replace this check with open, and check for error
+ //
// Check if this is a sort-based shuffle output (has index file)
- if is_sort_shuffle_output(data_path) {
+ if sort_shuffle_enabled && is_sort_shuffle_output(data_path) {
+ // note: in some cases sort shuffle is not going to be used
+ // even its enabled. thus we need to check if there is
+ // sort shuffle file index
debug!(
"Reading sort-based shuffle for partition {} from {:?}",
partition_id.partition_id, data_path
@@ -622,7 +568,8 @@ fn fetch_partition_local_inner(
let file = File::open(path).map_err(|e| {
BallistaError::General(format!("Failed to open partition file at
{path}: {e:?}"))
})?;
- let file = BufReader::new(file);
+ // TODO: make this configurable
+ let file = BufReader::with_capacity(256 * 1024, file);
// Safety: setting `skip_validation` requires `unsafe`, user assures data
is valid
let reader = unsafe {
StreamReader::try_new(file, None)
@@ -637,14 +584,6 @@ fn fetch_partition_local_inner(
Ok(reader)
}
-async fn fetch_partition_object_store(
- _location: &PartitionLocation,
-) -> result::Result<SendableRecordBatchStream, BallistaError> {
- Err(BallistaError::NotImplemented(
- "Should not use ObjectStorePartitionReader".to_string(),
- ))
-}
-
struct CoalescedShuffleReaderStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
@@ -1120,16 +1059,10 @@ mod tests {
partition_num,
file_path.to_str().unwrap().to_string(),
);
+ let config = SessionConfig::new_with_ballista()
+
.with_ballista_shuffle_reader_maximum_concurrent_requests(max_request_num);
- let response_receiver = send_fetch_partitions(
- partition_locations,
- max_request_num,
- 4 * 1024 * 1024,
- false,
- true,
- None,
- false,
- );
+ let response_receiver = send_fetch_partitions(partition_locations,
&config);
let stream = RecordBatchStreamAdapter::new(
Arc::new(schema),
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 2c777097c..f8292645d 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::config::{
- BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME,
+ BALLISTA_CLIENT_USE_TLS, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE,
BALLISTA_JOB_NAME,
BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ,
BALLISTA_SHUFFLE_READER_MAX_REQUESTS,
BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT,
BALLISTA_STANDALONE_PARALLELISM,
BallistaConfig,
@@ -233,6 +233,9 @@ pub trait SessionConfigExt {
/// Get whether to use TLS for executor connections
fn ballista_use_tls(&self) -> bool;
+
+ /// Is short shuffle used
+ fn ballista_sort_shuffle_enabled(&self) -> bool;
}
/// [SessionConfigHelperExt] is set of [SessionConfig] extension methods
@@ -392,10 +395,8 @@ impl SessionConfigExt for SessionConfig {
self.options()
.extensions
.get::<BallistaConfig>()
- .map(|c| c.default_grpc_client_max_message_size())
- .unwrap_or_else(|| {
-
BallistaConfig::default().default_grpc_client_max_message_size()
- })
+ .map(|c| c.grpc_client_max_message_size())
+ .unwrap_or_else(||
BallistaConfig::default().grpc_client_max_message_size())
}
fn with_ballista_job_name(self, job_name: &str) -> Self {
@@ -435,6 +436,14 @@ impl SessionConfigExt for SessionConfig {
})
}
+ fn ballista_sort_shuffle_enabled(&self) -> bool {
+ self.options()
+ .extensions
+ .get::<BallistaConfig>()
+ .map(|c| c.shuffle_sort_based_enabled())
+ .unwrap_or_else(||
BallistaConfig::default().shuffle_sort_based_enabled())
+ }
+
fn with_ballista_shuffle_reader_maximum_concurrent_requests(
self,
max_requests: usize,
@@ -538,13 +547,20 @@ impl SessionConfigExt for SessionConfig {
}
fn with_ballista_use_tls(self, use_tls: bool) -> Self {
- self.with_extension(Arc::new(BallistaUseTls(use_tls)))
+ if self.options().extensions.get::<BallistaConfig>().is_some() {
+ self.set_bool(BALLISTA_CLIENT_USE_TLS, use_tls)
+ } else {
+ self.with_option_extension(BallistaConfig::default())
+ .set_bool(BALLISTA_CLIENT_USE_TLS, use_tls)
+ }
}
fn ballista_use_tls(&self) -> bool {
- self.get_extension::<BallistaUseTls>()
- .map(|ext| ext.0)
- .unwrap_or(false)
+ self.options()
+ .extensions
+ .get::<BallistaConfig>()
+ .map(|c| c.client_use_tls())
+ .unwrap_or_else(|| BallistaConfig::default().client_use_tls())
}
}
@@ -746,10 +762,6 @@ impl BallistaConfigGrpcEndpoint {
}
}
-/// Wrapper for cluster-wide TLS configuration
-#[derive(Clone, Copy)]
-pub struct BallistaUseTls(pub bool);
-
#[derive(Debug)]
struct BallistaCacheFactory;
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 0d6a3d833..f71c74d72 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -63,19 +63,23 @@ pub struct GrpcClientConfig {
pub tcp_keepalive_seconds: u64,
/// HTTP/2 keep-alive ping interval in seconds
pub http2_keepalive_interval_seconds: u64,
+ /// Should client use tls
+ pub use_tls: bool,
+ /// Returns the maximum message size for gRPC clients in bytes.
+ pub max_message_size: usize,
}
impl From<&BallistaConfig> for GrpcClientConfig {
fn from(config: &BallistaConfig) -> Self {
Self {
- connect_timeout_seconds:
config.default_grpc_client_connect_timeout_seconds()
- as u64,
- timeout_seconds: config.default_grpc_client_timeout_seconds() as
u64,
- tcp_keepalive_seconds:
config.default_grpc_client_tcp_keepalive_seconds()
- as u64,
+ connect_timeout_seconds:
config.grpc_client_connect_timeout_seconds() as u64,
+ timeout_seconds: config.grpc_client_timeout_seconds() as u64,
+ tcp_keepalive_seconds: config.grpc_client_tcp_keepalive_seconds()
as u64,
http2_keepalive_interval_seconds: config
- .default_grpc_client_http2_keepalive_interval_seconds()
+ .grpc_client_http2_keepalive_interval_seconds()
as u64,
+ use_tls: config.client_use_tls(),
+ max_message_size: config.grpc_client_max_message_size(),
}
}
}
@@ -87,6 +91,8 @@ impl Default for GrpcClientConfig {
timeout_seconds: 20,
tcp_keepalive_seconds: 3600,
http2_keepalive_interval_seconds: 300,
+ use_tls: false,
+ max_message_size: 16 * 1024 * 1024,
}
}
}
@@ -312,19 +318,19 @@ mod tests {
// Verify the conversion picks up the right values
assert_eq!(
grpc_config.connect_timeout_seconds,
- ballista_config.default_grpc_client_connect_timeout_seconds() as
u64
+ ballista_config.grpc_client_connect_timeout_seconds() as u64
);
assert_eq!(
grpc_config.timeout_seconds,
- ballista_config.default_grpc_client_timeout_seconds() as u64
+ ballista_config.grpc_client_timeout_seconds() as u64
);
assert_eq!(
grpc_config.tcp_keepalive_seconds,
- ballista_config.default_grpc_client_tcp_keepalive_seconds() as u64
+ ballista_config.grpc_client_tcp_keepalive_seconds() as u64
);
assert_eq!(
grpc_config.http2_keepalive_interval_seconds,
-
ballista_config.default_grpc_client_http2_keepalive_interval_seconds() as u64
+ ballista_config.grpc_client_http2_keepalive_interval_seconds() as
u64
);
}
@@ -335,6 +341,8 @@ mod tests {
timeout_seconds: 30,
tcp_keepalive_seconds: 1800,
http2_keepalive_interval_seconds: 150,
+ use_tls: false,
+ max_message_size: 16 * 1024 * 1024,
};
let result = create_grpc_client_endpoint("http://localhost:50051",
Some(&config));
assert!(result.is_ok());
diff --git a/examples/examples/standalone-substrait.rs
b/examples/examples/standalone-substrait.rs
index 7e8b2036c..95caa8370 100644
--- a/examples/examples/standalone-substrait.rs
+++ b/examples/examples/standalone-substrait.rs
@@ -416,8 +416,6 @@ impl SubstraitSchedulerClient {
&metadata.id,
&partition_id.into(),
&location.path,
- host,
- port,
flight_transport,
)
.await
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]