This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 108bbaedb9a3 [SPARK-52470][ML][PYTHON][FOLLOW-UP] Further fix GRPC
import
108bbaedb9a3 is described below
commit 108bbaedb9a329bdfa48c49b5b3b5007f59a7d44
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jun 24 15:51:06 2025 +0800
[SPARK-52470][ML][PYTHON][FOLLOW-UP] Further fix GRPC import
### What changes were proposed in this pull request?
to fix GRPC import
### Why are the changes needed?
to fix
https://github.com/apache/spark/actions/runs/15801182350/job/44539663472
PySpark Classic should not depends on PySpark Connect
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
PR builder with
```
default: '{"PYSPARK_IMAGE_TO_TEST": "python-311-classic-only",
"PYTHON_TO_TEST": "python3.11"}'
```
see
https://github.com/zhengruifeng/spark/actions/runs/15840250226/job/44652627469
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51258 from zhengruifeng/fix_grpc_import.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/util.py | 73 +++++++++++++++++++++++------------------------
1 file changed, 36 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index f9a532de10f9..5f23826c73cc 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -278,47 +278,46 @@ def try_remote_call(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
- import pyspark.sql.connect.proto as pb2
- from pyspark.sql.connect.session import SparkSession
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ from pyspark.errors.exceptions.connect import SparkException
+ import pyspark.sql.connect.proto as pb2
+ from pyspark.sql.connect.session import SparkSession
- session = SparkSession.getActiveSession()
+ session = SparkSession.getActiveSession()
- def remote_call() -> Any:
- from pyspark.ml.connect.util import _extract_id_methods
- from pyspark.ml.connect.serialize import serialize, deserialize
- from pyspark.ml.wrapper import JavaModel
+ def remote_call() -> Any:
+ from pyspark.ml.connect.util import _extract_id_methods
+ from pyspark.ml.connect.serialize import serialize, deserialize
+ from pyspark.ml.wrapper import JavaModel
- assert session is not None
- if self._java_obj == ML_CONNECT_HELPER_ID:
- obj_id = ML_CONNECT_HELPER_ID
- else:
- if isinstance(self, JavaModel):
- assert isinstance(self._java_obj, RemoteModelRef)
- obj_id = self._java_obj.ref_id
+ assert session is not None
+ if self._java_obj == ML_CONNECT_HELPER_ID:
+ obj_id = ML_CONNECT_HELPER_ID
else:
- # model summary
- obj_id = self._java_obj # type: ignore
- methods, obj_ref = _extract_id_methods(obj_id)
- methods.append(pb2.Fetch.Method(method=name,
args=serialize(session.client, *args)))
- command = pb2.Command()
- command.ml_command.fetch.CopyFrom(
- pb2.Fetch(obj_ref=pb2.ObjectRef(id=obj_ref), methods=methods)
- )
- (_, properties, _) = session.client.execute_command(command)
- ml_command_result = properties["ml_command_result"]
- if ml_command_result.HasField("summary"):
- summary = ml_command_result.summary
- return summary
- elif ml_command_result.HasField("operator_info"):
- model_info = deserialize(properties)
- # get a new model ref id from the existing model,
- # it is up to the caller to build the model
- return model_info.obj_ref.id
- else:
- return deserialize(properties)
-
- if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- from pyspark.errors.exceptions.connect import SparkException
+ if isinstance(self, JavaModel):
+ assert isinstance(self._java_obj, RemoteModelRef)
+ obj_id = self._java_obj.ref_id
+ else:
+ # model summary
+ obj_id = self._java_obj # type: ignore
+ methods, obj_ref = _extract_id_methods(obj_id)
+ methods.append(pb2.Fetch.Method(method=name,
args=serialize(session.client, *args)))
+ command = pb2.Command()
+ command.ml_command.fetch.CopyFrom(
+ pb2.Fetch(obj_ref=pb2.ObjectRef(id=obj_ref),
methods=methods)
+ )
+ (_, properties, _) = session.client.execute_command(command)
+ ml_command_result = properties["ml_command_result"]
+ if ml_command_result.HasField("summary"):
+ summary = ml_command_result.summary
+ return summary
+ elif ml_command_result.HasField("operator_info"):
+ model_info = deserialize(properties)
+ # get a new model ref id from the existing model,
+ # it is up to the caller to build the model
+ return model_info.obj_ref.id
+ else:
+ return deserialize(properties)
try:
return remote_call()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]