Jefffrey commented on code in PR #21710:
URL: https://github.com/apache/datafusion/pull/21710#discussion_r3162595398
##########
datafusion/spark/src/function/math/ceil.rs:
##########
@@ -18,26 +18,33 @@
use std::sync::Arc;
use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array};
-use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type,
Int64Type};
-use datafusion_common::utils::take_function_args;
-use datafusion_common::{Result, ScalarValue, exec_err};
+use arrow::datatypes::{
+ DataType, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type,
+ Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
+};
+use datafusion_common::types::{NativeType, logical_int32};
+use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
use datafusion_expr::{
- ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature,
+ TypeSignatureClass, Volatility,
};
+use super::scale::get_scale;
+
/// Spark-compatible `ceil` expression
/// <https://spark.apache.org/docs/latest/api/sql/index.html#ceil>
///
/// Differences with DataFusion ceil:
-/// - Spark's ceil returns Int64 for float inputs; DataFusion preserves
+/// - Spark's 1-arg `ceil` returns Int64 for float inputs; DataFusion
preserves
/// the input type (Float32→Float32, Float64→Float64)
-/// - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing
scale
+/// - Spark's `ceil` on Decimal128(p, s) returns Decimal128(p−s+1, 0),
reducing scale
/// to 0; DataFusion preserves the original precision and scale
/// - Spark only supports Decimal128; DataFusion also supports
Decimal32/64/256
/// - Spark does not check for decimal overflow; DataFusion errors on overflow
///
-/// 2-argument ceil(value, scale) is not yet implemented
-/// <https://github.com/apache/datafusion/issues/21560>
+/// Two-argument form `ceil(expr, scale)` returns the same type as `expr` for
+/// integer and floating-point inputs (regardless of `scale`). Decimal inputs
+/// are not yet supported in the 2-arg form (see TODO in execution path).
Review Comment:
I'd prefer not having this detail of decimals not being supported on the
2-arg path in the docstring; it should just be a TODO item for clear visibility
##########
datafusion/spark/src/function/math/ceil.rs:
##########
@@ -69,7 +90,12 @@ impl ScalarUDFImpl for SparkCeil {
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ let has_scale = arg_types.len() == 2;
+
match &arg_types[0] {
+ // 2-arg decimal is not yet supported; report input type so the
planner
+ // does not reject the call before we surface a proper error at
execution.
Review Comment:
We should error at planning time instead of execution time if it is not
supported
##########
datafusion/spark/src/function/math/scale.rs:
##########
@@ -0,0 +1,65 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Shared helpers for scale-taking math functions (`round`, `ceil`, ...).
+
+use datafusion_common::{Result, ScalarValue, exec_err};
+use datafusion_expr::ColumnarValue;
+
+/// Extract the `scale` (decimal places) argument from `args[1]`.
+///
+/// - Returns `Some(0)` when the function is invoked with a single argument
+/// (no scale provided), which matches Spark's default scale of `0`.
+/// - Returns `Some(value)` for any non-NULL signed/unsigned integer scalar.
+/// - Returns `None` when the scale argument is NULL — Spark returns NULL for
+/// `round(expr, NULL)` / `ceil(expr, NULL)`.
+/// - Returns an error for unsupported types or out-of-range integers.
+pub(crate) fn get_scale(fn_name: &str, args: &[ColumnarValue]) ->
Result<Option<i32>> {
+ if args.len() < 2 {
+ return Ok(Some(0));
+ }
+
+ match &args[1] {
+ ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) =>
Ok(Some(i32::from(*v))),
+ ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) =>
Ok(Some(i32::from(*v))),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)),
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => i32::try_from(*v)
+ .map(Some)
+ .map_err(|_| out_of_range_err(fn_name, *v)),
+ ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) =>
Ok(Some(i32::from(*v))),
+ ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) =>
Ok(Some(i32::from(*v))),
+ ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) =>
i32::try_from(*v)
+ .map(Some)
+ .map_err(|_| out_of_range_err(fn_name, *v)),
+ ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) =>
i32::try_from(*v)
+ .map(Some)
+ .map_err(|_| out_of_range_err(fn_name, *v)),
+ ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None),
+ other => exec_err!(
+ "Unsupported type for {fn_name} scale: {}",
+ other.data_type()
+ ),
+ }
+}
+
+fn out_of_range_err<T: std::fmt::Display>(
Review Comment:
We can use `exec_datafusion_err` like so
https://github.com/apache/datafusion/blob/5fda21683ed73f89b06431fac94777d6b1540b02/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs#L289-L291
##########
datafusion/spark/src/function/math/ceil.rs:
##########
@@ -121,11 +168,74 @@ fn decimal128_ceil_precision(precision: u8, scale: i8) ->
u8 {
((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8
}
-fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
+/// Compute the ceiling of an integer at the given decimal `scale`.
+///
+/// - `scale >= 0`: integers have no fractional part; returns `value`
unchanged.
+/// - `scale < 0`: rounds *up* to the nearest multiple of `10^(-scale)`.
+/// For positive values this rounds away from zero; for negative values it
+/// rounds toward zero (ceiling = toward +∞).
+///
+/// If `10^(-scale)` overflows `i64`, returns `0` (matching Spark / `round`).
+fn ceil_integer(value: i64, scale: i32) -> i64 {
+ if scale >= 0 {
+ return value;
+ }
+ let abs_scale = (-scale) as u32;
+ let Some(factor) = 10_i64.checked_pow(abs_scale) else {
+ return 0;
+ };
+ let remainder = value % factor;
+ if remainder > 0 {
+ // Positive remainder: bump up to the next multiple.
+ value.wrapping_sub(remainder).wrapping_add(factor)
+ } else if remainder < 0 {
+ // Negative remainder (negative value not on boundary): truncating
toward
+ // zero already gives the ceiling.
+ value.wrapping_sub(remainder)
+ } else {
+ value
+ }
+}
+
+fn spark_ceil_scalar(
+ value: &ScalarValue,
+ scale: i32,
+ has_scale: bool,
+) -> Result<ColumnarValue> {
let result = match value {
+ // Floats: 2-arg form preserves the input float type for *every* scale,
+ // so the runtime type matches what `return_type` advertised.
+ ScalarValue::Float32(v) if has_scale => {
+ ScalarValue::Float32(v.map(|x| ceil_float(x, scale)))
+ }
+ ScalarValue::Float64(v) if has_scale => {
+ ScalarValue::Float64(v.map(|x| ceil_float(x, scale)))
+ }
+ // 1-arg float: Spark returns Int64.
ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as
i64)),
ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as
i64)),
+ // Integers: the 1-arg path was a plain `cast_to(Int64)`. The 2-arg
path
+ // additionally applies a (possibly negative) scale before casting.
+ v if v.data_type().is_integer() && has_scale => {
+ match v.cast_to(&DataType::Int64)? {
+ ScalarValue::Int64(opt) => {
+ ScalarValue::Int64(opt.map(|x| ceil_integer(x, scale)))
+ }
+ other => {
+ return exec_err!(
+ "Internal error: integer cast_to(Int64) yielded {:?}",
+ other.data_type()
+ );
+ }
+ }
+ }
v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?,
+ // Decimal128 with positive scale (1-arg only).
+ ScalarValue::Decimal128(_, _, _) if has_scale => {
+ return not_impl_err!(
+ "2-argument ceil is not yet supported for decimal inputs"
Review Comment:
If we aren't supporting this yet we should keep the original issue open
instead of marking this PR as closing it
##########
datafusion/spark/src/function/math/ceil.rs:
##########
@@ -168,137 +336,18 @@ fn spark_ceil_array(input: &Arc<dyn
arrow::array::Array>) -> Result<ColumnarValu
Ok(ColumnarValue::Array(result))
}
-#[cfg(test)]
-mod tests {
- use super::*;
- use arrow::array::{Decimal128Array, Float32Array, Float64Array,
Int64Array};
- use datafusion_common::ScalarValue;
-
- #[test]
- fn test_ceil_float64() {
- let input = Float64Array::from(vec![
- Some(125.2345),
- Some(15.0001),
- Some(0.1),
- Some(-0.9),
- Some(-1.1),
- Some(123.0),
- None,
- ]);
- let args = vec![ColumnarValue::Array(Arc::new(input))];
- let result = spark_ceil(&args).unwrap();
- let result = match result {
- ColumnarValue::Array(arr) => arr,
- _ => panic!("Expected array"),
- };
- let result = result.as_primitive::<Int64Type>();
- assert_eq!(
- result,
- &Int64Array::from(vec![
- Some(126),
- Some(16),
- Some(1),
- Some(0),
- Some(-1),
- Some(123),
- None,
- ])
- );
- }
-
- #[test]
- fn test_ceil_float32() {
- let input = Float32Array::from(vec![
- Some(125.2345f32),
- Some(15.0001f32),
- Some(0.1f32),
- Some(-0.9f32),
- Some(-1.1f32),
- Some(123.0f32),
- None,
- ]);
- let args = vec![ColumnarValue::Array(Arc::new(input))];
- let result = spark_ceil(&args).unwrap();
- let result = match result {
- ColumnarValue::Array(arr) => arr,
- _ => panic!("Expected array"),
- };
- let result = result.as_primitive::<Int64Type>();
- assert_eq!(
- result,
- &Int64Array::from(vec![
- Some(126),
- Some(16),
- Some(1),
- Some(0),
- Some(-1),
- Some(123),
- None,
- ])
- );
- }
-
- #[test]
- fn test_ceil_int64() {
- let input = Int64Array::from(vec![Some(1), Some(-1), None]);
- let args = vec![ColumnarValue::Array(Arc::new(input))];
- let result = spark_ceil(&args).unwrap();
- let result = match result {
- ColumnarValue::Array(arr) => arr,
- _ => panic!("Expected array"),
- };
- let result = result.as_primitive::<Int64Type>();
- assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None]));
- }
-
- #[test]
- fn test_ceil_decimal128() {
- // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00
- let return_type = DataType::Decimal128(9, 0);
- let input = Decimal128Array::from(vec![Some(150), Some(-150),
Some(100), None])
- .with_data_type(DataType::Decimal128(10, 2));
- let args = vec![ColumnarValue::Array(Arc::new(input))];
- let result = spark_ceil(&args).unwrap();
- let result = match result {
- ColumnarValue::Array(arr) => arr,
- _ => panic!("Expected array"),
- };
- let result = result.as_primitive::<Decimal128Type>();
- let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1),
None])
- .with_data_type(return_type);
- assert_eq!(result, &expected);
- }
-
- #[test]
- fn test_ceil_float64_scalar() {
- let input = ScalarValue::Float64(Some(-1.1));
- let args = vec![ColumnarValue::Scalar(input)];
- let result = match spark_ceil(&args).unwrap() {
- ColumnarValue::Scalar(v) => v,
- _ => panic!("Expected scalar"),
- };
- assert_eq!(result, ScalarValue::Int64(Some(-1)));
- }
-
- #[test]
- fn test_ceil_float32_scalar() {
- let input = ScalarValue::Float32(Some(125.2345f32));
- let args = vec![ColumnarValue::Scalar(input)];
- let result = match spark_ceil(&args).unwrap() {
- ColumnarValue::Scalar(v) => v,
- _ => panic!("Expected scalar"),
- };
- assert_eq!(result, ScalarValue::Int64(Some(126)));
- }
-
- #[test]
- fn test_ceil_int64_scalar() {
- let input = ScalarValue::Int64(Some(48));
- let args = vec![ColumnarValue::Scalar(input)];
- let result = match spark_ceil(&args).unwrap() {
- ColumnarValue::Scalar(v) => v,
- _ => panic!("Expected scalar"),
- };
- assert_eq!(result, ScalarValue::Int64(Some(48)));
+fn ceil_float<T: num_traits::Float>(value: T, scale: i32) -> T {
+ if scale >= 0 {
+ let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
+ if factor.is_infinite() {
+ return value;
Review Comment:
Does this match with Spark behaviour?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]