ntjohnson1 commented on code in PR #1544:
URL: 
https://github.com/apache/datafusion-python/pull/1544#discussion_r3250913441


##########
python/tests/test_pickle_expr.py:
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""In-process pickle round-trip tests for :class:`Expr`.
+
+Built-in functions and Python scalar UDFs travel with the pickled
+expression and do not need worker-side pre-registration. The worker
+context (:mod:`datafusion.ipc`) is only consulted for UDFs imported
+via the FFI capsule protocol.
+"""
+
+from __future__ import annotations
+
+import pickle
+
+import pyarrow as pa
+import pytest
+from datafusion import Expr, SessionContext, col, lit, udf
+from datafusion.ipc import (
+    clear_worker_ctx,
+    set_worker_ctx,
+)
+
+
[email protected](autouse=True)
+def _reset_worker_ctx():
+    """Ensure every test starts with no worker context installed."""
+    clear_worker_ctx()
+    yield
+    clear_worker_ctx()
+
+
+def _double_udf():
+    return udf(
+        lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
+        [pa.int64()],
+        pa.int64(),
+        volatility="immutable",
+        name="double",
+    )
+
+
+class TestProtoRoundTrip:
+    def test_builtin_round_trip(self):
+        e = col("a") + lit(1)
+        blob = pickle.dumps(e)
+        decoded = pickle.loads(blob)  # noqa: S301
+        assert decoded.canonical_name() == e.canonical_name()
+
+    def test_to_bytes_from_bytes(self):
+        e = col("x") * lit(7)
+        blob = e.to_bytes()
+        assert isinstance(blob, bytes)
+        decoded = Expr.from_bytes(blob)
+        assert decoded.canonical_name() == e.canonical_name()
+
+    def test_explicit_ctx_used(self, ctx):
+        e = col("a") + lit(1)
+        decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx)
+        assert decoded.canonical_name() == e.canonical_name()
+
+
+class TestUDFCodec:
+    """Python scalar UDFs ride inside the proto blob via the Rust codec.
+
+    No worker context needed on the receiver — the cloudpickled callable is
+    embedded in ``fun_definition`` and reconstructed automatically.
+    """
+
+    def test_udf_self_contained_blob(self):
+        e = _double_udf()(col("a"))
+        blob = pickle.dumps(e)
+        # The codec inlines the callable, so the blob is much bigger than a

Review Comment:
   I think this is testing the thing I was asking about but I haven't thought 
deeply enough if it actually does. If I know cloud pickle says it can serialize 
lambdas but if I instead had
   
   ```python
   from foo import double
   def _double_udf():
       return udf(
           double,
           [pa.int64()],
           pa.int64(),
           volatility="immutable",
           name="double",
       )
   ```
   
   Would I still be able to deserialize this on remote in a python environment 
without foo?



##########
crates/core/src/codec.rs:
##########
@@ -284,3 +365,282 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
         self.inner.try_decode_udwf(name, buf)
     }
 }
+
+// 
=============================================================================
+// Shared Python scalar UDF encode / decode helpers
+//
+// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on
+// every `try_encode_udf` / `try_decode_udf` call. Same wire format on
+// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan`
+// or an `ExecutionPlan` round-trips identically.
+// 
=============================================================================
+
+/// Encode a Python scalar UDF inline if `node` is one. Returns
+/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte,
+/// cloudpickled tuple) was written and the caller should skip its
+/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling
+/// the caller to delegate to its `inner`.
+pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut 
Vec<u8>) -> Result<bool> {
+    let Some(py_udf) = node
+        .inner()
+        .as_any()
+        .downcast_ref::<PythonFunctionScalarUDF>()
+    else {
+        return Ok(false);
+    };
+
+    Python::attach(|py| -> Result<bool> {
+        let bytes = encode_python_scalar_udf(py, py_udf)
+            .map_err(|e| 
datafusion::error::DataFusionError::External(Box::new(e)))?;
+        write_wire_header(buf, PY_SCALAR_UDF_FAMILY);
+        buf.extend_from_slice(&bytes);
+        Ok(true)
+    })
+}
+
+/// Decode an inline Python scalar UDF payload. Returns `Ok(None)`
+/// when `buf` does not carry the `DFPYUDF` family prefix, signalling
+/// the caller to delegate to its `inner` codec (and eventually the
+/// `FunctionRegistry`).
+pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> 
Result<Option<Arc<ScalarUDF>>> {
+    let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar 
UDF")? else {
+        return Ok(None);
+    };
+
+    Python::attach(|py| -> Result<Option<Arc<ScalarUDF>>> {
+        let udf = decode_python_scalar_udf(py, payload)
+            .map_err(|e| 
datafusion::error::DataFusionError::External(Box::new(e)))?;
+        Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf))))
+    })
+}
+
+/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`.

Review Comment:
   Maybe it is capture more clearly somewhere else but it feels like there is 
some nuance of the dependency on cloudpickle that's not fully communicated 
here. I didn't do too much of a deep dive on it.
   
   1. cloudpickle only works on the same version of python (I'm not sure if it 
detects the mismatch with a nice error). So potentially your header might want 
to capture the source python version to give a nicer error and advertise that 
there is a limitation of only sending to the same version of python for remote 
workers
   
   2. cloudpickle seems to have serialize by reference (more like dill) and by 
value (super cool). The former needs the function installed in the environment 
so when deserialized it can reference it where maybe the later tries to just 
capture all necessary bits (here is where I didn't deep dive a ton). Those are 
fairly different mental models for support.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to