This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9995e62 [SPARK-36885][PYTHON] Inline type hints for
pyspark.sql.dataframe
9995e62 is described below
commit 9995e623f7f65fdd3b1dc3cd4e0140a7cf4bc4a0
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue Oct 12 09:17:14 2021 +0900
[SPARK-36885][PYTHON] Inline type hints for pyspark.sql.dataframe
### What changes were proposed in this pull request?
Inline type hints from `python/pyspark/sql/dataframe.pyi` to
`python/pyspark/sql/dataframe.py`.
### Why are the changes needed?
Currently, there is type hint stub files `python/pyspark/sql/dataframe.pyi`
to show the expected types for functions, but we can also take advantage of
static type checking within the functions by inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #34225 from ueshin/issues/SPARK-36885/inline_typehints_dataframe.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/__init__.pyi | 2 +-
python/pyspark/sql/dataframe.py | 671 ++++++++++++++++++++++++++++----------
python/pyspark/sql/dataframe.pyi | 351 --------------------
python/pyspark/sql/observation.py | 2 +-
4 files changed, 504 insertions(+), 522 deletions(-)
diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi
index f85319b..35df545 100644
--- a/python/pyspark/__init__.pyi
+++ b/python/pyspark/__init__.pyi
@@ -71,7 +71,7 @@ def since(version: Union[str, float]) -> Callable[[T], T]: ...
def copy_func(
f: F,
name: Optional[str] = ...,
- sinceversion: Optional[str] = ...,
+ sinceversion: Optional[Union[str, float]] = ...,
doc: Optional[str] = ...,
) -> F: ...
def keyword_only(func: F) -> F: ...
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8d4c94f..339f8f8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,22 +22,41 @@ import warnings
from collections.abc import Iterable
from functools import reduce
from html import escape as html_escape
+from typing import (
+ Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type,
Union,
+ cast, overload, TYPE_CHECKING
+)
-from pyspark import copy_func, since, _NoValue
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import copy_func, since, _NoValue # type: ignore[attr-defined]
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket
+from pyspark.rdd import ( # type: ignore[attr-defined]
+ RDD, _load_from_socket, _local_iterator_from_socket
+)
from pyspark.serializers import BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.sql.types import _parse_datatype_json_string
-from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.types import _parse_datatype_json_string # type:
ignore[attr-defined]
+from pyspark.sql.column import ( # type: ignore[attr-defined]
+ Column, _to_seq, _to_list, _to_java_column
+)
from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.streaming import DataStreamWriter
-from pyspark.sql.types import StructType, StructField, StringType, IntegerType
+from pyspark.sql.types import StructType, StructField, StringType,
IntegerType, Row
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+if TYPE_CHECKING:
+ from pyspark._typing import PrimitiveType
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql._typing import ColumnOrName, LiteralType,
OptionalPrimitiveType
+ from pyspark.sql.context import SQLContext
+ from pyspark.sql.group import GroupedData
+ from pyspark.sql.observation import Observation
+
+
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -68,42 +87,49 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
.. versionadded:: 1.3.0
"""
- def __init__(self, jdf, sql_ctx):
+ def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"):
self._jdf = jdf
self.sql_ctx = sql_ctx
- self._sc = sql_ctx and sql_ctx._sc
+ self._sc = cast(
+ SparkContext,
+ sql_ctx and sql_ctx._sc # type: ignore[attr-defined]
+ )
self.is_cached = False
- self._schema = None # initialized lazily
- self._lazy_rdd = None
+ # initialized lazily
+ self._schema: Optional[StructType] = None
+ self._lazy_rdd: Optional[RDD[Row]] = None
# Check whether _repr_html is supported or not, we use it to avoid
calling _jdf twice
# by __repr__ and _repr_html_ while eager evaluation opened.
self._support_repr_html = False
- @property
+ @property # type: ignore[misc]
@since(1.3)
- def rdd(self):
+ def rdd(self) -> "RDD[Row]":
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
"""
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
- self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc,
BatchedSerializer(PickleSerializer()))
+ self._lazy_rdd = RDD(
+ jrdd, self.sql_ctx._sc, # type: ignore[attr-defined]
+ BatchedSerializer(PickleSerializer())
+ )
return self._lazy_rdd
- @property
+ @property # type: ignore[misc]
@since("1.3.1")
- def na(self):
+ def na(self) -> "DataFrameNaFunctions":
"""Returns a :class:`DataFrameNaFunctions` for handling missing values.
"""
return DataFrameNaFunctions(self)
- @property
+ @property # type: ignore[misc]
@since(1.4)
- def stat(self):
+ def stat(self) -> "DataFrameStatFunctions":
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
"""
return DataFrameStatFunctions(self)
- def toJSON(self, use_unicode=True):
+ def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
Each row is turned into a JSON document as one element in the returned
RDD.
@@ -118,7 +144,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
- def registerTempTable(self, name):
+ def registerTempTable(self, name: str) -> None:
"""Registers this :class:`DataFrame` as a temporary table using the
given name.
The lifetime of this temporary table is tied to the
:class:`SparkSession`
@@ -145,7 +171,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
)
self._jdf.createOrReplaceTempView(name)
- def createTempView(self, name):
+ def createTempView(self, name: str) -> None:
"""Creates a local temporary view with this :class:`DataFrame`.
The lifetime of this temporary table is tied to the
:class:`SparkSession`
@@ -171,7 +197,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
self._jdf.createTempView(name)
- def createOrReplaceTempView(self, name):
+ def createOrReplaceTempView(self, name: str) -> None:
"""Creates or replaces a local temporary view with this
:class:`DataFrame`.
The lifetime of this temporary table is tied to the
:class:`SparkSession`
@@ -193,7 +219,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
self._jdf.createOrReplaceTempView(name)
- def createGlobalTempView(self, name):
+ def createGlobalTempView(self, name: str) -> None:
"""Creates a global temporary view with this :class:`DataFrame`.
The lifetime of this temporary view is tied to this Spark application.
@@ -218,7 +244,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
self._jdf.createGlobalTempView(name)
- def createOrReplaceGlobalTempView(self, name):
+ def createOrReplaceGlobalTempView(self, name: str) -> None:
"""Creates or replaces a global temporary view using the given name.
The lifetime of this temporary view is tied to this Spark application.
@@ -240,7 +266,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
self._jdf.createOrReplaceGlobalTempView(name)
@property
- def write(self):
+ def write(self) -> DataFrameWriter:
"""
Interface for saving the content of the non-streaming
:class:`DataFrame` out into external
storage.
@@ -254,7 +280,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrameWriter(self)
@property
- def writeStream(self):
+ def writeStream(self) -> DataStreamWriter:
"""
Interface for saving the content of the streaming :class:`DataFrame`
out into external
storage.
@@ -272,7 +298,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataStreamWriter(self)
@property
- def schema(self):
+ def schema(self) -> StructType:
"""Returns the schema of this :class:`DataFrame` as a
:class:`pyspark.sql.types.StructType`.
.. versionadded:: 1.3.0
@@ -290,7 +316,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"Unable to parse datatype from schema. %s" % e) from e
return self._schema
- def printSchema(self):
+ def printSchema(self) -> None:
"""Prints out the schema in the tree format.
.. versionadded:: 1.3.0
@@ -305,7 +331,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
print(self._jdf.schema().treeString())
- def explain(self, extended=None, mode=None):
+ def explain(
+ self, extended: Optional[Union[bool, str]] = None, mode: Optional[str]
= None
+ ) -> None:
"""Prints the (logical and physical) plans to the console for
debugging purpose.
.. versionadded:: 1.3.0
@@ -390,13 +418,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
elif is_extended_case:
explain_mode = "extended" if extended else "simple"
elif is_mode_case:
- explain_mode = mode
+ explain_mode = cast(str, mode)
elif is_extended_as_mode:
- explain_mode = extended
+ explain_mode = cast(str, extended)
-
print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(),
explain_mode))
+ print(
+ self._sc._jvm # type: ignore[attr-defined]
+ .PythonSQLUtils.explainString(self._jdf.queryExecution(),
explain_mode)
+ )
- def exceptAll(self, other):
+ def exceptAll(self, other: "DataFrame") -> "DataFrame":
"""Return a new :class:`DataFrame` containing rows in this
:class:`DataFrame` but
not in another :class:`DataFrame` while preserving duplicates.
@@ -425,14 +456,14 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
@since(1.3)
- def isLocal(self):
+ def isLocal(self) -> bool:
"""Returns ``True`` if the :func:`collect` and :func:`take` methods
can be run locally
(without any Spark executors).
"""
return self._jdf.isLocal()
@property
- def isStreaming(self):
+ def isStreaming(self) -> bool:
"""Returns ``True`` if this :class:`DataFrame` contains one or more
sources that
continuously return data as it arrives. A :class:`DataFrame` that
reads data from a
streaming source must be executed as a :class:`StreamingQuery` using
the :func:`start`
@@ -448,7 +479,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return self._jdf.isStreaming()
- def show(self, n=20, truncate=True, vertical=False):
+ def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical:
bool = False) -> None:
"""Prints the first ``n`` rows to the console.
.. versionadded:: 1.3.0
@@ -509,26 +540,33 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
print(self._jdf.showString(n, int_truncate, vertical))
- def __repr__(self):
- if not self._support_repr_html and
self.sql_ctx._conf.isReplEagerEvalEnabled():
+ def __repr__(self) -> str:
+ if (
+ not self._support_repr_html
+ and self.sql_ctx._conf.isReplEagerEvalEnabled() # type:
ignore[attr-defined]
+ ):
vertical = False
return self._jdf.showString(
- self.sql_ctx._conf.replEagerEvalMaxNumRows(),
- self.sql_ctx._conf.replEagerEvalTruncate(), vertical)
+ self.sql_ctx._conf.replEagerEvalMaxNumRows(), # type:
ignore[attr-defined]
+ self.sql_ctx._conf.replEagerEvalTruncate(), vertical) # type:
ignore[attr-defined]
else:
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in
self.dtypes))
- def _repr_html_(self):
+ def _repr_html_(self) -> Optional[str]:
"""Returns a :class:`DataFrame` with html code when you enabled eager
evaluation
by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
using support eager evaluation with HTML.
"""
if not self._support_repr_html:
self._support_repr_html = True
- if self.sql_ctx._conf.isReplEagerEvalEnabled():
- max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0)
+ if self.sql_ctx._conf.isReplEagerEvalEnabled(): # type:
ignore[attr-defined]
+ max_num_rows = max(
+ self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0 # type:
ignore[attr-defined]
+ )
sock_info = self._jdf.getRowsToPython(
- max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate())
+ max_num_rows,
+ self.sql_ctx._conf.replEagerEvalTruncate() # type:
ignore[attr-defined]
+ )
rows = list(_load_from_socket(sock_info,
BatchedSerializer(PickleSerializer())))
head = rows[0]
row_data = rows[1:]
@@ -550,7 +588,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
else:
return None
- def checkpoint(self, eager=True):
+ def checkpoint(self, eager: bool = True) -> "DataFrame":
"""Returns a checkpointed version of this :class:`DataFrame`.
Checkpointing can be used to
truncate the logical plan of this :class:`DataFrame`, which is
especially useful in
iterative algorithms where the plan may grow exponentially. It will be
saved to files
@@ -570,7 +608,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.checkpoint(eager)
return DataFrame(jdf, self.sql_ctx)
- def localCheckpoint(self, eager=True):
+ def localCheckpoint(self, eager: bool = True) -> "DataFrame":
"""Returns a locally checkpointed version of this :class:`DataFrame`.
Checkpointing can be
used to truncate the logical plan of this :class:`DataFrame`, which is
especially useful in
iterative algorithms where the plan may grow exponentially. Local
checkpoints are
@@ -590,7 +628,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.localCheckpoint(eager)
return DataFrame(jdf, self.sql_ctx)
- def withWatermark(self, eventTime, delayThreshold):
+ def withWatermark(self, eventTime: str, delayThreshold: str) ->
"DataFrame":
"""Defines an event time watermark for this :class:`DataFrame`. A
watermark tracks a point
in time before which we assume no more late data is going to arrive.
@@ -634,7 +672,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.withWatermark(eventTime, delayThreshold)
return DataFrame(jdf, self.sql_ctx)
- def hint(self, name, *parameters):
+ def hint(
+ self, name: str, *parameters: Union["PrimitiveType",
List["PrimitiveType"]]
+ ) -> "DataFrame":
"""Specifies some hint on the current :class:`DataFrame`.
.. versionadded:: 2.2.0
@@ -660,7 +700,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
+----+---+------+
"""
if len(parameters) == 1 and isinstance(parameters[0], list):
- parameters = parameters[0]
+ parameters = parameters[0] # type: ignore[assignment]
if not isinstance(name, str):
raise TypeError("name should be provided as str, got
{0}".format(type(name)))
@@ -675,7 +715,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.hint(name, self._jseq(parameters))
return DataFrame(jdf, self.sql_ctx)
- def count(self):
+ def count(self) -> int:
"""Returns the number of rows in this :class:`DataFrame`.
.. versionadded:: 1.3.0
@@ -687,7 +727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return int(self._jdf.count())
- def collect(self):
+ def collect(self) -> List[Row]:
"""Returns all the records as a list of :class:`Row`.
.. versionadded:: 1.3.0
@@ -701,7 +741,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
sock_info = self._jdf.collectToPython()
return list(_load_from_socket(sock_info,
BatchedSerializer(PickleSerializer())))
- def toLocalIterator(self, prefetchPartitions=False):
+ def toLocalIterator(self, prefetchPartitions: bool = False) ->
Iterator[Row]:
"""
Returns an iterator that contains all of the rows in this
:class:`DataFrame`.
The iterator will consume as much memory as the largest partition in
this
@@ -724,7 +764,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info,
BatchedSerializer(PickleSerializer()))
- def limit(self, num):
+ def limit(self, num: int) -> "DataFrame":
"""Limits the result count to the number specified.
.. versionadded:: 1.3.0
@@ -739,7 +779,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.limit(num)
return DataFrame(jdf, self.sql_ctx)
- def take(self, num):
+ def take(self, num: int) -> List[Row]:
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
.. versionadded:: 1.3.0
@@ -751,7 +791,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return self.limit(num).collect()
- def tail(self, num):
+ def tail(self, num: int) -> List[Row]:
"""
Returns the last ``num`` rows as a :class:`list` of :class:`Row`.
@@ -769,7 +809,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
sock_info = self._jdf.tailToPython(num)
return list(_load_from_socket(sock_info,
BatchedSerializer(PickleSerializer())))
- def foreach(self, f):
+ def foreach(self, f: Callable[[Row], None]) -> None:
"""Applies the ``f`` function to all :class:`Row` of this
:class:`DataFrame`.
This is a shorthand for ``df.rdd.foreach()``.
@@ -784,7 +824,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
self.rdd.foreach(f)
- def foreachPartition(self, f):
+ def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
"""Applies the ``f`` function to each partition of this
:class:`DataFrame`.
This a shorthand for ``df.rdd.foreachPartition()``.
@@ -798,9 +838,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
... print(person.name)
>>> df.foreachPartition(f)
"""
- self.rdd.foreachPartition(f)
+ self.rdd.foreachPartition(f) # type: ignore[arg-type]
- def cache(self):
+ def cache(self) -> "DataFrame":
"""Persists the :class:`DataFrame` with the default storage level
(`MEMORY_AND_DISK`).
.. versionadded:: 1.3.0
@@ -813,7 +853,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
self._jdf.cache()
return self
- def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK_DESER):
+ def persist(
+ self,
+ storageLevel: StorageLevel = (
+ StorageLevel.MEMORY_AND_DISK_DESER # type: ignore[attr-defined]
+ )
+ ) -> "DataFrame":
"""Sets the storage level to persist the contents of the
:class:`DataFrame` across
operations after the first time it is computed. This can only be used
to assign
a new storage level if the :class:`DataFrame` does not have a storage
level set yet.
@@ -826,12 +871,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
The default storage level has changed to `MEMORY_AND_DISK_DESER` to
match Scala in 3.0.
"""
self.is_cached = True
- javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) #
type: ignore[attr-defined]
self._jdf.persist(javaStorageLevel)
return self
@property
- def storageLevel(self):
+ def storageLevel(self) -> StorageLevel:
"""Get the :class:`DataFrame`'s current storage level.
.. versionadded:: 2.1.0
@@ -853,7 +898,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
java_storage_level.replication())
return storage_level
- def unpersist(self, blocking=False):
+ def unpersist(self, blocking: bool = False) -> "DataFrame":
"""Marks the :class:`DataFrame` as non-persistent, and remove all
blocks for it from
memory and disk.
@@ -867,7 +912,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
self._jdf.unpersist(blocking)
return self
- def coalesce(self, numPartitions):
+ def coalesce(self, numPartitions: int) -> "DataFrame":
"""
Returns a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
@@ -898,7 +943,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
- def repartition(self, numPartitions, *cols):
+ @overload
+ def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
"DataFrame":
+ ...
+
+ @overload
+ def repartition(self, *cols: "ColumnOrName") -> "DataFrame":
+ ...
+
+ def repartition( # type: ignore[misc]
+ self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+ ) -> "DataFrame":
"""
Returns a new :class:`DataFrame` partitioned by the given partitioning
expressions. The
resulting :class:`DataFrame` is hash partitioned.
@@ -967,7 +1022,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
else:
raise TypeError("numPartitions should be an int or Column")
- def repartitionByRange(self, numPartitions, *cols):
+ @overload
+ def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
"DataFrame":
+ ...
+
+ @overload
+ def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame":
+ ...
+
+ def repartitionByRange( # type: ignore[misc]
+ self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+ ) -> "DataFrame":
"""
Returns a new :class:`DataFrame` partitioned by the given partitioning
expressions. The
resulting :class:`DataFrame` is range partitioned.
@@ -1017,7 +1082,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
if isinstance(numPartitions, int):
if len(cols) == 0:
- return ValueError("At least one partition-by expression must
be specified.")
+ raise ValueError("At least one partition-by expression must be
specified.")
else:
return DataFrame(
self._jdf.repartitionByRange(numPartitions,
self._jcols(*cols)), self.sql_ctx)
@@ -1027,7 +1092,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
else:
raise TypeError("numPartitions should be an int, string or Column")
- def distinct(self):
+ def distinct(self) -> "DataFrame":
"""Returns a new :class:`DataFrame` containing the distinct rows in
this :class:`DataFrame`.
.. versionadded:: 1.3.0
@@ -1039,7 +1104,25 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
return DataFrame(self._jdf.distinct(), self.sql_ctx)
- def sample(self, withReplacement=None, fraction=None, seed=None):
+ @overload
+ def sample(self, fraction: float, seed: Optional[int] = ...) ->
"DataFrame":
+ ...
+
+ @overload
+ def sample(
+ self,
+ withReplacement: Optional[bool],
+ fraction: float,
+ seed: Optional[int] = ...,
+ ) -> "DataFrame":
+ ...
+
+ def sample( # type: ignore[misc]
+ self,
+ withReplacement: Optional[Union[float, bool]] = None,
+ fraction: Optional[Union[int, float]] = None,
+ seed: Optional[int] = None
+ ) -> "DataFrame":
"""Returns a sampled subset of this :class:`DataFrame`.
.. versionadded:: 1.3.0
@@ -1105,7 +1188,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if is_withReplacement_omitted_args:
if fraction is not None:
- seed = fraction
+ seed = cast(int, fraction)
fraction = withReplacement
withReplacement = None
@@ -1114,7 +1197,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.sample(*args)
return DataFrame(jdf, self.sql_ctx)
- def sampleBy(self, col, fractions, seed=None):
+ def sampleBy(
+ self, col: "ColumnOrName", fractions: Dict[Any, float], seed:
Optional[int] = None
+ ) -> "DataFrame":
"""
Returns a stratified sample without replacement based on the
fraction given on each stratum.
@@ -1167,7 +1252,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions),
seed), self.sql_ctx)
- def randomSplit(self, weights, seed=None):
+ def randomSplit(self, weights: List[float], seed: Optional[int] = None) ->
List["DataFrame"]:
"""Randomly splits this :class:`DataFrame` with the provided weights.
.. versionadded:: 1.4.0
@@ -1193,11 +1278,13 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
if w < 0.0:
raise ValueError("Weights must be positive. Found weight
value: %s" % w)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
- rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights),
int(seed))
+ rdd_array = self._jdf.randomSplit(
+ _to_list(self.sql_ctx._sc, weights), int(seed) # type:
ignore[attr-defined]
+ )
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
@property
- def dtypes(self):
+ def dtypes(self) -> List[Tuple[str, str]]:
"""Returns all column names and their data types as a list.
.. versionadded:: 1.3.0
@@ -1210,7 +1297,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return [(str(f.name), f.dataType.simpleString()) for f in
self.schema.fields]
@property
- def columns(self):
+ def columns(self) -> List[str]:
"""Returns all column names as a list.
.. versionadded:: 1.3.0
@@ -1222,7 +1309,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return [f.name for f in self.schema.fields]
- def colRegex(self, colName):
+ def colRegex(self, colName: str) -> Column:
"""
Selects column based on the column name specified as a regex and
returns it
as :class:`Column`.
@@ -1251,7 +1338,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jc = self._jdf.colRegex(colName)
return Column(jc)
- def alias(self, alias):
+ def alias(self, alias: str) -> "DataFrame":
"""Returns a new :class:`DataFrame` with an alias set.
.. versionadded:: 1.3.0
@@ -1274,7 +1361,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
assert isinstance(alias, str), "alias should be a string"
return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
- def crossJoin(self, other):
+ def crossJoin(self, other: "DataFrame") -> "DataFrame":
"""Returns the cartesian product with another :class:`DataFrame`.
.. versionadded:: 2.1.0
@@ -1298,7 +1385,12 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.crossJoin(other._jdf)
return DataFrame(jdf, self.sql_ctx)
- def join(self, other, on=None, how=None):
+ def join(
+ self,
+ other: "DataFrame",
+ on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+ how: Optional[str] = None
+ ) -> "DataFrame":
"""Joins with another :class:`DataFrame`, using the given join
expression.
.. versionadded:: 1.3.0
@@ -1342,14 +1434,14 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
if on is not None and not isinstance(on, list):
- on = [on]
+ on = [on] # type: ignore[assignment]
if on is not None:
if isinstance(on[0], str):
- on = self._jseq(on)
+ on = self._jseq(cast(List[str], on))
else:
assert isinstance(on[0], Column), "on should be Column or list
of Column"
- on = reduce(lambda x, y: x.__and__(y), on)
+ on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
on = on._jc
if on is None and how is None:
@@ -1366,16 +1458,16 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
- other,
- leftAsOfColumn,
- rightAsOfColumn,
- on=None,
- how=None,
+ other: "DataFrame",
+ leftAsOfColumn: Union[str, Column],
+ rightAsOfColumn: Union[str, Column],
+ on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+ how: Optional[str] = None,
*,
- tolerance=None,
- allowExactMatches=True,
- direction="backward",
- ):
+ tolerance: Optional[Column] = None,
+ allowExactMatches: bool = True,
+ direction: str = "backward",
+ ) -> "DataFrame":
"""
Perform an as-of join.
@@ -1448,20 +1540,20 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
if isinstance(leftAsOfColumn, str):
leftAsOfColumn = self[leftAsOfColumn]
- left_as_of_jcol = leftAsOfColumn._jc
+ left_as_of_jcol = cast(Column, leftAsOfColumn)._jc
if isinstance(rightAsOfColumn, str):
rightAsOfColumn = other[rightAsOfColumn]
- right_as_of_jcol = rightAsOfColumn._jc
+ right_as_of_jcol = cast(Column, rightAsOfColumn)._jc
if on is not None and not isinstance(on, list):
- on = [on]
+ on = [on] # type: ignore[assignment]
if on is not None:
if isinstance(on[0], str):
- on = self._jseq(on)
+ on = self._jseq(cast(List[str], on))
else:
assert isinstance(on[0], Column), "on should be Column or list
of Column"
- on = reduce(lambda x, y: x.__and__(y), on)
+ on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
on = on._jc
if how is None:
@@ -1480,7 +1572,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
)
return DataFrame(jdf, self.sql_ctx)
- def sortWithinPartitions(self, *cols, **kwargs):
+ def sortWithinPartitions(
+ self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
+ ) -> "DataFrame":
"""Returns a new :class:`DataFrame` with each partition sorted by the
specified column(s).
.. versionadded:: 1.6.0
@@ -1510,7 +1604,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
return DataFrame(jdf, self.sql_ctx)
- def sort(self, *cols, **kwargs):
+ def sort(
+ self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
+ ) -> "DataFrame":
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
.. versionadded:: 1.3.0
@@ -1548,15 +1644,19 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
orderBy = sort
- def _jseq(self, cols, converter=None):
+ def _jseq(
+ self,
+ cols: Sequence,
+ converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]]
= None
+ ) -> JavaObject:
"""Return a JVM Seq of Columns from a list of Column or names"""
- return _to_seq(self.sql_ctx._sc, cols, converter)
+ return _to_seq(self.sql_ctx._sc, cols, converter) # type:
ignore[attr-defined]
- def _jmap(self, jm):
+ def _jmap(self, jm: Dict) -> JavaObject:
"""Return a JVM Scala Map from a dict"""
- return _to_scala_map(self.sql_ctx._sc, jm)
+ return _to_scala_map(self.sql_ctx._sc, jm) # type:
ignore[attr-defined]
- def _jcols(self, *cols):
+ def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
"""Return a JVM Seq of Columns from a list of Column or column names
If `cols` has only one list in it, cols[0] will be used as the list.
@@ -1565,7 +1665,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
cols = cols[0]
return self._jseq(cols, _to_java_column)
- def _sort_cols(self, cols, kwargs):
+ def _sort_cols(
+ self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]],
kwargs: Dict[str, Any]
+ ) -> JavaObject:
""" Return a JVM Seq of Columns that describes the sort order
"""
if not cols:
@@ -1584,7 +1686,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("ascending can only be boolean or list, but got
%s" % type(ascending))
return self._jseq(jcols)
- def describe(self, *cols):
+ def describe(self, *cols: Union[str, List[str]]) -> "DataFrame":
"""Computes basic statistics for numeric and string columns.
.. versionadded:: 1.3.1
@@ -1628,11 +1730,11 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
DataFrame.summary
"""
if len(cols) == 1 and isinstance(cols[0], list):
- cols = cols[0]
+ cols = cols[0] # type: ignore[assignment]
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
- def summary(self, *statistics):
+ def summary(self, *statistics: str) -> "DataFrame":
"""Computes specified statistics for numeric and string columns.
Available statistics are:
- count
- mean
@@ -1697,7 +1799,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.summary(self._jseq(statistics))
return DataFrame(jdf, self.sql_ctx)
- def head(self, n=None):
+ @overload
+ def head(self) -> Optional[Row]:
+ ...
+
+ @overload
+ def head(self, n: int) -> List[Row]:
+ ...
+
+ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
"""Returns the first ``n`` rows.
.. versionadded:: 1.3.0
@@ -1729,7 +1839,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return rs[0] if rs else None
return self.take(n)
- def first(self):
+ def first(self) -> Optional[Row]:
"""Returns the first row as a :class:`Row`.
.. versionadded:: 1.3.0
@@ -1741,7 +1851,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
return self.head()
- def __getitem__(self, item):
+ @overload
+ def __getitem__(self, item: Union[int, str]) -> Column:
+ ...
+
+ @overload
+ def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":
+ ...
+
+ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) ->
Union[Column, "DataFrame"]:
"""Returns the column as a :class:`Column`.
.. versionadded:: 1.3.0
@@ -1770,7 +1888,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
else:
raise TypeError("unexpected item type: %s" % type(item))
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Column:
"""Returns the :class:`Column` denoted by ``name``.
.. versionadded:: 1.3.0
@@ -1786,7 +1904,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
jc = self._jdf.apply(name)
return Column(jc)
- def select(self, *cols):
+ @overload
+ def select(self, *cols: "ColumnOrName") -> "DataFrame":
+ ...
+
+ @overload
+ def select(self, __cols: Union[List[Column], List[str]]) -> "DataFrame":
+ ...
+
+ def select(self, *cols: "ColumnOrName") -> "DataFrame": # type:
ignore[misc]
"""Projects a set of expressions and returns a new :class:`DataFrame`.
.. versionadded:: 1.3.0
@@ -1810,7 +1936,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sql_ctx)
- def selectExpr(self, *expr):
+ @overload
+ def selectExpr(self, *expr: str) -> "DataFrame":
+ ...
+
+ @overload
+ def selectExpr(self, *expr: List[str]) -> "DataFrame":
+ ...
+
+ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
"""Projects a set of SQL expressions and returns a new
:class:`DataFrame`.
This is a variant of :func:`select` that accepts SQL expressions.
@@ -1823,11 +1957,11 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
[Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
"""
if len(expr) == 1 and isinstance(expr[0], list):
- expr = expr[0]
+ expr = expr[0] # type: ignore[assignment]
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sql_ctx)
- def filter(self, condition):
+ def filter(self, condition: "ColumnOrName") -> "DataFrame":
"""Filters rows using the given condition.
:func:`where` is an alias for :func:`filter`.
@@ -1860,7 +1994,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
raise TypeError("condition should be string or Column")
return DataFrame(jdf, self.sql_ctx)
- def groupBy(self, *cols):
+ @overload
+ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
+ ...
+
+ @overload
+ def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ ...
+
+ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
"""Groups the :class:`DataFrame` using the specified columns,
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
@@ -1890,7 +2032,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
- def rollup(self, *cols):
+ @overload
+ def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+ ...
+
+ @overload
+ def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ ...
+
+ def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
"""
Create a multi-dimensional rollup for the current :class:`DataFrame`
using
the specified columns, so we can run aggregation on them.
@@ -1914,7 +2064,15 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
- def cube(self, *cols):
+ @overload
+ def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+ ...
+
+ @overload
+ def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+ ...
+
+ def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type:
ignore[misc]
"""
Create a multi-dimensional cube for the current :class:`DataFrame`
using
the specified columns, so we can run aggregations on them.
@@ -1940,7 +2098,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
from pyspark.sql.group import GroupedData
return GroupedData(jgd, self)
- def agg(self, *exprs):
+ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy().agg()``).
@@ -1954,10 +2112,10 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
>>> df.agg(F.min(df.age)).collect()
[Row(min(age)=2)]
"""
- return self.groupBy().agg(*exprs)
+ return self.groupBy().agg(*exprs) # type: ignore[arg-type]
@since(3.3)
- def observe(self, observation, *exprs):
+ def observe(self, observation: "Observation", *exprs: Column) ->
"DataFrame":
"""Observe (named) metrics through an :class:`Observation` instance.
A user can retrieve the metrics by accessing `Observation.get`.
@@ -1996,7 +2154,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return observation._on(self, *exprs)
@since(2.0)
- def union(self, other):
+ def union(self, other: "DataFrame") -> "DataFrame":
""" Return a new :class:`DataFrame` containing union of rows in this
and another
:class:`DataFrame`.
@@ -2008,7 +2166,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
@since(1.3)
- def unionAll(self, other):
+ def unionAll(self, other: "DataFrame") -> "DataFrame":
""" Return a new :class:`DataFrame` containing union of rows in this
and another
:class:`DataFrame`.
@@ -2019,7 +2177,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return self.union(other)
- def unionByName(self, other, allowMissingColumns=False):
+ def unionByName(self, other: "DataFrame", allowMissingColumns: bool =
False) -> "DataFrame":
""" Returns a new :class:`DataFrame` containing union of rows in this
and another
:class:`DataFrame`.
@@ -2065,7 +2223,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrame(self._jdf.unionByName(other._jdf,
allowMissingColumns), self.sql_ctx)
@since(1.3)
- def intersect(self, other):
+ def intersect(self, other: "DataFrame") -> "DataFrame":
""" Return a new :class:`DataFrame` containing rows only in
both this :class:`DataFrame` and another :class:`DataFrame`.
@@ -2073,7 +2231,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
- def intersectAll(self, other):
+ def intersectAll(self, other: "DataFrame") -> "DataFrame":
""" Return a new :class:`DataFrame` containing rows in both this
:class:`DataFrame`
and another :class:`DataFrame` while preserving duplicates.
@@ -2100,7 +2258,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
@since(1.3)
- def subtract(self, other):
+ def subtract(self, other: "DataFrame") -> "DataFrame":
""" Return a new :class:`DataFrame` containing rows in this
:class:`DataFrame`
but not in another :class:`DataFrame`.
@@ -2109,7 +2267,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf),
self.sql_ctx)
- def dropDuplicates(self, subset=None):
+ def dropDuplicates(self, subset: Optional[List[str]] = None) ->
"DataFrame":
"""Return a new :class:`DataFrame` with duplicate rows removed,
optionally only considering certain columns.
@@ -2155,7 +2313,12 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.dropDuplicates(self._jseq(subset))
return DataFrame(jdf, self.sql_ctx)
- def dropna(self, how='any', thresh=None, subset=None):
+ def dropna(
+ self,
+ how: str = 'any',
+ thresh: Optional[int] = None,
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None
+ ) -> "DataFrame":
"""Returns a new :class:`DataFrame` omitting rows with null values.
:func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are
aliases of each other.
@@ -2198,7 +2361,23 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)),
self.sql_ctx)
- def fillna(self, value, subset=None):
+ @overload
+ def fillna(
+ self,
+ value: "LiteralType",
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def fillna(self, value: Dict[str, "LiteralType"]) -> "DataFrame":
+ ...
+
+ def fillna(
+ self,
+ value: Union["LiteralType", Dict[str, "LiteralType"]],
+ subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None
+ ) -> "DataFrame":
"""Replace null values, alias for ``na.fill()``.
:func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are
aliases of each other.
@@ -2269,7 +2448,49 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)),
self.sql_ctx)
- def replace(self, to_replace, value=_NoValue, subset=None):
+ @overload
+ def replace(
+ self,
+ to_replace: "LiteralType",
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: List["OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> "DataFrame":
+ ...
+
+ def replace( # type: ignore[misc]
+ self,
+ to_replace: Union[
+ "LiteralType", List["LiteralType"], Dict["LiteralType",
"OptionalPrimitiveType"]
+ ],
+ value: Optional[Union["OptionalPrimitiveType",
List["OptionalPrimitiveType"]]] = _NoValue,
+ subset: Optional[List[str]] = None
+ ) -> "DataFrame":
"""Returns a new :class:`DataFrame` replacing a value with another
value.
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
aliases of each other.
@@ -2348,7 +2569,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("value argument is required when to_replace is
not a dictionary.")
# Helper functions
- def all_of(types):
+ def all_of(types: Union[Type, Tuple[Type, ...]]) ->
Callable[[Iterable], bool]:
"""Given a type or tuple of types and a sequence of xs
check if each x is instance of type(s)
@@ -2357,7 +2578,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
>>> all_of(str)(["a", 1])
False
"""
- def all_of_(xs):
+ def all_of_(xs: Iterable) -> bool:
return all(isinstance(x, types) for x in xs)
return all_of_
@@ -2415,7 +2636,30 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
return DataFrame(
self._jdf.na().replace(self._jseq(subset),
self._jmap(rep_dict)), self.sql_ctx)
- def approxQuantile(self, col, probabilities, relativeError):
+ @overload
+ def approxQuantile(
+ self,
+ col: str,
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[float]:
+ ...
+
+ @overload
+ def approxQuantile(
+ self,
+ col: Union[List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[List[float]]:
+ ...
+
+ def approxQuantile(
+ self,
+ col: Union[str, List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float
+ ) -> Union[List[float], List[List[float]]]:
"""
Calculates the approximate quantiles of numerical columns of a
:class:`DataFrame`.
@@ -2474,7 +2718,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
if isinstance(col, tuple):
col = list(col)
elif isStr:
- col = [col]
+ col = [cast(str, col)]
for c in col:
if not isinstance(c, str):
@@ -2500,7 +2744,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jaq_list = [list(j) for j in jaq]
return jaq_list[0] if isStr else jaq_list
- def corr(self, col1, col2, method=None):
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
"""
Calculates the correlation of two columns of a :class:`DataFrame` as a
double value.
Currently only supports the Pearson Correlation Coefficient.
@@ -2528,7 +2772,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"coefficient is supported.")
return self._jdf.stat().corr(col1, col2, method)
- def cov(self, col1, col2):
+ def cov(self, col1: str, col2: str) -> float:
"""
Calculate the sample covariance for the given columns, specified by
their names, as a
double value. :func:`DataFrame.cov` and
:func:`DataFrameStatFunctions.cov` are aliases.
@@ -2548,7 +2792,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)
- def crosstab(self, col1, col2):
+ def crosstab(self, col1: str, col2: str) -> "DataFrame":
"""
Computes a pair-wise frequency table of the given columns. Also known
as a contingency
table. The number of distinct values for each column should be less
than 1e4. At most 1e6
@@ -2575,7 +2819,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
- def freqItems(self, cols, support=None):
+ def freqItems(
+ self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
None
+ ) -> "DataFrame":
"""
Finding frequent items for columns, possibly with false positives.
Using the
frequent element count algorithm described in
@@ -2607,7 +2853,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
support = 0.01
return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols),
support), self.sql_ctx)
- def withColumn(self, colName, col):
+ def withColumn(self, colName: str, col: Column) -> "DataFrame":
"""
Returns a new :class:`DataFrame` by adding a column or replacing the
existing column that has the same name.
@@ -2641,7 +2887,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
raise TypeError("col should be Column")
return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
- def withColumnRenamed(self, existing, new):
+ def withColumnRenamed(self, existing: str, new: str) -> "DataFrame":
"""Returns a new :class:`DataFrame` by renaming an existing column.
This is a no-op if schema doesn't contain the given column name.
@@ -2661,7 +2907,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return DataFrame(self._jdf.withColumnRenamed(existing, new),
self.sql_ctx)
- def withMetadata(self, columnName, metadata):
+ def withMetadata(self, columnName: str, metadata: Dict[str, Any]) ->
"DataFrame":
"""Returns a new :class:`DataFrame` by updating an existing column
with metadata.
.. versionadded:: 3.3.0
@@ -2681,12 +2927,20 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
"""
if not isinstance(metadata, dict):
raise TypeError("metadata should be a dict")
- sc = SparkContext._active_spark_context
+ sc = SparkContext._active_spark_context # type: ignore[attr-defined]
jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(
json.dumps(metadata))
return DataFrame(self._jdf.withMetadata(columnName, jmeta),
self.sql_ctx)
- def drop(self, *cols):
+ @overload
+ def drop(self, cols: "ColumnOrName") -> "DataFrame":
+ ...
+
+ @overload
+ def drop(self, *cols: str) -> "DataFrame":
+ ...
+
+ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc]
"""Returns a new :class:`DataFrame` that drops the specified column.
This is a no-op if schema doesn't contain the given column name(s).
@@ -2730,7 +2984,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return DataFrame(jdf, self.sql_ctx)
- def toDF(self, *cols):
+ def toDF(self, *cols: "ColumnOrName") -> "DataFrame":
"""Returns a new :class:`DataFrame` that with new specified column
names
Parameters
@@ -2746,7 +3000,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
jdf = self._jdf.toDF(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
- def transform(self, func):
+ def transform(self, func: Callable[["DataFrame"], "DataFrame"]) ->
"DataFrame":
"""Returns a new :class:`DataFrame`. Concise syntax for chaining
custom transformations.
.. versionadded:: 3.0.0
@@ -2777,7 +3031,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"should have been DataFrame." %
type(result)
return result
- def sameSemantics(self, other):
+ def sameSemantics(self, other: "DataFrame") -> bool:
"""
Returns `True` when the logical query plans inside both
:class:`DataFrame`\\s are equal and
therefore return same results.
@@ -2811,7 +3065,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
% type(other))
return self._jdf.sameSemantics(other._jdf)
- def semanticHash(self):
+ def semanticHash(self) -> int:
"""
Returns a hash code of the logical query plan against this
:class:`DataFrame`.
@@ -2833,7 +3087,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return self._jdf.semanticHash()
- def inputFiles(self):
+ def inputFiles(self) -> List[str]:
"""
Returns a best-effort snapshot of the files that compose this
:class:`DataFrame`.
This method simply asks each constituent BaseRelation for its
respective files and
@@ -2870,7 +3124,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
sinceversion=1.4,
doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.")
- def writeTo(self, table):
+ def writeTo(self, table: str) -> DataFrameWriterV2:
"""
Create a write configuration builder for v2 sources.
@@ -2889,7 +3143,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""
return DataFrameWriterV2(self, table)
- def to_pandas_on_spark(self, index_col=None):
+ def to_pandas_on_spark(
+ self, index_col: Optional[Union[str, List[str]]] = None
+ ) -> "PandasOnSparkDataFrame":
"""
Converts the existing DataFrame into a pandas-on-Spark DataFrame.
@@ -2935,17 +3191,20 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
c 3
"""
from pyspark.pandas.namespace import _get_index_map
- from pyspark.pandas.frame import DataFrame
+ from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.pandas.internal import InternalFrame
index_spark_columns, index_names = _get_index_map(self, index_col)
internal = InternalFrame(
- spark_frame=self, index_spark_columns=index_spark_columns,
index_names=index_names
+ spark_frame=self, index_spark_columns=index_spark_columns,
+ index_names=index_names # type: ignore[arg-type]
)
- return DataFrame(internal)
+ return PandasOnSparkDataFrame(internal)
# Keep to_koalas for backward compatibility for now.
- def to_koalas(self, index_col=None):
+ def to_koalas(
+ self, index_col: Optional[Union[str, List[str]]] = None
+ ) -> "PandasOnSparkDataFrame":
warnings.warn(
"DataFrame.to_koalas is deprecated. Use
DataFrame.to_pandas_on_spark instead.",
FutureWarning,
@@ -2953,11 +3212,11 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
return self.to_pandas_on_spark(index_col)
-def _to_scala_map(sc, jm):
+def _to_scala_map(sc: SparkContext, jm: Dict) -> JavaObject:
"""
Convert a dict into a JVM Map.
"""
- return sc._jvm.PythonUtils.toScalaMap(jm)
+ return sc._jvm.PythonUtils.toScalaMap(jm) # type: ignore[attr-defined]
class DataFrameNaFunctions(object):
@@ -2966,21 +3225,70 @@ class DataFrameNaFunctions(object):
.. versionadded:: 1.4
"""
- def __init__(self, df):
+ def __init__(self, df: DataFrame):
self.df = df
- def drop(self, how='any', thresh=None, subset=None):
+ def drop(
+ self, how: str = 'any', thresh: Optional[int] = None, subset:
Optional[List[str]] = None
+ ) -> DataFrame:
return self.df.dropna(how=how, thresh=thresh, subset=subset)
drop.__doc__ = DataFrame.dropna.__doc__
- def fill(self, value, subset=None):
- return self.df.fillna(value=value, subset=subset)
+ @overload
+ def fill(
+ self, value: "LiteralType", subset: Optional[List[str]] = ...
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def fill(self, value: Dict[str, "LiteralType"]) -> DataFrame:
+ ...
+
+ def fill(
+ self,
+ value: Union["LiteralType", Dict[str, "LiteralType"]],
+ subset: Optional[List[str]] = None
+ ) -> DataFrame:
+ return self.df.fillna(value=value, subset=subset) # type:
ignore[arg-type]
fill.__doc__ = DataFrame.fillna.__doc__
- def replace(self, to_replace, value=_NoValue, subset=None):
- return self.df.replace(to_replace, value, subset)
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: List["OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+ subset: Optional[List[str]] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def replace(
+ self,
+ to_replace: List["LiteralType"],
+ value: "OptionalPrimitiveType",
+ subset: Optional[List[str]] = ...,
+ ) -> DataFrame:
+ ...
+
+ def replace( # type: ignore[misc]
+ self,
+ to_replace: Union[
+ List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
+ ],
+ value: Optional[Union["OptionalPrimitiveType",
List["OptionalPrimitiveType"]]] = _NoValue,
+ subset: Optional[List[str]] = None
+ ) -> DataFrame:
+ return self.df.replace(to_replace, value, subset) # type:
ignore[arg-type]
replace.__doc__ = DataFrame.replace.__doc__
@@ -2991,41 +3299,66 @@ class DataFrameStatFunctions(object):
.. versionadded:: 1.4
"""
- def __init__(self, df):
+ def __init__(self, df: DataFrame):
self.df = df
- def approxQuantile(self, col, probabilities, relativeError):
+ @overload
+ def approxQuantile(
+ self,
+ col: str,
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[float]:
+ ...
+
+ @overload
+ def approxQuantile(
+ self,
+ col: Union[List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float,
+ ) -> List[List[float]]:
+ ...
+
+ def approxQuantile( # type: ignore[misc]
+ self,
+ col: Union[str, List[str], Tuple[str]],
+ probabilities: Union[List[float], Tuple[float]],
+ relativeError: float
+ ) -> Union[List[float], List[List[float]]]:
return self.df.approxQuantile(col, probabilities, relativeError)
approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__
- def corr(self, col1, col2, method=None):
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
return self.df.corr(col1, col2, method)
corr.__doc__ = DataFrame.corr.__doc__
- def cov(self, col1, col2):
+ def cov(self, col1: str, col2: str) -> float:
return self.df.cov(col1, col2)
cov.__doc__ = DataFrame.cov.__doc__
- def crosstab(self, col1, col2):
+ def crosstab(self, col1: str, col2: str) -> DataFrame:
return self.df.crosstab(col1, col2)
crosstab.__doc__ = DataFrame.crosstab.__doc__
- def freqItems(self, cols, support=None):
+ def freqItems(self, cols: List[str], support: Optional[float] = None) ->
DataFrame:
return self.df.freqItems(cols, support)
freqItems.__doc__ = DataFrame.freqItems.__doc__
- def sampleBy(self, col, fractions, seed=None):
+ def sampleBy(
+ self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
+ ) -> DataFrame:
return self.df.sampleBy(col, fractions, seed)
sampleBy.__doc__ = DataFrame.sampleBy.__doc__
-def _test():
+def _test() -> None:
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext, SparkSession
diff --git a/python/pyspark/sql/dataframe.pyi b/python/pyspark/sql/dataframe.pyi
deleted file mode 100644
index d903a79..0000000
--- a/python/pyspark/sql/dataframe.pyi
+++ /dev/null
@@ -1,351 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
- Any,
- Callable,
- Dict,
- Iterator,
- List,
- Optional,
- Tuple,
- Union,
-)
-
-from py4j.java_gateway import JavaObject # type: ignore[import]
-
-from pyspark.sql._typing import ColumnOrName, LiteralType,
OptionalPrimitiveType
-from pyspark._typing import PrimitiveType
-from pyspark.sql.types import ( # noqa: F401
- StructType,
- StructField,
- StringType,
- IntegerType,
- Row,
-) # noqa: F401
-from pyspark.sql.context import SQLContext
-from pyspark.sql.group import GroupedData
-from pyspark.sql.observation import Observation
-from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
-from pyspark.sql.streaming import DataStreamWriter
-from pyspark.sql.column import Column
-from pyspark.rdd import RDD
-from pyspark.storagelevel import StorageLevel
-
-from pyspark.sql.pandas.conversion import PandasConversionMixin
-from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
-from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
-
-class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
- sql_ctx: SQLContext
- is_cached: bool
- def __init__(self, jdf: JavaObject, sql_ctx: SQLContext) -> None: ...
- @property
- def rdd(self) -> RDD[Row]: ...
- @property
- def na(self) -> DataFrameNaFunctions: ...
- @property
- def stat(self) -> DataFrameStatFunctions: ...
- def toJSON(self, use_unicode: bool = ...) -> RDD[str]: ...
- def registerTempTable(self, name: str) -> None: ...
- def createTempView(self, name: str) -> None: ...
- def createOrReplaceTempView(self, name: str) -> None: ...
- def createGlobalTempView(self, name: str) -> None: ...
- @property
- def write(self) -> DataFrameWriter: ...
- @property
- def writeStream(self) -> DataStreamWriter: ...
- @property
- def schema(self) -> StructType: ...
- def printSchema(self) -> None: ...
- def explain(
- self, extended: Optional[Union[bool, str]] = ..., mode: Optional[str]
= ...
- ) -> None: ...
- def exceptAll(self, other: DataFrame) -> DataFrame: ...
- def isLocal(self) -> bool: ...
- @property
- def isStreaming(self) -> bool: ...
- def show(
- self, n: int = ..., truncate: Union[bool, int] = ..., vertical: bool =
...
- ) -> None: ...
- def checkpoint(self, eager: bool = ...) -> DataFrame: ...
- def localCheckpoint(self, eager: bool = ...) -> DataFrame: ...
- def withWatermark(
- self, eventTime: str, delayThreshold: str
- ) -> DataFrame: ...
- def hint(self, name: str, *parameters: Union[PrimitiveType,
List[PrimitiveType]]) -> DataFrame: ...
- def count(self) -> int: ...
- def collect(self) -> List[Row]: ...
- def toLocalIterator(self, prefetchPartitions: bool = ...) ->
Iterator[Row]: ...
- def limit(self, num: int) -> DataFrame: ...
- def take(self, num: int) -> List[Row]: ...
- def tail(self, num: int) -> List[Row]: ...
- def foreach(self, f: Callable[[Row], None]) -> None: ...
- def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: ...
- def cache(self) -> DataFrame: ...
- def persist(self, storageLevel: StorageLevel = ...) -> DataFrame: ...
- @property
- def storageLevel(self) -> StorageLevel: ...
- def unpersist(self, blocking: bool = ...) -> DataFrame: ...
- def coalesce(self, numPartitions: int) -> DataFrame: ...
- @overload
- def repartition(self, numPartitions: int, *cols: ColumnOrName) ->
DataFrame: ...
- @overload
- def repartition(self, *cols: ColumnOrName) -> DataFrame: ...
- @overload
- def repartitionByRange(
- self, numPartitions: int, *cols: ColumnOrName
- ) -> DataFrame: ...
- @overload
- def repartitionByRange(self, *cols: ColumnOrName) -> DataFrame: ...
- def distinct(self) -> DataFrame: ...
- @overload
- def sample(self, fraction: float, seed: Optional[int] = ...) -> DataFrame:
...
- @overload
- def sample(
- self,
- withReplacement: Optional[bool],
- fraction: float,
- seed: Optional[int] = ...,
- ) -> DataFrame: ...
- def sampleBy(
- self, col: ColumnOrName, fractions: Dict[Any, float], seed:
Optional[int] = ...
- ) -> DataFrame: ...
- def randomSplit(
- self, weights: List[float], seed: Optional[int] = ...
- ) -> List[DataFrame]: ...
- @property
- def dtypes(self) -> List[Tuple[str, str]]: ...
- @property
- def columns(self) -> List[str]: ...
- def colRegex(self, colName: str) -> Column: ...
- def alias(self, alias: str) -> DataFrame: ...
- def crossJoin(self, other: DataFrame) -> DataFrame: ...
- def join(
- self,
- other: DataFrame,
- on: Optional[Union[str, List[str], Column, List[Column]]] = ...,
- how: Optional[str] = ...,
- ) -> DataFrame: ...
- def sortWithinPartitions(
- self,
- *cols: Union[str, Column, List[Union[str, Column]]],
- ascending: Union[bool, List[bool]] = ...
- ) -> DataFrame: ...
- def sort(
- self,
- *cols: Union[str, Column, List[Union[str, Column]]],
- ascending: Union[bool, List[bool]] = ...
- ) -> DataFrame: ...
- def orderBy(
- self,
- *cols: Union[str, Column, List[Union[str, Column]]],
- ascending: Union[bool, List[bool]] = ...
- ) -> DataFrame: ...
- def describe(self, *cols: Union[str, List[str]]) -> DataFrame: ...
- def summary(self, *statistics: str) -> DataFrame: ...
- @overload
- def head(self) -> Row: ...
- @overload
- def head(self, n: int) -> List[Row]: ...
- def first(self) -> Row: ...
- def __getitem__(self, item: Union[int, str, Column, List, Tuple]) ->
Column: ...
- def __getattr__(self, name: str) -> Column: ...
- @overload
- def select(self, *cols: ColumnOrName) -> DataFrame: ...
- @overload
- def select(self, __cols: Union[List[Column], List[str]]) -> DataFrame: ...
- @overload
- def selectExpr(self, *expr: str) -> DataFrame: ...
- @overload
- def selectExpr(self, *expr: List[str]) -> DataFrame: ...
- def filter(self, condition: ColumnOrName) -> DataFrame: ...
- @overload
- def groupBy(self, *cols: ColumnOrName) -> GroupedData: ...
- @overload
- def groupBy(self, __cols: Union[List[Column], List[str]]) -> GroupedData:
...
- @overload
- def rollup(self, *cols: ColumnOrName) -> GroupedData: ...
- @overload
- def rollup(self, __cols: Union[List[Column], List[str]]) -> GroupedData:
...
- @overload
- def cube(self, *cols: ColumnOrName) -> GroupedData: ...
- @overload
- def cube(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
- def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: ...
- def observe(self, observation: Observation, *exprs: Column) -> DataFrame:
...
- def union(self, other: DataFrame) -> DataFrame: ...
- def unionAll(self, other: DataFrame) -> DataFrame: ...
- def unionByName(
- self, other: DataFrame, allowMissingColumns: bool = ...
- ) -> DataFrame: ...
- def intersect(self, other: DataFrame) -> DataFrame: ...
- def intersectAll(self, other: DataFrame) -> DataFrame: ...
- def subtract(self, other: DataFrame) -> DataFrame: ...
- def dropDuplicates(self, subset: Optional[List[str]] = ...) -> DataFrame:
...
- def dropna(
- self,
- how: str = ...,
- thresh: Optional[int] = ...,
- subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
- ) -> DataFrame: ...
- @overload
- def fillna(
- self,
- value: LiteralType,
- subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
- ) -> DataFrame: ...
- @overload
- def fillna(self, value: Dict[str, LiteralType]) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: LiteralType,
- value: OptionalPrimitiveType,
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: List[LiteralType],
- value: List[OptionalPrimitiveType],
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: Dict[LiteralType, OptionalPrimitiveType],
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: List[LiteralType],
- value: OptionalPrimitiveType,
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def approxQuantile(
- self,
- col: str,
- probabilities: Union[List[float], Tuple[float]],
- relativeError: float,
- ) -> List[float]: ...
- @overload
- def approxQuantile(
- self,
- col: Union[List[str], Tuple[str]],
- probabilities: Union[List[float], Tuple[float]],
- relativeError: float,
- ) -> List[List[float]]: ...
- def corr(self, col1: str, col2: str, method: Optional[str] = ...) ->
float: ...
- def cov(self, col1: str, col2: str) -> float: ...
- def crosstab(self, col1: str, col2: str) -> DataFrame: ...
- def freqItems(
- self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
...
- ) -> DataFrame: ...
- def withColumn(self, colName: str, col: Column) -> DataFrame: ...
- def withColumnRenamed(self, existing: str, new: str) -> DataFrame: ...
- @overload
- def drop(self, cols: ColumnOrName) -> DataFrame: ...
- @overload
- def drop(self, *cols: str) -> DataFrame: ...
- def toDF(self, *cols: ColumnOrName) -> DataFrame: ...
- def transform(self, func: Callable[[DataFrame], DataFrame]) -> DataFrame:
...
- @overload
- def groupby(self, *cols: ColumnOrName) -> GroupedData: ...
- @overload
- def groupby(self, __cols: Union[List[Column], List[str]]) -> GroupedData:
...
- def drop_duplicates(self, subset: Optional[List[str]] = ...) -> DataFrame:
...
- def where(self, condition: ColumnOrName) -> DataFrame: ...
- def sameSemantics(self, other: DataFrame) -> bool: ...
- def semanticHash(self) -> int: ...
- def inputFiles(self) -> List[str]: ...
- def writeTo(self, table: str) -> DataFrameWriterV2: ...
- def to_pandas_on_spark(self, index_col: Optional[Union[str, List[str]]] =
None) -> PandasOnSparkDataFrame: ...
-
-class DataFrameNaFunctions:
- df: DataFrame
- def __init__(self, df: DataFrame) -> None: ...
- def drop(
- self,
- how: str = ...,
- thresh: Optional[int] = ...,
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def fill(
- self, value: LiteralType, subset: Optional[List[str]] = ...
- ) -> DataFrame: ...
- @overload
- def fill(self, value: Dict[str, LiteralType]) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: LiteralType,
- value: OptionalPrimitiveType,
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: List[LiteralType],
- value: List[OptionalPrimitiveType],
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: Dict[LiteralType, OptionalPrimitiveType],
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
- @overload
- def replace(
- self,
- to_replace: List[LiteralType],
- value: OptionalPrimitiveType,
- subset: Optional[List[str]] = ...,
- ) -> DataFrame: ...
-
-class DataFrameStatFunctions:
- df: DataFrame
- def __init__(self, df: DataFrame) -> None: ...
- @overload
- def approxQuantile(
- self,
- col: str,
- probabilities: Union[List[float], Tuple[float]],
- relativeError: float,
- ) -> List[float]: ...
- @overload
- def approxQuantile(
- self,
- col: Union[List[str], Tuple[str]],
- probabilities: Union[List[float], Tuple[float]],
- relativeError: float,
- ) -> List[List[float]]: ...
- def corr(self, col1: str, col2: str, method: Optional[str] = ...) ->
float: ...
- def cov(self, col1: str, col2: str) -> float: ...
- def crosstab(self, col1: str, col2: str) -> DataFrame: ...
- def freqItems(
- self, cols: List[str], support: Optional[float] = ...
- ) -> DataFrame: ...
- def sampleBy(
- self, col: str, fractions: Dict[Any, float], seed: Optional[int] = ...
- ) -> DataFrame: ...
diff --git a/python/pyspark/sql/observation.py
b/python/pyspark/sql/observation.py
index 3e8a0d1..48d8176 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -99,7 +99,7 @@ class Observation:
assert all(isinstance(c, Column) for c in exprs), "all exprs should be
Column"
assert self._jo is None, "an Observation can be used with a DataFrame
only once"
- self._jvm = df._sc._jvm # type: ignore[assignment]
+ self._jvm = df._sc._jvm # type: ignore[assignment, attr-defined]
cls = self._jvm.org.apache.spark.sql.Observation # type:
ignore[attr-defined]
self._jo = cls(self._name) if self._name is not None else cls()
observed_df = self._jo.on( # type: ignore[attr-defined]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]