Jackie-Jiang commented on code in PR #14958: URL: https://github.com/apache/pinot/pull/14958#discussion_r1936746829
########## pinot-segment-local/src/main/java/org/apache/pinot/segment/local/recordtransformer/SpecialValueTransformer.java: ########## @@ -18,112 +18,119 @@ */ package org.apache.pinot.segment.local.recordtransformer; -import com.google.common.annotations.VisibleForTesting; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; import org.apache.pinot.spi.data.FieldSpec; -import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.data.Schema; import org.apache.pinot.spi.data.readers.GenericRow; import org.apache.pinot.spi.recordtransformer.RecordTransformer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * The {@code SpecialValueTransformer} class will transform special values according to the following rules: * <ul> - * <li>Negative zero (-0.0) should be converted to 0.0</li> - * <li>NaN should be converted to default null</li> + * <li> + * For FLOAT and DOUBLE: + * <ul> + * <li>Negative zero (-0.0) should be converted to 0.0</li> + * <li>NaN should be converted to default null</li> + * </ul> + * </li> + * <li> + * For BIG_DECIMAL: + * <ul> + * <li>Remove trailing zeros</li> + * </ul> + * </li> * </ul> * <p>NOTE: should put this after the {@link DataTypeTransformer} so that we already have the values complying * with the schema before handling special values and before {@link NullValueTransformer} so that it transforms * all the null values properly. */ public class SpecialValueTransformer implements RecordTransformer { + private final static int NEGATIVE_ZERO_FLOAT_BITS = Float.floatToRawIntBits(-0.0f); + private final static long NEGATIVE_ZERO_DOUBLE_BITS = Double.doubleToLongBits(-0.0d); - private static final Logger LOGGER = LoggerFactory.getLogger(SpecialValueTransformer.class); - private final HashSet<String> _specialValuesKeySet = new HashSet<>(); - private int _negativeZeroConversionCount = 0; - private int _nanConversionCount = 0; + private final Set<String> _columnsToCheck = new HashSet<>(); public SpecialValueTransformer(Schema schema) { for (FieldSpec fieldSpec : schema.getAllFieldSpecs()) { - if (!fieldSpec.isVirtualColumn() && (fieldSpec.getDataType() == DataType.FLOAT - || fieldSpec.getDataType() == DataType.DOUBLE)) { - _specialValuesKeySet.add(fieldSpec.getName()); + if (!fieldSpec.isVirtualColumn()) { + FieldSpec.DataType dataType = fieldSpec.getDataType(); + if (dataType == FieldSpec.DataType.FLOAT || dataType == FieldSpec.DataType.DOUBLE + || dataType == FieldSpec.DataType.BIG_DECIMAL) { + _columnsToCheck.add(fieldSpec.getName()); + } } } } - private Object transformNegativeZero(Object value) { - if ((value instanceof Float) && (Float.floatToRawIntBits((float) value) == Float.floatToRawIntBits(-0.0f))) { - value = 0.0f; - _negativeZeroConversionCount++; - } else if ((value instanceof Double) && (Double.doubleToLongBits((double) value) == Double.doubleToLongBits( - -0.0d))) { - value = 0.0d; - _negativeZeroConversionCount++; - } - return value; - } - - private Object transformNaN(Object value) { - if ((value instanceof Float) && ((Float) value).isNaN()) { - value = null; - _nanConversionCount++; - } else if ((value instanceof Double) && ((Double) value).isNaN()) { - _nanConversionCount++; - value = null; - } - return value; - } - @Override public boolean isNoOp() { - return _specialValuesKeySet.isEmpty(); + return _columnsToCheck.isEmpty(); } @Override public GenericRow transform(GenericRow record) { - for (String element : _specialValuesKeySet) { - Object value = record.getValue(element); + for (String column : _columnsToCheck) { + Object value = record.getValue(column); if (value instanceof Object[]) { // Multi-valued column. Object[] values = (Object[]) value; - int numValues = values.length; - List<Object> negativeZeroNanSanitizedValues = new ArrayList<>(numValues); - for (Object o : values) { - Object zeroTransformedValue = transformNegativeZero(o); - Object nanTransformedValue = transformNaN(zeroTransformedValue); - if (nanTransformedValue != null) { - negativeZeroNanSanitizedValues.add(nanTransformedValue); + List<Object> transformedValues = new ArrayList<>(values.length); + boolean transformed = false; + for (Object v : values) { + Object transformedValue = transformValue(v); + if (transformedValue != v) { + transformed = true; + } + if (transformedValue != null) { + transformedValues.add(transformedValue); + } + if (transformed) { + record.putValue(column, !transformedValues.isEmpty() ? transformedValues.toArray() : null); } } - record.putValue(element, negativeZeroNanSanitizedValues.toArray()); - } else { + } else if (value != null) { // Single-valued column. - Object zeroTransformedValue = transformNegativeZero(value); - Object nanTransformedValue = transformNaN(zeroTransformedValue); - if (nanTransformedValue != value) { - record.putValue(element, nanTransformedValue); + Object transformedValue = transformValue(value); + if (transformedValue != value) { + record.putValue(column, transformedValue); } } } - if (_negativeZeroConversionCount > 0 || _nanConversionCount > 0) { - LOGGER.debug("Converted {} -0.0s to 0.0 and {} NaNs to null", _negativeZeroConversionCount, _nanConversionCount); - } return record; } - @VisibleForTesting - int getNegativeZeroConversionCount() { - return _negativeZeroConversionCount; - } - - @VisibleForTesting - int getNanConversionCount() { - return _nanConversionCount; + @Nullable + private Object transformValue(Object value) { + if (value instanceof Float) { + Float floatValue = (Float) value; + if (floatValue.isNaN()) { + return null; + } + if (Float.floatToRawIntBits(floatValue) == NEGATIVE_ZERO_FLOAT_BITS) { + return 0.0f; + } + } else if (value instanceof Double) { + Double doubleValue = (Double) value; + if (doubleValue.isNaN()) { + return null; + } + if (Double.doubleToRawLongBits(doubleValue) == NEGATIVE_ZERO_DOUBLE_BITS) { + return 0.0d; + } + } else if (value instanceof BigDecimal) { + BigDecimal bigDecimalValue = (BigDecimal) value; + BigDecimal stripped = bigDecimalValue.stripTrailingZeros(); + if (!stripped.equals(bigDecimalValue)) { + return stripped; + } Review Comment: I don't fully follow, can you elaborate? We need to strip zeros because `10.100` and `10.1` can also mess up the index. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org