wumeibanfa commented on code in PR #52967:
URL: https://github.com/apache/doris/pull/52967#discussion_r2194976267
##########
be/src/util/frame_of_reference_coding.cpp:
##########
@@ -421,28 +421,147 @@ bool ForDecoder<T>::init() {
}
// todo(kks): improve this method by SIMD instructions
+
+template <typename T>
+template <typename U>
+void ForDecoder<T>::bit_unpack_optimize(const uint8_t* input, uint8_t in_num,
int bit_width,
+ T* output) {
+ U s = 0;
+ int valid_bit = 0; // How many valid bits
+ int need_bit = 0; // still need
+ T output_mask = ((static_cast<T>(1)) << bit_width) - 1;
+ int u_size = sizeof(U); // Size of U
+ size_t input_size = (in_num * bit_width + 7) >> 3; // input's size
+ int full_batch_size =
+ (input_size / u_size) * u_size; // Adjust input_size to a
multiple of u_size
+ int tail_count = input_size & (u_size - 1); // The remainder of input_size
modulo u_size.
+ // The number of bits in input to adjust to multiples of 8 and thus more
+ int more_bit = (input_size << 3) - (in_num * bit_width);
+
+ for (int i = 0; i < full_batch_size; i += u_size) {
+ s |= static_cast<U>(input[i]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 1]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 2]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 3]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 4]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 5]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 6]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 7]);
+
+ if (u_size == 16) {
+ s <<= 8;
+ s |= static_cast<U>(input[i + 8]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 9]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 10]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 11]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 12]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 13]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 14]);
+ s <<= 8;
+ s |= static_cast<U>(input[i + 15]);
+ }
+
+ // Determine what the valid bits are based on u_size
+ valid_bit = u_size * 8;
+
+ // If input_size is exactly a multiple of 8, then need to remove the
last more_bit in the last loop.
+ if (tail_count == 0 && i == full_batch_size - u_size) {
+ valid_bit -= more_bit;
+ s >>= more_bit;
+ }
+
+ if (need_bit) {
+ // The last time we take away the high bit_width - need_bit,
+ // we need to make up the rest of the need_bit from the width.
+ // Use valid_bit - need_bit to compute high need_bit bits of s
+ // perform an AND operation to ensure that only need_bit bits are
valid
+ *output |= ((s >> (valid_bit - need_bit)) & ((static_cast<U>(1) <<
need_bit) - 1));
+ output++;
+ valid_bit -= need_bit;
+ }
+
+ int num = valid_bit / bit_width; // How many outputs can be
processed at a time
Review Comment:
bit_width is not 2^i, must use \ and % here
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]