Repository: spark Updated Branches: refs/heads/branch-1.2 17b7cc733 -> 576fc54e5
[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2) The eq of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). This PR also improve the performance of inferSchema (avoid the unnecessary converter of object). Author: Davies Liu <[email protected]> Closes #4809 from davies/leak2 and squashes the following commits: 65c222f [Davies Liu] Update sql.py 9b4dadc [Davies Liu] fix __eq__ of singleton b576107 [Davies Liu] fix tests 6c2909a [Davies Liu] fix incorrect DataType.__eq__ Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/576fc54e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/576fc54e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/576fc54e Branch: refs/heads/branch-1.2 Commit: 576fc54e5c154fc28af1a732a6bea452d0a5cabb Parents: 17b7cc7 Author: Davies Liu <[email protected]> Authored: Fri Feb 27 20:04:16 2015 -0800 Committer: Josh Rosen <[email protected]> Committed: Fri Feb 27 20:04:16 2015 -0800 ---------------------------------------------------------------------- python/pyspark/sql.py | 67 ++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/576fc54e/python/pyspark/sql.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index aa5af1b..4410925 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -36,6 +36,7 @@ import keyword import warnings import json import re +import weakref from array import array from operator import itemgetter from itertools import imap @@ -68,8 +69,7 @@ class DataType(object): return hash(str(self)) def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -105,10 +105,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class NullType(PrimitiveType): @@ -251,9 +247,9 @@ class ArrayType(DataType): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, True) + >>> ArrayType(StringType()) == ArrayType(StringType(), True) True - >>> ArrayType(StringType, False) == ArrayType(StringType) + >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ self.elementType = elementType @@ -298,11 +294,11 @@ class MapType(DataType): :param valueContainsNull: indicates whether values contains null values. - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) False """ self.keyType = keyType @@ -351,11 +347,11 @@ class StructField(DataType): to simple type that can be serialized to JSON automatically - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) False """ self.name = name @@ -393,13 +389,13 @@ class StructType(DataType): def __init__(self, fields): """Creates a StructType - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ @@ -499,6 +495,10 @@ _all_complex_types = dict((v.typeName(), v) def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. + + >>> import pickle + >>> LongType() == pickle.loads(pickle.dumps(LongType())) + True >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) @@ -781,8 +781,25 @@ def _merge_type(a, b): return a +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ + + if not _need_converter(dataType): + return lambda x: x + if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) @@ -800,6 +817,7 @@ def _create_converter(dataType): # dataType must be StructType names = [f.name for f in dataType.fields] converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) def convert_struct(obj): if obj is None: @@ -822,7 +840,10 @@ def _create_converter(dataType): else: raise ValueError("Unexpected obj: %s" % obj) - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) return convert_struct @@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType): _verify_type(v, f.dataType) -_cached_cls = {} +_cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
