lidavidm commented on code in PR #49549:
URL: https://github.com/apache/arrow/pull/49549#discussion_r3015044841


##########
cpp/src/arrow/flight/serialization_internal.cc:
##########
@@ -612,6 +618,232 @@ Status ToProto(const CloseSessionResult& result, 
pb::CloseSessionResult* pb_resu
   return Status::OK();
 }
 
+namespace {
+using google::protobuf::internal::WireFormatLite;
+using google::protobuf::io::ArrayOutputStream;
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::CodedOutputStream;
+static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
+const uint8_t kPaddingBytes[8] = {0, 0, 0, 0, 0, 0, 0, 0};
+
+// Update the sizes of our Protobuf fields based on the given IPC payload.
+arrow::Status IpcMessageHeaderSize(const arrow::ipc::IpcPayload& ipc_msg, bool 
has_body,
+                                   size_t* header_size, int32_t* 
metadata_size) {
+  DCHECK_LE(ipc_msg.metadata->size(), kInt32Max);
+  *metadata_size = static_cast<int32_t>(ipc_msg.metadata->size());
+
+  // 1 byte for metadata tag
+  *header_size += 1 + WireFormatLite::LengthDelimitedSize(*metadata_size);
+
+  // 2 bytes for body tag
+  if (has_body) {
+    // We write the body tag in the header but not the actual body data
+    *header_size += 2 + 
WireFormatLite::LengthDelimitedSize(ipc_msg.body_length) -
+                    ipc_msg.body_length;
+  }
+
+  return arrow::Status::OK();
+}
+
+bool ReadBytesZeroCopy(const std::shared_ptr<Buffer>& source_data,
+                       CodedInputStream* input, std::shared_ptr<Buffer>* out) {
+  uint32_t length;
+  if (!input->ReadVarint32(&length)) {
+    return false;
+  }
+  auto buf =
+      SliceBuffer(source_data, input->CurrentPosition(), 
static_cast<int64_t>(length));
+  *out = buf;
+  return input->Skip(static_cast<int>(length));
+}
+
+}  // namespace
+
+arrow::Result<arrow::BufferVector> SerializePayloadToBuffers(const 
FlightPayload& msg) {
+  // Size of the IPC body (protobuf: data_body)
+  size_t body_size = 0;
+  // Size of the Protobuf "header" (everything except for the body)
+  size_t header_size = 0;
+  // Size of IPC header metadata (protobuf: data_header)
+  int32_t metadata_size = 0;
+
+  // Write the descriptor if present
+  int32_t descriptor_size = 0;
+  if (msg.descriptor != nullptr) {
+    DCHECK_LE(msg.descriptor->size(), kInt32Max);
+    descriptor_size = static_cast<int32_t>(msg.descriptor->size());
+    header_size += 1 + WireFormatLite::LengthDelimitedSize(descriptor_size);
+  }
+
+  // App metadata tag if appropriate
+  int32_t app_metadata_size = 0;
+  if (msg.app_metadata && msg.app_metadata->size() > 0) {
+    DCHECK_LE(msg.app_metadata->size(), kInt32Max);
+    app_metadata_size = static_cast<int32_t>(msg.app_metadata->size());
+    header_size += 1 + WireFormatLite::LengthDelimitedSize(app_metadata_size);
+  }
+
+  const arrow::ipc::IpcPayload& ipc_msg = msg.ipc_message;
+  // No data in this payload (metadata-only).
+  bool has_ipc = ipc_msg.type != ipc::MessageType::NONE;
+  bool has_body = has_ipc ? ipc::Message::HasBody(ipc_msg.type) : false;
+
+  if (has_ipc) {
+    DCHECK(has_body || ipc_msg.body_length == 0);
+    ARROW_RETURN_NOT_OK(
+        IpcMessageHeaderSize(ipc_msg, has_body, &header_size, &metadata_size));
+    body_size = static_cast<size_t>(ipc_msg.body_length);
+  }
+
+  // TODO(wesm): messages over 2GB unlikely to be yet supported
+  // Validated in WritePayload since returning error here causes gRPC to fail 
an assertion
+  DCHECK_LE(body_size, kInt32Max);
+
+  // Allocate and initialize buffers
+  arrow::BufferVector buffers;
+  ARROW_ASSIGN_OR_RAISE(auto header_buf, arrow::AllocateBuffer(header_size));
+
+  // Force the header_stream to be destructed, which actually flushes
+  // the data into the slice.
+  {
+    ArrayOutputStream 
header_writer(const_cast<uint8_t*>(header_buf->mutable_data()),
+                                    static_cast<int>(header_size));
+    CodedOutputStream header_stream(&header_writer);
+
+    // Write descriptor
+    if (msg.descriptor != nullptr) {
+      WireFormatLite::WriteTag(pb::FlightData::kFlightDescriptorFieldNumber,
+                               WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&header_stream);
+      header_stream.WriteVarint32(descriptor_size);
+      header_stream.WriteRawMaybeAliased(msg.descriptor->data(),
+                                         
static_cast<int>(msg.descriptor->size()));
+    }
+
+    // Write header
+    if (has_ipc) {
+      WireFormatLite::WriteTag(pb::FlightData::kDataHeaderFieldNumber,
+                               WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&header_stream);
+      header_stream.WriteVarint32(metadata_size);
+      header_stream.WriteRawMaybeAliased(ipc_msg.metadata->data(),
+                                         
static_cast<int>(ipc_msg.metadata->size()));
+    }
+
+    // Write app metadata
+    if (app_metadata_size > 0) {
+      WireFormatLite::WriteTag(pb::FlightData::kAppMetadataFieldNumber,
+                               WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&header_stream);
+      header_stream.WriteVarint32(app_metadata_size);
+      header_stream.WriteRawMaybeAliased(msg.app_metadata->data(),
+                                         
static_cast<int>(msg.app_metadata->size()));
+    }
+
+    if (has_body) {
+      // Write body tag
+      WireFormatLite::WriteTag(pb::FlightData::kDataBodyFieldNumber,
+                               WireFormatLite::WIRETYPE_LENGTH_DELIMITED, 
&header_stream);
+      header_stream.WriteVarint32(static_cast<uint32_t>(body_size));
+
+      // Enqueue body buffers for writing, without copying
+      for (const auto& buffer : ipc_msg.body_buffers) {
+        // Buffer may be null when the row length is zero, or when all
+        // entries are invalid.
+        if (!buffer || buffer->size() == 0) continue;
+        buffers.push_back(buffer);
+
+        // Write padding if not multiple of 8
+        const auto remainder = static_cast<int>(
+            bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+        if (remainder) {
+          buffers.push_back(std::make_shared<arrow::Buffer>(kPaddingBytes, 
remainder));
+        }
+      }
+    }
+
+    DCHECK_EQ(static_cast<int>(header_size), header_stream.ByteCount());
+  }
+  // Once header is written we add it as the first buffer in the output vector.
+  buffers.insert(buffers.begin(), std::move(header_buf));
+
+  return buffers;
+}
+
+// Read internal::FlightData from arrow::Buffer containing FlightData
+// protobuf without copying
+arrow::Result<arrow::flight::internal::FlightData> DeserializeFlightData(
+    const std::shared_ptr<arrow::Buffer>& buffer) {
+  if (!buffer) {
+    return Status::Invalid("No payload");
+  }

Review Comment:
   Ah right, I missed that usage site.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to