This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f2672fcf3cf [SPARK-43923][CONNECT][FOLLOWUP] Correct the message
abbreviation
f2672fcf3cf is described below
commit f2672fcf3cf3019791a0afcf7eff28f86503fcbc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Oct 26 22:11:11 2023 +0800
[SPARK-43923][CONNECT][FOLLOWUP] Correct the message abbreviation
### What changes were proposed in this pull request?
1, truncate raw bytes (udf/udtf/local relation) with `MAX_BYTES_SIZE`;
2, pass `maxStringSize` to abbreviate nested messages;
3, minor optimization to avoid temp array creation;
### Why are the changes needed?
1, there is only one place specifying the `maxStringSize`, with value
`MAX_STATEMENT_TEXT_SIZE = 65535`. By its name, it is used to truncate the SQL
statements which are always strings. No need to affect raw bytes;
2, according to the implementation of `Message.toString`:
https://github.com/protocolbuffers/protobuf/blob/main/java/core/src/main/java/com/google/protobuf/TextFormat.java#L567-L574
the value of bytes fields can be either `ByteString` or `byte[]`, so the
two branches should be consistent.
3, `maxStringSize` only affects the top-level string fields, it should also
be used in nested messages.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43535 from zhengruifeng/connect_abbreviate_fix.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../spark/sql/connect/common/ProtoUtils.scala | 25 ++++++++++++----------
1 file changed, 14 insertions(+), 11 deletions(-)
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index c7bf3f93bd0..4d1be169ae1 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -43,9 +43,12 @@ private[connect] object ProtoUtils {
case (field: FieldDescriptor, byteString: ByteString)
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING &&
byteString != null =>
val size = byteString.size
- if (size > maxStringSize) {
- val prefix = Array.tabulate(maxStringSize)(byteString.byteAt)
- builder.setField(field, createByteString(prefix, size))
+ if (size > MAX_BYTES_SIZE) {
+ builder.setField(
+ field,
+ byteString
+ .substring(0, MAX_BYTES_SIZE)
+ .concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteString)
}
@@ -54,8 +57,11 @@ private[connect] object ProtoUtils {
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING &&
byteArray != null =>
val size = byteArray.size
if (size > MAX_BYTES_SIZE) {
- val prefix = byteArray.take(MAX_BYTES_SIZE)
- builder.setField(field, createByteString(prefix, size))
+ builder.setField(
+ field,
+ ByteString
+ .copyFrom(byteArray, 0, MAX_BYTES_SIZE)
+ .concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteArray)
}
@@ -63,7 +69,7 @@ private[connect] object ProtoUtils {
// TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx,
msg>
case (field: FieldDescriptor, msg: Message)
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg !=
null =>
- builder.setField(field, abbreviate(msg))
+ builder.setField(field, abbreviate(msg, maxStringSize))
case (field: FieldDescriptor, value: Any) => builder.setField(field,
value)
}
@@ -71,11 +77,8 @@ private[connect] object ProtoUtils {
builder.build()
}
- private def createByteString(prefix: Array[Byte], size: Int): ByteString = {
- ByteString.copyFrom(
- List(
- ByteString.copyFrom(prefix),
-
ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")).asJava)
+ private def createTruncatedByteString(size: Int): ByteString = {
+ ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")
}
private def createString(prefix: String, size: Int): String = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]