This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 976064dc6c Add Snowpark operator and decorator (#42457)
976064dc6c is described below

commit 976064dc6ce95d3b5cead1a7d2fcad4971c61b9a
Author: Jianzhun Du <[email protected]>
AuthorDate: Wed Oct 2 18:55:57 2024 -0700

    Add Snowpark operator and decorator (#42457)
---
 airflow/providers/snowflake/decorators/__init__.py |  16 ++
 airflow/providers/snowflake/decorators/snowpark.py | 124 ++++++++++++
 airflow/providers/snowflake/hooks/snowflake.py     |  22 ++
 airflow/providers/snowflake/operators/snowpark.py  | 133 +++++++++++++
 airflow/providers/snowflake/provider.yaml          |   7 +
 airflow/providers/snowflake/utils/snowpark.py      |  44 ++++
 .../decorators/index.rst                           |  25 +++
 .../decorators/snowpark.rst                        |  70 +++++++
 docs/apache-airflow-providers-snowflake/index.rst  |   1 +
 .../operators/snowpark.rst                         |  74 +++++++
 docs/spelling_wordlist.txt                         |   3 +
 generated/provider_dependencies.json               |   1 +
 tests/providers/snowflake/decorators/__init__.py   |  16 ++
 .../snowflake/decorators/test_snowpark.py          | 221 +++++++++++++++++++++
 tests/providers/snowflake/hooks/test_snowflake.py  |  27 +++
 .../providers/snowflake/operators/test_snowpark.py | 181 +++++++++++++++++
 tests/providers/snowflake/utils/test_snowpark.py   |  36 ++++
 .../snowflake/example_snowpark_decorator.py        |  85 ++++++++
 .../snowflake/example_snowpark_operator.py         |  94 +++++++++
 19 files changed, 1180 insertions(+)

diff --git a/airflow/providers/snowflake/decorators/__init__.py 
b/airflow/providers/snowflake/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/snowflake/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/snowflake/decorators/snowpark.py 
b/airflow/providers/snowflake/decorators/snowpark.py
new file mode 100644
index 0000000000..406d817e9d
--- /dev/null
+++ b/airflow/providers/snowflake/decorators/snowpark.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, task_decorator_factory
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+from airflow.providers.snowflake.utils.snowpark import 
inject_session_into_op_kwargs
+
+if TYPE_CHECKING:
+    from airflow.decorators.base import TaskDecorator
+
+
+class _SnowparkDecoratedOperator(DecoratedOperator, SnowparkOperator):
+    """
+    Wraps a Python callable that contains Snowpark code and captures 
args/kwargs when called for execution.
+
+    :param snowflake_conn_id: Reference to
+        :ref:`Snowflake connection id<howto/connection:snowflake>`
+    :param python_callable: A reference to an object that is callable
+    :param op_args: a list of positional arguments that will get unpacked when
+        calling your callable
+    :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+        in your function
+    :param warehouse: name of warehouse (will overwrite any warehouse
+        defined in the connection's extra JSON)
+    :param database: name of database (will overwrite database defined
+        in connection)
+    :param schema: name of schema (will overwrite schema defined in
+        connection)
+    :param role: name of role (will overwrite any role defined in
+        connection's extra JSON)
+    :param authenticator: authenticator for Snowflake.
+        'snowflake' (default) to use the internal Snowflake authenticator
+        'externalbrowser' to authenticate using your web browser and
+        Okta, ADFS or any other SAML 2.0-compliant identify provider
+        (IdP) that has been defined for your account
+        'https://<your_okta_account_name>.okta.com' to authenticate
+        through native Okta.
+    :param session_parameters: You can set session-level parameters at
+        the time you connect to Snowflake
+    :param multiple_outputs: If set to True, the decorated function's return 
value will be unrolled to
+        multiple XCom values. Dict will unroll to XCom values with its keys as 
XCom keys. Defaults to False.
+    """
+
+    custom_operator_name = "@task.snowpark"
+
+    def __init__(
+        self,
+        *,
+        snowflake_conn_id: str = "snowflake_default",
+        python_callable: Callable,
+        op_args: Sequence | None = None,
+        op_kwargs: dict | None = None,
+        warehouse: str | None = None,
+        database: str | None = None,
+        role: str | None = None,
+        schema: str | None = None,
+        authenticator: str | None = None,
+        session_parameters: dict | None = None,
+        **kwargs,
+    ) -> None:
+        kwargs_to_upstream = {
+            "python_callable": python_callable,
+            "op_args": op_args,
+            "op_kwargs": op_kwargs,
+        }
+        super().__init__(
+            kwargs_to_upstream=kwargs_to_upstream,
+            snowflake_conn_id=snowflake_conn_id,
+            python_callable=python_callable,
+            op_args=op_args,
+            # airflow.decorators.base.DecoratedOperator checks if the 
functions are bindable, so we have to
+            # add an artificial value to pass the validation if there is a 
keyword argument named `session`
+            # in the signature of the python callable. The real value is 
determined at runtime.
+            op_kwargs=inject_session_into_op_kwargs(python_callable, 
op_kwargs, None)
+            if op_kwargs is not None
+            else op_kwargs,
+            warehouse=warehouse,
+            database=database,
+            role=role,
+            schema=schema,
+            authenticator=authenticator,
+            session_parameters=session_parameters,
+            **kwargs,
+        )
+
+
+def snowpark_task(
+    python_callable: Callable | None = None,
+    multiple_outputs: bool | None = None,
+    **kwargs,
+) -> TaskDecorator:
+    """
+    Wrap a function that contains Snowpark code into an Airflow operator.
+
+    Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+    :param python_callable: Function to decorate
+    :param multiple_outputs: If set to True, the decorated function's return 
value will be unrolled to
+        multiple XCom values. Dict will unroll to XCom values with its keys as 
XCom keys. Defaults to False.
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        multiple_outputs=multiple_outputs,
+        decorated_operator_class=_SnowparkDecoratedOperator,
+        **kwargs,
+    )
diff --git a/airflow/providers/snowflake/hooks/snowflake.py 
b/airflow/providers/snowflake/hooks/snowflake.py
index 4b4143fdd1..0f81f2e384 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -321,6 +321,28 @@ class SnowflakeHook(DbApiHook):
                 engine_kwargs["connect_args"][key] = conn_params[key]
         return create_engine(self._conn_params_to_sqlalchemy_uri(conn_params), 
**engine_kwargs)
 
+    def get_snowpark_session(self):
+        """
+        Get a Snowpark session object.
+
+        :return: the created session.
+        """
+        from snowflake.snowpark import Session
+
+        from airflow import __version__ as airflow_version
+        from airflow.providers.snowflake import __version__ as provider_version
+
+        conn_config = self._get_conn_params
+        session = Session.builder.configs(conn_config).create()
+        # add query tag for observability
+        session.update_query_tag(
+            {
+                "airflow_version": airflow_version,
+                "airflow_provider_version": provider_version,
+            }
+        )
+        return session
+
     def set_autocommit(self, conn, autocommit: Any) -> None:
         conn.autocommit(autocommit)
         conn.autocommit_mode = autocommit
diff --git a/airflow/providers/snowflake/operators/snowpark.py 
b/airflow/providers/snowflake/operators/snowpark.py
new file mode 100644
index 0000000000..1635eebaa3
--- /dev/null
+++ b/airflow/providers/snowflake/operators/snowpark.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any, Callable, Collection, Mapping, Sequence
+
+from airflow.operators.python import PythonOperator, get_current_context
+from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
+from airflow.providers.snowflake.utils.snowpark import 
inject_session_into_op_kwargs
+
+
+class SnowparkOperator(PythonOperator):
+    """
+    Executes a Python function with Snowpark Python code.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SnowparkOperator`
+
+    :param snowflake_conn_id: Reference to
+        :ref:`Snowflake connection id<howto/connection:snowflake>`
+    :param python_callable: A reference to an object that is callable
+    :param op_args: a list of positional arguments that will get unpacked when
+        calling your callable
+    :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+        in your function
+    :param templates_dict: a dictionary where the values are templates that
+        will get templated by the Airflow engine sometime between
+        ``__init__`` and ``execute`` takes place and are made available
+        in your callable's context after the template has been applied. 
(templated)
+    :param templates_exts: a list of file extensions to resolve while
+        processing templated fields, for examples ``['.sql', '.hql']``
+    :param show_return_value_in_logs: a bool value whether to show return_value
+        logs. Defaults to True, which allows return value log output.
+        It can be set to False to prevent log output of return value when you 
return huge data
+        such as transmission a large amount of XCom to TaskAPI.
+    :param warehouse: name of warehouse (will overwrite any warehouse
+        defined in the connection's extra JSON)
+    :param database: name of database (will overwrite database defined
+        in connection)
+    :param schema: name of schema (will overwrite schema defined in
+        connection)
+    :param role: name of role (will overwrite any role defined in
+        connection's extra JSON)
+    :param authenticator: authenticator for Snowflake.
+        'snowflake' (default) to use the internal Snowflake authenticator
+        'externalbrowser' to authenticate using your web browser and
+        Okta, ADFS or any other SAML 2.0-compliant identify provider
+        (IdP) that has been defined for your account
+        'https://<your_okta_account_name>.okta.com' to authenticate
+        through native Okta.
+    :param session_parameters: You can set session-level parameters at
+        the time you connect to Snowflake
+    """
+
+    def __init__(
+        self,
+        *,
+        snowflake_conn_id: str = "snowflake_default",
+        python_callable: Callable,
+        op_args: Collection[Any] | None = None,
+        op_kwargs: Mapping[str, Any] | None = None,
+        templates_dict: dict[str, Any] | None = None,
+        templates_exts: Sequence[str] | None = None,
+        show_return_value_in_logs: bool = True,
+        warehouse: str | None = None,
+        database: str | None = None,
+        schema: str | None = None,
+        role: str | None = None,
+        authenticator: str | None = None,
+        session_parameters: dict | None = None,
+        **kwargs,
+    ):
+        super().__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            templates_dict=templates_dict,
+            templates_exts=templates_exts,
+            show_return_value_in_logs=show_return_value_in_logs,
+            **kwargs,
+        )
+        self.snowflake_conn_id = snowflake_conn_id
+        self.warehouse = warehouse
+        self.database = database
+        self.schema = schema
+        self.role = role
+        self.authenticator = authenticator
+        self.session_parameters = session_parameters
+
+    def execute_callable(self):
+        hook = SnowflakeHook(
+            snowflake_conn_id=self.snowflake_conn_id,
+            warehouse=self.warehouse,
+            database=self.database,
+            role=self.role,
+            schema=self.schema,
+            authenticator=self.authenticator,
+            session_parameters=self.session_parameters,
+        )
+        session = hook.get_snowpark_session()
+        context = get_current_context()
+        session.update_query_tag(
+            {
+                "dag_id": context["dag_run"].dag_id,
+                "dag_run_id": context["dag_run"].run_id,
+                "task_id": context["task_instance"].task_id,
+                "operator": self.__class__.__name__,
+            }
+        )
+        try:
+            # inject session object if the function has "session" keyword as 
an argument
+            self.op_kwargs = inject_session_into_op_kwargs(
+                self.python_callable, dict(self.op_kwargs), session
+            )
+            return super().execute_callable()
+        finally:
+            session.close()
diff --git a/airflow/providers/snowflake/provider.yaml 
b/airflow/providers/snowflake/provider.yaml
index 067f673d70..47de902ff6 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -90,12 +90,14 @@ dependencies:
   - pyarrow>=14.0.1
   - snowflake-connector-python>=3.7.1
   - snowflake-sqlalchemy>=1.4.0
+  - snowflake-snowpark-python>=1.17.0;python_version<"3.12"
 
 integrations:
   - integration-name: Snowflake
     external-doc-url: https://snowflake.com/
     how-to-guide:
       - /docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+      - /docs/apache-airflow-providers-snowflake/operators/snowpark.rst
     logo: /integration-logos/snowflake/Snowflake.png
     tags: [service]
 
@@ -103,6 +105,11 @@ operators:
   - integration-name: Snowflake
     python-modules:
       - airflow.providers.snowflake.operators.snowflake
+      - airflow.providers.snowflake.operators.snowpark
+
+task-decorators:
+  - class-name: airflow.providers.snowflake.decorators.snowpark.snowpark_task
+    name: snowpark
 
 hooks:
   - integration-name: Snowflake
diff --git a/airflow/providers/snowflake/utils/snowpark.py 
b/airflow/providers/snowflake/utils/snowpark.py
new file mode 100644
index 0000000000..a6617bb920
--- /dev/null
+++ b/airflow/providers/snowflake/utils/snowpark.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Callable
+
+if TYPE_CHECKING:
+    from snowflake.snowpark import Session
+
+
+def inject_session_into_op_kwargs(
+    python_callable: Callable, op_kwargs: dict, session: Session | None
+) -> dict:
+    """
+    Inject Snowpark session into operator kwargs based on signature of python 
callable.
+
+    If there is a keyword argument named `session` in the signature of the 
python callable,
+    a Snowpark session object will be injected into kwargs.
+
+    :param python_callable: Python callable
+    :param op_kwargs: Operator kwargs
+    :param session: Snowpark session
+    """
+    signature = inspect.signature(python_callable)
+    if "session" in signature.parameters:
+        return {**op_kwargs, "session": session}
+    else:
+        return op_kwargs
diff --git a/docs/apache-airflow-providers-snowflake/decorators/index.rst 
b/docs/apache-airflow-providers-snowflake/decorators/index.rst
new file mode 100644
index 0000000000..7871e3553b
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/decorators/index.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Snowflake decorators
+====================
+
+.. toctree::
+    :maxdepth: 1
+    :glob:
+
+    *
diff --git a/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst 
b/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst
new file mode 100644
index 0000000000..09be01e3ef
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst
@@ -0,0 +1,70 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/decorators:snowpark:
+
+``@task.snowpark``
+==================
+
+Use the :func:`@task.snowpark 
<airflow.providers.snowflake.decorators.snowpark.snowpark_task>` to run
+`Snowpark Python 
<https://docs.snowflake.com/en/developer-guide/snowpark/python/index.html>`__ 
code in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+.. warning::
+
+    - Snowpark does not support Python 3.12 yet.
+    - Currently, this decorator does not support `Snowpark pandas API 
<https://docs.snowflake.com/en/developer-guide/snowpark/python/pandas-on-snowflake>`__
 because conflicting pandas version is used in Airflow.
+      Consider using Snowpark pandas API with other Snowpark decorators or 
operators.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+To use this decorator, you must do a few things:
+
+  * Install provider package via **pip**.
+
+    .. code-block:: bash
+
+      pip install 'apache-airflow-providers-snowflake'
+
+    Detailed information is available for :doc:`Installation 
<apache-airflow:installation/index>`.
+
+  * :doc:`Setup a Snowflake Connection </connections/snowflake>`.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to specify connection used. If not 
specified, ``snowflake_default`` will be used.
+
+An example usage of the ``@task.snowpark`` is as follows:
+
+.. exampleinclude:: 
/../../tests/system/providers/snowflake/example_snowpark_decorator.py
+    :language: python
+    :start-after: [START howto_decorator_snowpark]
+    :end-before: [END howto_decorator_snowpark]
+
+
+As the example demonstrates, there are two ways to use the Snowpark session 
object in your Python function:
+
+  * Pass the Snowpark session object to the function as a keyword argument 
named ``session``. The Snowpark session will be automatically injected into the 
function, allowing you to use it as you normally would.
+
+  * Use `get_active_session 
<https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/1.3.0/api/snowflake.snowpark.context.get_active_session>`__
+    function from Snowpark to retrieve the Snowpark session object inside the 
function.
+
+.. note::
+
+  Parameters that can be passed onto the decorators will be given priority 
over the parameters already given
+  in the Airflow connection metadata (such as ``schema``, ``role``, 
``database`` and so forth).
diff --git a/docs/apache-airflow-providers-snowflake/index.rst 
b/docs/apache-airflow-providers-snowflake/index.rst
index 5b9a8a5133..b00ea39c52 100644
--- a/docs/apache-airflow-providers-snowflake/index.rst
+++ b/docs/apache-airflow-providers-snowflake/index.rst
@@ -36,6 +36,7 @@
 
     Connection Types <connections/snowflake>
     Operators <operators/index>
+    Decorators <decorators/index>
 
 .. toctree::
     :hidden:
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowpark.rst 
b/docs/apache-airflow-providers-snowflake/operators/snowpark.rst
new file mode 100644
index 0000000000..755fa6c529
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/operators/snowpark.rst
@@ -0,0 +1,74 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/operator:SnowparkOperator:
+
+SnowparkOperator
+================
+
+Use the :class:`SnowparkOperator 
<airflow.providers.snowflake.operators.snowpark.SnowparkOperator>` to run
+`Snowpark Python 
<https://docs.snowflake.com/en/developer-guide/snowpark/python/index.html>`__ 
code in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+.. warning::
+
+    - Snowpark does not support Python 3.12 yet.
+    - Currently, this operator does not support `Snowpark pandas API 
<https://docs.snowflake.com/en/developer-guide/snowpark/python/pandas-on-snowflake>`__
 because conflicting pandas version is used in Airflow.
+      Consider using Snowpark pandas API with other Snowpark decorators or 
operators.
+
+.. tip::
+
+    The :doc:`@task.snowpark </decorators/snowpark>` decorator is recommended 
over the ``SnowparkOperator`` to run Snowpark Python code.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+To use this operator, you must do a few things:
+
+  * Install provider package via **pip**.
+
+    .. code-block:: bash
+
+      pip install 'apache-airflow-providers-snowflake'
+
+    Detailed information is available for :doc:`Installation 
<apache-airflow:installation/index>`.
+
+  * :doc:`Setup a Snowflake Connection </connections/snowflake>`.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to specify connection used. If not 
specified, ``snowflake_default`` will be used.
+
+An example usage of the ``@task.snowpark`` is as follows:
+
+.. exampleinclude:: 
/../../tests/system/providers/snowflake/example_snowpark_operator.py
+    :language: python
+    :start-after: [START howto_operator_snowpark]
+    :end-before: [END howto_operator_snowpark]
+
+
+As the example demonstrates, there are two ways to use the Snowpark session 
object in your Python function:
+
+  * Pass the Snowpark session object to the function as a keyword argument 
named ``session``. The Snowpark session will be automatically injected into the 
function, allowing you to use it as you normally would.
+
+  * Use `get_active_session 
<https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/1.3.0/api/snowflake.snowpark.context.get_active_session>`__
+    function from Snowpark to retrieve the Snowpark session object inside the 
function.
+
+.. note::
+
+  Parameters that can be passed onto the operators will be given priority over 
the parameters already given
+  in the Airflow connection metadata (such as ``schema``, ``role``, 
``database`` and so forth).
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index eb6e612e09..cf6856838d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1485,6 +1485,9 @@ SlackResponse
 slas
 smtp
 SnowflakeHook
+Snowpark
+snowpark
+SnowparkOperator
 somecollection
 somedatabase
 sortable
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 2a81933c6c..10631afb9b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1230,6 +1230,7 @@
       "pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
       "pyarrow>=14.0.1",
       "snowflake-connector-python>=3.7.1",
+      "snowflake-snowpark-python>=1.17.0;python_version<\"3.12\"",
       "snowflake-sqlalchemy>=1.4.0"
     ],
     "devel-deps": [],
diff --git a/tests/providers/snowflake/decorators/__init__.py 
b/tests/providers/snowflake/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/snowflake/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/snowflake/decorators/test_snowpark.py 
b/tests/providers/snowflake/decorators/test_snowpark.py
new file mode 100644
index 0000000000..b14b6bd5c0
--- /dev/null
+++ b/tests/providers/snowflake/decorators/test_snowpark.py
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+
+from airflow.decorators import task
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+    from snowflake.snowpark import Session
+
+DEFAULT_DATE = timezone.datetime(2024, 9, 1)
+TEST_DAG_ID = "test_snowpark_decorator"
+TASK_ID = "snowpark_task"
+CONN_ID = "snowflake_default"
+
+
[email protected]_test
[email protected](sys.version_info >= (3, 12), reason="Snowpark Python 
doesn't support Python 3.12 yet")
+class TestSnowparkDecorator:
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_decorator_no_param(self, mock_snowflake_hook, dag_maker):
+        number = 11
+
+        @task.snowpark(
+            task_id=f"{TASK_ID}_1",
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func1(session: Session):
+            assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+            return number
+
+        @task.snowpark(
+            task_id=f"{TASK_ID}_2",
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func2():
+            return number
+
+        with dag_maker(dag_id=TEST_DAG_ID):
+            rets = [func1(), func2()]
+
+        dr = dag_maker.create_dagrun()
+        for ret in rets:
+            ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() == number
+        assert mock_snowflake_hook.call_count == 2
+        assert 
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_decorator_with_param(self, mock_snowflake_hook, 
dag_maker):
+        number = 11
+
+        @task.snowpark(
+            task_id=f"{TASK_ID}_1",
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func1(session: Session, number: int):
+            assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+            return number
+
+        @task.snowpark(
+            task_id=f"{TASK_ID}_2",
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func2(number: int, session: Session):
+            assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+            return number
+
+        @task.snowpark(
+            task_id=f"{TASK_ID}_3",
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func3(number: int):
+            return number
+
+        with dag_maker(dag_id=TEST_DAG_ID):
+            rets = [func1(number=number), func2(number=number), 
func3(number=number)]
+
+        dr = dag_maker.create_dagrun()
+        for ret in rets:
+            ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() == number
+        assert mock_snowflake_hook.call_count == 3
+        assert 
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_decorator_no_return(self, mock_snowflake_hook, 
dag_maker):
+        @task.snowpark(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func(session: Session):
+            assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+
+        with dag_maker(dag_id=TEST_DAG_ID):
+            ret = func()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() is None
+        mock_snowflake_hook.assert_called_once()
+        
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook, 
dag_maker):
+        @task.snowpark(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+            multiple_outputs=True,
+        )
+        def func(session: Session):
+            assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+            return {"a": 1, "b": "2"}
+
+        with dag_maker(dag_id=TEST_DAG_ID):
+            ret = func()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        ti = dr.get_task_instances()[0]
+        assert ti.xcom_pull(key="a") == 1
+        assert ti.xcom_pull(key="b") == "2"
+        assert ti.xcom_pull() == {"a": 1, "b": "2"}
+        mock_snowflake_hook.assert_called_once()
+        
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_decorator_session_tag(self, mock_snowflake_hook, 
dag_maker):
+        mock_session = 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+        mock_session.query_tag = {}
+
+        # Mock the update_query_tag function to combine with another dict
+        def update_query_tag(new_tags):
+            mock_session.query_tag.update(new_tags)
+
+        mock_session.update_query_tag = mock.Mock(side_effect=update_query_tag)
+
+        @task.snowpark(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            warehouse="test_warehouse",
+            database="test_database",
+            schema="test_schema",
+            role="test_role",
+            authenticator="externalbrowser",
+        )
+        def func(session: Session):
+            return session.query_tag
+
+        with dag_maker(dag_id=TEST_DAG_ID):
+            ret = func()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        ti = dr.get_task_instances()[0]
+        query_tag = ti.xcom_pull()
+        assert query_tag == {
+            "dag_id": TEST_DAG_ID,
+            "dag_run_id": dr.run_id,
+            "task_id": TASK_ID,
+            "operator": "_SnowparkDecoratedOperator",
+        }
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py 
b/tests/providers/snowflake/hooks/test_snowflake.py
index 16e10db048..9ef0c4d2a5 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+import sys
 from copy import deepcopy
 from typing import TYPE_CHECKING, Any
 from unittest import mock
@@ -611,3 +612,29 @@ class TestPytestSnowflakeHook:
             hook_with_schema_param = 
SnowflakeHook(snowflake_conn_id="test_conn", schema="my_schema")
             assert hook_with_schema_param.get_openlineage_default_schema() == 
"my_schema"
             mock_get_first.assert_not_called()
+
+    @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python 
doesn't support Python 3.12 yet")
+    @mock.patch("snowflake.snowpark.Session.builder")
+    def test_get_snowpark_session(self, mock_session_builder):
+        from airflow import __version__ as airflow_version
+        from airflow.providers.snowflake import __version__ as provider_version
+
+        mock_session = mock.MagicMock()
+        mock_session_builder.configs.return_value.create.return_value = 
mock_session
+
+        with mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            session = hook.get_snowpark_session()
+            assert session == mock_session
+
+            
mock_session_builder.configs.assert_called_once_with(hook._get_conn_params)
+
+            # Verify that update_query_tag was called with the expected tag 
dictionary
+            mock_session.update_query_tag.assert_called_once_with(
+                {
+                    "airflow_version": airflow_version,
+                    "airflow_provider_version": provider_version,
+                }
+            )
diff --git a/tests/providers/snowflake/operators/test_snowpark.py 
b/tests/providers/snowflake/operators/test_snowpark.py
new file mode 100644
index 0000000000..b39bf3c105
--- /dev/null
+++ b/tests/providers/snowflake/operators/test_snowpark.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+    from snowflake.snowpark import Session
+
+DEFAULT_DATE = timezone.datetime(2024, 9, 1)
+TEST_DAG_ID = "test_snowpark_operator"
+TASK_ID = "snowpark_task"
+CONN_ID = "snowflake_default"
+
+
[email protected]_test
[email protected](sys.version_info >= (3, 12), reason="Snowpark Python 
doesn't support Python 3.12 yet")
+class TestSnowparkOperator:
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_operator_no_param(self, mock_snowflake_hook, dag_maker):
+        number = 11
+
+        with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+            def func1(session: Session):
+                assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+                return number
+
+            def func2():
+                return number
+
+            operators = [
+                SnowparkOperator(
+                    task_id=f"{TASK_ID}_{i}",
+                    snowflake_conn_id=CONN_ID,
+                    python_callable=func,
+                    warehouse="test_warehouse",
+                    database="test_database",
+                    schema="test_schema",
+                    role="test_role",
+                    authenticator="externalbrowser",
+                    dag=dag,
+                )
+                for i, func in enumerate([func1, func2])
+            ]
+
+        dr = dag_maker.create_dagrun()
+        for operator in operators:
+            operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() == number
+        assert mock_snowflake_hook.call_count == 2
+        assert 
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_operator_with_param(self, mock_snowflake_hook, 
dag_maker):
+        number = 11
+
+        with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+            def func1(session: Session, number: int):
+                assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+                return number
+
+            def func2(number: int, session: Session):
+                assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+                return number
+
+            def func3(number: int):
+                return number
+
+            operators = [
+                SnowparkOperator(
+                    task_id=f"{TASK_ID}_{i}",
+                    snowflake_conn_id=CONN_ID,
+                    python_callable=func,
+                    op_kwargs={"number": number},
+                    warehouse="test_warehouse",
+                    database="test_database",
+                    schema="test_schema",
+                    role="test_role",
+                    authenticator="externalbrowser",
+                    dag=dag,
+                )
+                for i, func in enumerate([func1, func2, func3])
+            ]
+
+        dr = dag_maker.create_dagrun()
+        for operator in operators:
+            operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() == number
+        assert mock_snowflake_hook.call_count == 3
+        assert 
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_operator_no_return(self, mock_snowflake_hook, dag_maker):
+        with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+            def func(session: Session):
+                assert session == 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+
+            operator = SnowparkOperator(
+                task_id=TASK_ID,
+                snowflake_conn_id=CONN_ID,
+                python_callable=func,
+                warehouse="test_warehouse",
+                database="test_database",
+                schema="test_schema",
+                role="test_role",
+                authenticator="externalbrowser",
+                dag=dag,
+            )
+
+        dr = dag_maker.create_dagrun()
+        operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in dr.get_task_instances():
+            assert ti.xcom_pull() is None
+        mock_snowflake_hook.assert_called_once()
+        
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+    @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+    def test_snowpark_operator_session_tag(self, mock_snowflake_hook, 
dag_maker):
+        mock_session = 
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+        mock_session.query_tag = {}
+
+        # Mock the update_query_tag function to combine with another dict
+        def update_query_tag(new_tags):
+            mock_session.query_tag.update(new_tags)
+
+        mock_session.update_query_tag = mock.Mock(side_effect=update_query_tag)
+
+        with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+            def func(session: Session):
+                return session.query_tag
+
+            operator = SnowparkOperator(
+                task_id=TASK_ID,
+                snowflake_conn_id=CONN_ID,
+                python_callable=func,
+                warehouse="test_warehouse",
+                database="test_database",
+                schema="test_schema",
+                role="test_role",
+                authenticator="externalbrowser",
+                dag=dag,
+            )
+
+        dr = dag_maker.create_dagrun()
+        operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        ti = dr.get_task_instances()[0]
+        query_tag = ti.xcom_pull()
+        assert query_tag == {
+            "dag_id": TEST_DAG_ID,
+            "dag_run_id": dr.run_id,
+            "task_id": TASK_ID,
+            "operator": "SnowparkOperator",
+        }
diff --git a/tests/providers/snowflake/utils/test_snowpark.py 
b/tests/providers/snowflake/utils/test_snowpark.py
new file mode 100644
index 0000000000..c0c8b507ef
--- /dev/null
+++ b/tests/providers/snowflake/utils/test_snowpark.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.snowflake.utils.snowpark import 
inject_session_into_op_kwargs
+
+
[email protected](
+    "func,expected_injected",
+    [
+        (lambda x: x, False),
+        (lambda: 1, False),
+        (lambda session: 1, True),
+        (lambda session, x: x, True),
+        (lambda x, session: 2 * x, True),
+    ],
+)
+def test_inject_session_into_op_kwargs(func, expected_injected):
+    result = inject_session_into_op_kwargs(func, {}, None)
+    assert ("session" in result) == expected_injected
diff --git a/tests/system/providers/snowflake/example_snowpark_decorator.py 
b/tests/system/providers/snowflake/example_snowpark_decorator.py
new file mode 100644
index 0000000000..1a303b1fdf
--- /dev/null
+++ b/tests/system/providers/snowflake/example_snowpark_decorator.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example use of Snowflake Snowpark Python related decorators.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from snowflake.snowpark import Session
+
+from airflow import DAG
+from airflow.decorators import task
+
+SNOWFLAKE_CONN_ID = "snowflake_default"
+DAG_ID = "example_snowpark"
+
+with DAG(
+    DAG_ID,
+    start_date=datetime(2024, 1, 1),
+    schedule="@once",
+    default_args={"snowflake_conn_id": SNOWFLAKE_CONN_ID},
+    tags=["example"],
+    catchup=False,
+) as dag:
+    # [START howto_decorator_snowpark]
+    @task.snowpark
+    def setup_data(session: Session):
+        # The Snowpark session object is injected as an argument
+        data = [
+            (1, 0, 5, "Product 1", "prod-1", 1, 10),
+            (2, 1, 5, "Product 1A", "prod-1-A", 1, 20),
+            (3, 1, 5, "Product 1B", "prod-1-B", 1, 30),
+            (4, 0, 10, "Product 2", "prod-2", 2, 40),
+            (5, 4, 10, "Product 2A", "prod-2-A", 2, 50),
+            (6, 4, 10, "Product 2B", "prod-2-B", 2, 60),
+            (7, 0, 20, "Product 3", "prod-3", 3, 70),
+            (8, 7, 20, "Product 3A", "prod-3-A", 3, 80),
+            (9, 7, 20, "Product 3B", "prod-3-B", 3, 90),
+            (10, 0, 50, "Product 4", "prod-4", 4, 100),
+            (11, 10, 50, "Product 4A", "prod-4-A", 4, 100),
+            (12, 10, 50, "Product 4B", "prod-4-B", 4, 100),
+        ]
+        columns = ["id", "parent_id", "category_id", "name", "serial_number", 
"key", "3rd"]
+        df = session.create_dataframe(data, schema=columns)
+        table_name = "sample_product_data"
+        df.write.save_as_table(table_name, mode="overwrite")
+        return table_name
+
+    table_name = setup_data()  # type: ignore[call-arg]
+
+    @task.snowpark
+    def check_num_rows(table_name: str):
+        # Alternatively, retrieve the Snowpark session object using 
`get_active_session`
+        from snowflake.snowpark.context import get_active_session
+
+        session = get_active_session()
+        df = session.table(table_name)
+        assert df.count() == 12
+
+    check_num_rows(table_name)
+    # [END howto_decorator_snowpark]
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/snowflake/example_snowpark_operator.py 
b/tests/system/providers/snowflake/example_snowpark_operator.py
new file mode 100644
index 0000000000..090a0f53a4
--- /dev/null
+++ b/tests/system/providers/snowflake/example_snowpark_operator.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example use of Snowflake Snowpark Python related operators.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from snowflake.snowpark import Session
+
+from airflow import DAG
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+
+SNOWFLAKE_CONN_ID = "snowflake_default"
+DAG_ID = "example_snowpark"
+
+with DAG(
+    DAG_ID,
+    start_date=datetime(2024, 1, 1),
+    schedule="@once",
+    default_args={"snowflake_conn_id": SNOWFLAKE_CONN_ID},
+    tags=["example"],
+    catchup=False,
+) as dag:
+    # [START howto_operator_snowpark]
+    def setup_data(session: Session):
+        # The Snowpark session object is injected as an argument
+        data = [
+            (1, 0, 5, "Product 1", "prod-1", 1, 10),
+            (2, 1, 5, "Product 1A", "prod-1-A", 1, 20),
+            (3, 1, 5, "Product 1B", "prod-1-B", 1, 30),
+            (4, 0, 10, "Product 2", "prod-2", 2, 40),
+            (5, 4, 10, "Product 2A", "prod-2-A", 2, 50),
+            (6, 4, 10, "Product 2B", "prod-2-B", 2, 60),
+            (7, 0, 20, "Product 3", "prod-3", 3, 70),
+            (8, 7, 20, "Product 3A", "prod-3-A", 3, 80),
+            (9, 7, 20, "Product 3B", "prod-3-B", 3, 90),
+            (10, 0, 50, "Product 4", "prod-4", 4, 100),
+            (11, 10, 50, "Product 4A", "prod-4-A", 4, 100),
+            (12, 10, 50, "Product 4B", "prod-4-B", 4, 100),
+        ]
+        columns = ["id", "parent_id", "category_id", "name", "serial_number", 
"key", "3rd"]
+        df = session.create_dataframe(data, schema=columns)
+        table_name = "sample_product_data"
+        df.write.save_as_table(table_name, mode="overwrite")
+        return table_name
+
+    setup_data_operator = SnowparkOperator(
+        task_id="setup_data",
+        python_callable=setup_data,
+        dag=dag,
+    )
+
+    def check_num_rows(table_name: str):
+        # Alternatively, retrieve the Snowpark session object using 
`get_active_session`
+        from snowflake.snowpark.context import get_active_session
+
+        session = get_active_session()
+        df = session.table(table_name)
+        assert df.count() == 12
+
+    check_num_rows_operator = SnowparkOperator(
+        task_id="check_num_rows",
+        python_callable=check_num_rows,
+        op_kwargs={"table_name": "{{ 
task_instance.xcom_pull(task_ids='setup_data') }}"},
+        dag=dag,
+    )
+
+    setup_data_operator >> check_num_rows_operator
+    # [END howto_operator_snowpark]
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to