Imbruced commented on code in PR #2593:
URL: https://github.com/apache/sedona/pull/2593#discussion_r2812235471


##########
python/sedona/spark/worker/worker.py:
##########
@@ -0,0 +1,304 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import importlib
+import os
+import sys
+import time
+
+import sedonadb
+from pyspark import TaskContext, shuffle, SparkFiles
+from pyspark.errors import PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.resource import ResourceInformation
+from pyspark.serializers import (
+    read_int,
+    UTF8Deserializer,
+    read_bool,
+    read_long,
+    CPickleSerializer,
+    write_int,
+    write_long,
+    SpecialLengths,
+)
+
+from sedona.spark.worker.serde import SedonaDBSerializer
+from sedona.spark.worker.udf_info import UDFInfo
+
+
+def apply_iterator(db, iterator, udf_info: UDFInfo, cast_to_wkb: bool = False):
+    i = 0
+    for df in iterator:
+        i += 1
+        table_name = f"output_table_{i}"
+        df.to_view(table_name)
+
+        function_call_sql = udf_info.get_function_call_sql(
+            table_name, cast_to_wkb=cast_to_wkb
+        )
+
+        df_out = db.sql(function_call_sql)
+
+        df_out.to_view(f"view_{i}")
+        at = df_out.to_arrow_table()
+        batches = at.combine_chunks().to_batches()
+
+        yield from batches
+
+
+def check_python_version(utf_serde: UTF8Deserializer, infile) -> str:
+    version = utf_serde.loads(infile)
+
+    python_major, python_minor = sys.version_info[:2]
+
+    if version != f"{python_major}.{python_minor}":
+        raise PySparkRuntimeError(
+            error_class="PYTHON_VERSION_MISMATCH",
+            message_parameters={
+                "worker_version": str(sys.version_info[:2]),
+                "driver_version": str(version),
+            },
+        )
+
+    return version
+
+
+def check_barrier_flag(infile):
+    is_barrier = read_bool(infile)
+    bound_port = read_int(infile)
+    secret = UTF8Deserializer().loads(infile)
+
+    if is_barrier:
+        raise PySparkRuntimeError(
+            error_class="BARRIER_MODE_NOT_SUPPORTED",
+            message_parameters={
+                "worker_version": str(sys.version_info[:2]),
+                "message": "Barrier mode is not supported by SedonaDB 
vectorized functions.",
+            },
+        )
+
+    return is_barrier
+
+
+def assign_task_context(utf_serde: UTF8Deserializer, infile):
+    stage_id = read_int(infile)
+    partition_id = read_int(infile)
+    attempt_number = read_long(infile)
+    task_attempt_id = read_int(infile)
+    cpus = read_int(infile)
+
+    task_context = TaskContext._getOrCreate()
+    task_context._stage_id = stage_id
+    task_context._partition_id = partition_id
+    task_context._attempt_number = attempt_number
+    task_context._task_attempt_id = task_attempt_id
+    task_context._cpus = cpus
+
+    for r in range(read_int(infile)):
+        key = utf_serde.loads(infile)
+        name = utf_serde.loads(infile)
+        addresses = []
+        task_context._resources = {}
+        for a in range(read_int(infile)):
+            addresses.append(utf_serde.loads(infile))
+        task_context._resources[key] = ResourceInformation(name, addresses)
+
+    task_context._localProperties = {}
+    for i in range(read_int(infile)):
+        k = utf_serde.loads(infile)
+        v = utf_serde.loads(infile)
+        task_context._localProperties[k] = v
+
+    return task_context
+
+
+def resolve_python_path(utf_serde: UTF8Deserializer, infile):
+    def add_path(path: str):
+        # worker can be used, so do not add path multiple times
+        if path not in sys.path:
+            # overwrite system packages
+            sys.path.insert(1, path)
+
+    spark_files_dir = utf_serde.loads(infile)
+
+    SparkFiles._root_directory = spark_files_dir
+    SparkFiles._is_running_on_worker = True
+
+    add_path(spark_files_dir)  # *.py files that were added will be copied here
+    num_python_includes = read_int(infile)
+    for _ in range(num_python_includes):
+        filename = utf_serde.loads(infile)
+        add_path(os.path.join(spark_files_dir, filename))
+
+    importlib.invalidate_caches()
+
+
+def check_broadcast_variables(infile):
+    needs_broadcast_decryption_server = read_bool(infile)
+    num_broadcast_variables = read_int(infile)
+
+    if needs_broadcast_decryption_server or num_broadcast_variables > 0:
+        raise PySparkRuntimeError(
+            error_class="BROADCAST_VARS_NOT_SUPPORTED",
+            message_parameters={
+                "worker_version": str(sys.version_info[:2]),
+                "message": "Broadcast variables are not supported by SedonaDB 
vectorized functions.",
+            },
+        )
+
+
+def get_runner_conf(utf_serde: UTF8Deserializer, infile):
+    runner_conf = {}
+    num_conf = read_int(infile)
+    for i in range(num_conf):
+        k = utf_serde.loads(infile)
+        v = utf_serde.loads(infile)
+        runner_conf[k] = v
+    return runner_conf
+
+
+def read_command(serializer, infile):
+    command = serializer._read_with_length(infile)
+    return command
+
+
+def read_udf(infile, pickle_ser) -> UDFInfo:
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for i in range(num_arg)]
+
+    function = None
+    return_type = None
+
+    for i in range(read_int(infile)):
+        function, return_type = read_command(pickle_ser, infile)
+
+    sedona_db_udf_expression = function()
+
+    return UDFInfo(
+        arg_offsets=arg_offsets,
+        function=sedona_db_udf_expression,
+        return_type=return_type,
+        name=sedona_db_udf_expression._name,
+        geom_offsets=[0],
+    )
+
+
+def register_sedona_db_udf(infile, pickle_ser) -> UDFInfo:
+    num_udfs = read_int(infile)
+
+    udf = None
+    for _ in range(num_udfs):
+        udf = read_udf(infile, pickle_ser)
+
+    return udf

Review Comment:
   yes, it is supporting only one level nesting functions so far, I would like 
to extend this functionallity in next MRs to not overwhelm in reviews



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to