Fokko commented on code in PR #117: URL: https://github.com/apache/iceberg-cpp/pull/117#discussion_r2161756094
########## src/iceberg/expression/literal.cc: ########## @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/literal.h" + +#include <cmath> +#include <concepts> +#include <sstream> + +#include "iceberg/exception.h" + +namespace iceberg { + +/// \brief LiteralCaster handles type casting operations for Literal. +/// This is an internal implementation class. +class LiteralCaster { + public: + /// Cast a Literal to the target type. + static Result<Literal> CastTo(const Literal& literal, + const std::shared_ptr<PrimitiveType>& target_type); + + /// Create a literal representing a value below the minimum for the given type. + static Literal BelowMinLiteral(std::shared_ptr<PrimitiveType> type); + + /// Create a literal representing a value above the maximum for the given type. + static Literal AboveMaxLiteral(std::shared_ptr<PrimitiveType> type); + + private: + /// Cast from Int type to target type. + static Result<Literal> CastFromInt(const Literal& literal, + const std::shared_ptr<PrimitiveType>& target_type); + + /// Cast from Long type to target type. + static Result<Literal> CastFromLong(const Literal& literal, + const std::shared_ptr<PrimitiveType>& target_type); + + /// Cast from Float type to target type. + static Result<Literal> CastFromFloat(const Literal& literal, + const std::shared_ptr<PrimitiveType>& target_type); +}; + +Literal LiteralCaster::BelowMinLiteral(std::shared_ptr<PrimitiveType> type) { + return Literal(Literal::BelowMin{}, std::move(type)); +} + +Literal LiteralCaster::AboveMaxLiteral(std::shared_ptr<PrimitiveType> type) { + return Literal(Literal::AboveMax{}, std::move(type)); +} + +Result<Literal> LiteralCaster::CastFromInt( + const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) { + auto int_val = std::get<int32_t>(literal.value_); + auto target_type_id = target_type->type_id(); + + switch (target_type_id) { + case TypeId::kLong: + return Literal::Long(static_cast<int64_t>(int_val)); + case TypeId::kFloat: + return Literal::Float(static_cast<float>(int_val)); + case TypeId::kDouble: + return Literal::Double(static_cast<double>(int_val)); + default: + return NotSupported("Cast from Int to {} is not implemented", + static_cast<int>(target_type_id)); + } +} + +Result<Literal> LiteralCaster::CastFromLong( + const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) { + auto long_val = std::get<int64_t>(literal.value_); + auto target_type_id = target_type->type_id(); + + switch (target_type_id) { + case TypeId::kInt: { + // Check for overflow + if (long_val >= std::numeric_limits<int32_t>::max()) { + return AboveMaxLiteral(target_type); + } + if (long_val <= std::numeric_limits<int32_t>::min()) { + return BelowMinLiteral(target_type); + } + return Literal::Int(static_cast<int32_t>(long_val)); + } + case TypeId::kFloat: + return Literal::Float(static_cast<float>(long_val)); + case TypeId::kDouble: + return Literal::Double(static_cast<double>(long_val)); + default: + return NotSupported("Cast from Long to {} is not supported", + static_cast<int>(target_type_id)); + } +} + +Result<Literal> LiteralCaster::CastFromFloat( + const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) { + auto float_val = std::get<float>(literal.value_); + auto target_type_id = target_type->type_id(); + + switch (target_type_id) { + case TypeId::kDouble: + return Literal::Double(static_cast<double>(float_val)); + default: + return NotSupported("Cast from Float to {} is not supported", + static_cast<int>(target_type_id)); + } +} + +// Constructor +Literal::Literal(Value value, std::shared_ptr<PrimitiveType> type) + : value_(std::move(value)), type_(std::move(type)) {} + +// Factory methods +Literal Literal::Boolean(bool value) { + return {Value{value}, std::make_shared<BooleanType>()}; +} + +Literal Literal::Int(int32_t value) { + return {Value{value}, std::make_shared<IntType>()}; +} + +Literal Literal::Long(int64_t value) { + return {Value{value}, std::make_shared<LongType>()}; +} + +Literal Literal::Float(float value) { + return {Value{value}, std::make_shared<FloatType>()}; +} + +Literal Literal::Double(double value) { + return {Value{value}, std::make_shared<DoubleType>()}; +} + +Literal Literal::String(std::string value) { + return {Value{std::move(value)}, std::make_shared<StringType>()}; +} + +Literal Literal::Binary(std::vector<uint8_t> value) { + return {Value{std::move(value)}, std::make_shared<BinaryType>()}; +} + +Result<Literal> Literal::Deserialize(std::span<const uint8_t> data, + std::shared_ptr<PrimitiveType> type) { + return NotImplemented("Deserialization of Literal is not implemented yet"); +} + +Result<std::vector<uint8_t>> Literal::Serialize() const { + return NotImplemented("Serialization of Literal is not implemented yet"); +} + +// Getters + +const std::shared_ptr<PrimitiveType>& Literal::type() const { return type_; } + +// Cast method +Result<Literal> Literal::CastTo(const std::shared_ptr<PrimitiveType>& target_type) const { + return LiteralCaster::CastTo(*this, target_type); +} + +// Template function for floating point comparison following Iceberg rules: +// -NaN < NaN, but all NaN values (qNaN, sNaN) are treated as equivalent within their sign +template <std::floating_point T> +std::partial_ordering iceberg_float_compare(T lhs, T rhs) { + bool lhs_is_nan = std::isnan(lhs); + bool rhs_is_nan = std::isnan(rhs); + + // If both are NaN, check their signs + if (lhs_is_nan && rhs_is_nan) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + + if (lhs_is_negative == rhs_is_negative) { + // Same sign NaN values are equivalent (no qNaN vs sNaN distinction) + return std::partial_ordering::equivalent; + } + // -NaN < NaN + return lhs_is_negative ? std::partial_ordering::less : std::partial_ordering::greater; + } + + // For non-NaN values, use standard strong ordering + return std::strong_order(lhs, rhs); +} + +// Three-way comparison operator +std::partial_ordering Literal::operator<=>(const Literal& other) const { + // If types are different, comparison is unordered + if (type_->type_id() != other.type_->type_id()) { + return std::partial_ordering::unordered; + } + + // If either value is AboveMax or BelowMin, comparison is unordered + if (IsAboveMax() || IsBelowMin() || other.IsAboveMax() || other.IsBelowMin()) { + return std::partial_ordering::unordered; + } + + // Same type comparison for normal values + switch (type_->type_id()) { + case TypeId::kBoolean: { + auto this_val = std::get<bool>(value_); + auto other_val = std::get<bool>(other.value_); + if (this_val == other_val) return std::partial_ordering::equivalent; + return this_val ? std::partial_ordering::greater : std::partial_ordering::less; + } + + case TypeId::kInt: { + auto this_val = std::get<int32_t>(value_); + auto other_val = std::get<int32_t>(other.value_); + return this_val <=> other_val; + } + + case TypeId::kLong: { + auto this_val = std::get<int64_t>(value_); + auto other_val = std::get<int64_t>(other.value_); + return this_val <=> other_val; + } + + case TypeId::kFloat: { + auto this_val = std::get<float>(value_); + auto other_val = std::get<float>(other.value_); + // Use strong_ordering for floating point as spec requests + return iceberg_float_compare(this_val, other_val); + } + + case TypeId::kDouble: { + auto this_val = std::get<double>(value_); + auto other_val = std::get<double>(other.value_); + // Use strong_ordering for floating point as spec requests + return iceberg_float_compare(this_val, other_val); + } + + case TypeId::kString: { + auto& this_val = std::get<std::string>(value_); + auto& other_val = std::get<std::string>(other.value_); + return this_val <=> other_val; + } + + case TypeId::kBinary: { + auto& this_val = std::get<std::vector<uint8_t>>(value_); + auto& other_val = std::get<std::vector<uint8_t>>(other.value_); + return this_val <=> other_val; + } + + default: + // For unsupported types, return unordered + return std::partial_ordering::unordered; + } +} + +std::string Literal::ToString() const { + if (std::holds_alternative<BelowMin>(value_)) { + return "BelowMin"; + } + if (std::holds_alternative<AboveMax>(value_)) { + return "AboveMax"; + } + + switch (type_->type_id()) { + case TypeId::kBoolean: { + return std::get<bool>(value_) ? "true" : "false"; + } + case TypeId::kInt: { + return std::to_string(std::get<int32_t>(value_)); + } + case TypeId::kLong: { + return std::to_string(std::get<int64_t>(value_)); + } + case TypeId::kFloat: { + return std::to_string(std::get<float>(value_)); + } + case TypeId::kDouble: { + return std::to_string(std::get<double>(value_)); + } + case TypeId::kString: { + return std::get<std::string>(value_); + } + case TypeId::kBinary: { + const auto& binary_data = std::get<std::vector<uint8_t>>(value_); + std::string result; + result.reserve(binary_data.size() * 2); // 2 chars per byte + for (const auto& byte : binary_data) { + result += std::format("{:02X}", byte); + } + return result; + } + case TypeId::kDecimal: + case TypeId::kUuid: + case TypeId::kFixed: + case TypeId::kDate: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: { + throw IcebergError("Not implemented: ToString for " + type_->ToString()); Review Comment: Maybe something similar to Java, where you get the Object type with the pointer? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org