From 3028315fd385d57ee4e6dbb52a03c1733055c0c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Sat, 3 Jan 2026 16:39:43 +0000 Subject: [PATCH 1/4] wip --- .../org/apache/arrow/flight/ArrowMessage.java | 16 +- .../arrow/flight/grpc/GetReadableBuffer.java | 334 ++++++- .../flight/grpc/TestGetReadableBuffer.java | 888 ++++++++++++++++++ 3 files changed, 1207 insertions(+), 31 deletions(-) create mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index ab4eab3048..6187a59672 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -312,8 +312,12 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s case APP_METADATA_TAG: { int size = readRawVarint32(stream); - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); + if (ENABLE_ZERO_COPY_READ) { + appMetadata = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, size); + } else { + appMetadata = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, false); + } break; } case BODY_TAG: @@ -323,8 +327,12 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s body = null; } int size = readRawVarint32(stream); - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); + if (ENABLE_ZERO_COPY_READ) { + body = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, size); + } else { + body = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, body, size, false); + } break; default: diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java index 45c32a86c6..f3d3b9de02 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java @@ -18,50 +18,95 @@ import com.google.common.base.Throwables; import com.google.common.io.ByteStreams; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; import io.grpc.internal.ReadableBuffer; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; +import java.nio.ByteBuffer; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; /** - * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target - * ByteBuffer/ByteBuf. + * Utility class for efficiently reading data from gRPC InputStreams into Arrow buffers. * - *

This could be solved by BufferInputStream exposing Drainable. + *

This class provides two implementations for zero-copy reads: + * + *

+ * + *

When neither fast path is available, falls back to copying via a byte array. */ public class GetReadableBuffer { + /** + * System property to control whether the HasByteBuffer API is used. Default is {@code true}. Set + * to {@code false} to use the legacy reflection-based implementation. + */ + public static final String HASBYTEBUFFER_API_PROPERTY = + "arrow.flight.grpc.enable_hasbytebuffer_api"; + + private static final boolean USE_HASBYTEBUFFER_API; + + // Legacy reflection-based fields private static final Field READABLE_BUFFER; private static final Class BUFFER_INPUT_STREAM; static { + // Determine which implementation to use based on system property (default: true) + USE_HASBYTEBUFFER_API = + !"false".equalsIgnoreCase(System.getProperty(HASBYTEBUFFER_API_PROPERTY, "true")); + + // Initialize legacy reflection-based implementation (used as fallback or when explicitly + // enabled) Field tmpField = null; Class tmpClazz = null; - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - // don't set until we've gotten past all exception cases. - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) - .printStackTrace(); + if (!USE_HASBYTEBUFFER_API) { + try { + Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); + Field f = clazz.getDeclaredField("buffer"); + f.setAccessible(true); + tmpField = f; + tmpClazz = clazz; + } catch (Exception e) { + new RuntimeException( + "Failed to initialize GetReadableBuffer reflection, falling back to slow path", e) + .printStackTrace(); + } } READABLE_BUFFER = tmpField; BUFFER_INPUT_STREAM = tmpClazz; } + private GetReadableBuffer() {} + + /** + * Returns whether the HasByteBuffer API is enabled. + * + * @return true if the HasByteBuffer API is enabled, false if using legacy reflection + */ + public static boolean isHasByteBufferApiEnabled() { + return USE_HASBYTEBUFFER_API; + } + /** - * Extracts the ReadableBuffer for the given input stream. + * Extracts the ReadableBuffer for the given input stream using reflection. * * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null * will be returned. + * @deprecated This method uses gRPC internal APIs via reflection. Prefer using {@link + * #readIntoBuffer} which uses the public HasByteBuffer API by default. */ + @Deprecated public static ReadableBuffer getReadableBuffer(InputStream is) { - if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { return null; } @@ -76,24 +121,259 @@ public static ReadableBuffer getReadableBuffer(InputStream is) { /** * Helper method to read a gRPC-provided InputStream into an ArrowBuf. * - * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. + *

When fastPath is enabled, this method attempts to use zero-copy reads: + * + *

+ * + *

Falls back to copying via a byte array if zero-copy is not available. + * + * @param stream The stream to read from. * @param buf The buffer to read into. * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link - * #BUFFER_INPUT_STREAM}). - * @throws IOException if there is an error reading form the stream + * @param fastPath Whether to enable the fast path (zero-copy reads). + * @throws IOException if there is an error reading from the stream */ public static void readIntoBuffer( final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) throws IOException { - ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; - if (readableBuffer != null) { - readableBuffer.readBytes(buf.nioBuffer(0, size)); - } else { - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); + if (size == 0) { + buf.writerIndex(0); + return; } + + if (fastPath) { + if (USE_HASBYTEBUFFER_API) { + // New implementation using public HasByteBuffer API + if (stream instanceof HasByteBuffer) { + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (hasByteBuffer.byteBufferSupported()) { + readUsingHasByteBuffer(stream, hasByteBuffer, buf, size); + return; + } + } + } else { + // Legacy implementation using reflection + ReadableBuffer readableBuffer = getReadableBuffer(stream); + if (readableBuffer != null) { + readableBuffer.readBytes(buf.nioBuffer(0, size)); + buf.writerIndex(size); + return; + } + } + } + + // Slow path: copy via byte array + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + } + + /** + * Reads data from a stream using the HasByteBuffer zero-copy API. + * + *

This method copies data from gRPC's ByteBuffers into the provided ArrowBuf. While it avoids + * intermediate byte[] allocations, it still performs a memory copy. + * + * @param stream The underlying InputStream (for skip operations) + * @param hasByteBuffer The HasByteBuffer interface of the stream + * @param buf The ArrowBuf to write into + * @param size The number of bytes to read + * @throws IOException if there is an error reading from the stream + */ + private static void readUsingHasByteBuffer( + final InputStream stream, final HasByteBuffer hasByteBuffer, final ArrowBuf buf, int size) + throws IOException { + int offset = 0; + int remaining = size; + + while (remaining > 0) { + ByteBuffer byteBuffer = hasByteBuffer.getByteBuffer(); + if (byteBuffer == null) { + throw new IOException( + "Unexpected end of stream: expected " + size + " bytes, got " + offset); + } + + int available = byteBuffer.remaining(); + int toCopy = Math.min(remaining, available); + + // Copy data from the ByteBuffer to the ArrowBuf. + // We use the ByteBuffer directly without duplicate() since HasByteBuffer javadoc states: + // "The returned buffer's content should not be modified, but the position, limit, and mark + // may be changed. Operations for changing the position, limit, and mark of the returned + // buffer does not affect the position, limit, and mark of this input stream." + int originalLimit = byteBuffer.limit(); + byteBuffer.limit(byteBuffer.position() + toCopy); + buf.setBytes(offset, byteBuffer); + byteBuffer.limit(originalLimit); + + // Advance the stream position + long skipped = stream.skip(toCopy); + if (skipped != toCopy) { + throw new IOException("Failed to skip bytes: expected " + toCopy + ", skipped " + skipped); + } + + offset += toCopy; + remaining -= toCopy; + } + buf.writerIndex(size); } + + /** + * Reads data from a gRPC stream with true zero-copy ownership transfer when possible. + * + *

This method attempts to achieve zero-copy by taking ownership of gRPC's underlying + * ByteBuffers using the {@link Detachable} interface. When successful, the returned ArrowBuf + * directly wraps the gRPC buffer's memory without any data copying. + * + *

Zero-copy ownership transfer is only possible when: + * + *

+ * + *

When zero-copy is not possible, this method falls back to allocating a new buffer and + * copying the data. + * + * @param allocator The allocator to use for buffer allocation (used for both zero-copy wrapping + * and fallback allocation) + * @param stream The gRPC InputStream to read from + * @param size The number of bytes to read + * @return An ArrowBuf containing the data. The caller is responsible for closing this buffer. + * @throws IOException if there is an error reading from the stream + */ + public static ArrowBuf readWithOwnershipTransfer( + final BufferAllocator allocator, final InputStream stream, final int size) + throws IOException { + if (size == 0) { + return allocator.getEmpty(); + } + + // Try zero-copy ownership transfer if the stream supports it + if (USE_HASBYTEBUFFER_API && stream instanceof HasByteBuffer && stream instanceof Detachable) { + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (hasByteBuffer.byteBufferSupported()) { + ArrowBuf zeroCopyBuf = tryZeroCopyOwnershipTransfer(allocator, stream, hasByteBuffer, size); + if (zeroCopyBuf != null) { + return zeroCopyBuf; + } + } + } + + // Fall back to copy-based approach + ArrowBuf buf = allocator.buffer(size); + try { + readIntoBuffer(stream, buf, size, true); + return buf; + } catch (Exception e) { + buf.close(); + throw e; + } + } + + /** + * Attempts zero-copy ownership transfer from gRPC stream to ArrowBuf. + * + * @return ArrowBuf wrapping gRPC's memory if successful, null if zero-copy is not possible + */ + private static ArrowBuf tryZeroCopyOwnershipTransfer( + final BufferAllocator allocator, + final InputStream stream, + final HasByteBuffer hasByteBuffer, + final int size) + throws IOException { + // Check if mark is supported - we need it to reset the stream if zero-copy fails + if (!stream.markSupported()) { + return null; + } + + // Use mark() to prevent premature deallocation while we inspect the buffer + // According to gRPC docs: "skip() deallocates the last ByteBuffer, similar to read()" + // mark() prevents this deallocation + stream.mark(size); + + try { + ByteBuffer byteBuffer = hasByteBuffer.getByteBuffer(); + if (byteBuffer == null) { + // No need to reset - stream is already at end + return null; + } + + // Zero-copy only works with direct ByteBuffers (they have a memory address) + if (!byteBuffer.isDirect()) { + // Stream position hasn't changed, no need to reset + return null; + } + + // Check if this single buffer contains all the data we need + if (byteBuffer.remaining() < size) { + // Data is fragmented across multiple buffers, can't do zero-copy + // Stream position hasn't changed, no need to reset + return null; + } + + // Take ownership of the underlying buffers using Detachable.detach() + Detachable detachable = (Detachable) stream; + InputStream detachedStream = detachable.detach(); + + // Get the ByteBuffer from the detached stream + if (!(detachedStream instanceof HasByteBuffer)) { + // Detached stream doesn't support HasByteBuffer, fall back + // Note: original stream is now empty after detach(), can't fall back + closeQuietly(detachedStream); + throw new IOException("Detached stream does not support HasByteBuffer"); + } + + HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; + ByteBuffer detachedBuffer = detachedHasByteBuffer.getByteBuffer(); + if (detachedBuffer == null || !detachedBuffer.isDirect()) { + closeQuietly(detachedStream); + throw new IOException("Detached buffer is null or not direct"); + } + + // Get the memory address of the ByteBuffer + long memoryAddress = + MemoryUtil.getByteBufferAddress(detachedBuffer) + detachedBuffer.position(); + + // Create a ForeignAllocation that will close the detached stream when released + final InputStream streamToClose = detachedStream; + ForeignAllocation allocation = + new ForeignAllocation(size, memoryAddress) { + @Override + protected void release0() { + closeQuietly(streamToClose); + } + }; + + // Wrap the foreign allocation in an ArrowBuf + ArrowBuf buf = allocator.wrapForeignAllocation(allocation); + buf.writerIndex(size); + return buf; + + } catch (Exception e) { + // Reset the stream position on failure (if possible) + try { + stream.reset(); + } catch (IOException resetEx) { + e.addSuppressed(resetEx); + } + throw e; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException ignored) { + // Ignore close exceptions + } + } + } } diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java new file mode 100644 index 0000000000..1d3e65909c --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java @@ -0,0 +1,888 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight.grpc; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +/** Tests for {@link GetReadableBuffer}. */ +public class TestGetReadableBuffer { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + /** Check if the HasByteBuffer API is enabled (new implementation). */ + static boolean isHasByteBufferApiEnabled() { + return GetReadableBuffer.isHasByteBufferApiEnabled(); + } + + /** Check if the legacy reflection-based implementation is enabled. */ + static boolean isLegacyReflectionEnabled() { + return !GetReadableBuffer.isHasByteBufferApiEnabled(); + } + + // --- Feature Flag Tests --- + + @Test + public void testFeatureFlag_isConsistentWithSystemProperty() { + String propertyValue = System.getProperty(GetReadableBuffer.HASBYTEBUFFER_API_PROPERTY, "true"); + boolean expectedEnabled = !"false".equalsIgnoreCase(propertyValue); + assertEquals(expectedEnabled, GetReadableBuffer.isHasByteBufferApiEnabled()); + } + + // --- Slow Path Tests (work with both implementations) --- + + @Test + public void testSlowPath_regularInputStream() throws IOException { + byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; + InputStream stream = new ByteArrayInputStream(testData); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testSlowPath_fastPathDisabled() throws IOException { + byte[] testData = {10, 20, 30, 40}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + // fastPath=false should force slow path even with HasByteBuffer stream + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testSlowPath_byteBufferNotSupported() throws IOException { + byte[] testData = {100, (byte) 200, 50, 75}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), false); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testSlowPath_emptyBuffer() throws IOException { + InputStream stream = new ByteArrayInputStream(new byte[0]); + + try (ArrowBuf buf = allocator.buffer(8)) { + GetReadableBuffer.readIntoBuffer(stream, buf, 0, false); + assertEquals(0, buf.writerIndex()); + } + } + + @Test + public void testDataIntegrity_writerIndexSet() throws IOException { + byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; + InputStream stream = new ByteArrayInputStream(testData); + + try (ArrowBuf buf = allocator.buffer(16)) { + buf.writerIndex(4); + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); + assertEquals(testData.length, buf.writerIndex()); + } + } + + // --- Fast Path Tests (only run when HasByteBuffer API is enabled) --- + + @Nested + @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isHasByteBufferApiEnabled") + class HasByteBufferApiTests { + + @Test + public void testFastPath_singleByteBuffer() throws IOException { + byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testFastPath_multipleByteBuffers() throws IOException { + byte[] part1 = {1, 2, 3, 4}; + byte[] part2 = {5, 6, 7, 8}; + byte[] part3 = {9, 10}; + byte[] expected = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + List buffers = new ArrayList<>(); + buffers.add(ByteBuffer.wrap(part1)); + buffers.add(ByteBuffer.wrap(part2)); + buffers.add(ByteBuffer.wrap(part3)); + HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); + + try (ArrowBuf buf = allocator.buffer(expected.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, expected.length, true); + + assertEquals(expected.length, buf.writerIndex()); + byte[] result = new byte[expected.length]; + buf.getBytes(0, result); + assertArrayEquals(expected, result); + } + } + + @Test + public void testFastPath_emptyBuffer() throws IOException { + HasByteBufferInputStream stream = new HasByteBufferInputStream(List.of(), true); + + try (ArrowBuf buf = allocator.buffer(8)) { + GetReadableBuffer.readIntoBuffer(stream, buf, 0, true); + assertEquals(0, buf.writerIndex()); + } + } + + @Test + public void testFastPath_partialByteBuffer() throws IOException { + byte[] part1 = {1, 2, 3}; + byte[] part2 = {4, 5, 6, 7, 8}; + byte[] expected = {1, 2, 3, 4, 5}; + + List buffers = new ArrayList<>(); + buffers.add(ByteBuffer.wrap(part1)); + buffers.add(ByteBuffer.wrap(part2)); + HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); + + try (ArrowBuf buf = allocator.buffer(expected.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, expected.length, true); + + assertEquals(expected.length, buf.writerIndex()); + byte[] result = new byte[expected.length]; + buf.getBytes(0, result); + assertArrayEquals(expected, result); + } + } + + @Test + public void testFastPath_largeData() throws IOException { + int size = 64 * 1024; + byte[] testData = new byte[size]; + for (int i = 0; i < size; i++) { + testData[i] = (byte) (i % 256); + } + + List buffers = new ArrayList<>(); + int chunkSize = 8 * 1024; + for (int offset = 0; offset < size; offset += chunkSize) { + int len = Math.min(chunkSize, size - offset); + byte[] chunk = new byte[len]; + System.arraycopy(testData, offset, chunk, 0, len); + buffers.add(ByteBuffer.wrap(chunk)); + } + HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); + + try (ArrowBuf buf = allocator.buffer(size)) { + GetReadableBuffer.readIntoBuffer(stream, buf, size, true); + + assertEquals(size, buf.writerIndex()); + byte[] result = new byte[size]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testErrorHandling_unexpectedEndOfStream() { + byte[] testData = {1, 2, 3}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = allocator.buffer(10)) { + assertThrows( + IOException.class, () -> GetReadableBuffer.readIntoBuffer(stream, buf, 10, true)); + } + } + + @Test + public void testErrorHandling_skipFailure() { + byte[] testData = {1, 2, 3, 4, 5}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true, true); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + assertThrows( + IOException.class, + () -> GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true)); + } + } + + @Test + public void testDataIntegrity_offsetWriting() throws IOException { + byte[] testData = {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = allocator.buffer(16)) { + for (int i = 0; i < 16; i++) { + buf.setByte(i, 0xFF); + } + + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); + + assertEquals((byte) 0xDE, buf.getByte(0)); + assertEquals((byte) 0xAD, buf.getByte(1)); + assertEquals((byte) 0xBE, buf.getByte(2)); + assertEquals((byte) 0xEF, buf.getByte(3)); + assertEquals(testData.length, buf.writerIndex()); + } + } + + /** + * Verifies the HasByteBuffer data transfer path using the actual gRPC {@link + * ReadableBuffers#openStream} implementation. + * + *

This test uses the real gRPC BufferInputStream class (via ReadableBuffers.openStream) to + * ensure our code works correctly with actual gRPC streams, not just custom test helpers. + */ + @Test + public void testHasByteBufferPath_dataTransferWithObjectTracking() throws IOException { + // Create wrapper objects as the conceptual "source" of our data + // Each ByteWrapper is a distinct Java object instance we can track + final ByteWrapper[] sourceObjects = { + new ByteWrapper((byte) 1), + new ByteWrapper((byte) 2), + new ByteWrapper((byte) 3), + new ByteWrapper((byte) 4), + new ByteWrapper((byte) 5), + new ByteWrapper((byte) 6), + new ByteWrapper((byte) 7), + new ByteWrapper((byte) 8) + }; + + // Extract byte values from wrapper objects into the backing array + final byte[] backingArrayRef = new byte[sourceObjects.length]; + for (int i = 0; i < sourceObjects.length; i++) { + backingArrayRef[i] = sourceObjects[i].getValue(); + } + + // Create ByteBuffer that wraps the backing array + final ByteBuffer byteBufferRef = ByteBuffer.wrap(backingArrayRef); + + // Verify initial object relationships + assertTrue(byteBufferRef.hasArray(), "ByteBuffer should be backed by an array"); + assertSame( + backingArrayRef, + byteBufferRef.array(), + "ByteBuffer.array() must return the exact same byte[] instance"); + + // Create stream using actual gRPC ReadableBuffers implementation + // ReadableBuffers.wrap() creates a ReadableBuffer from a ByteBuffer + // ReadableBuffers.openStream() creates a BufferInputStream that implements HasByteBuffer + ReadableBuffer readableBuffer = ReadableBuffers.wrap(byteBufferRef); + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + // Verify the stream implements HasByteBuffer (required for the fast path) + assertTrue( + stream instanceof HasByteBuffer, "gRPC BufferInputStream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "gRPC BufferInputStream should support byteBuffer"); + + try (ArrowBuf buf = allocator.buffer(backingArrayRef.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, backingArrayRef.length, true); + + // Verify data transfer + assertEquals(backingArrayRef.length, buf.writerIndex()); + byte[] result = new byte[backingArrayRef.length]; + buf.getBytes(0, result); + assertArrayEquals(backingArrayRef, result); + + // VERIFICATION: Check that source objects are preserved and data transferred correctly + + // 1. The source wrapper objects are unchanged and still accessible with same identity + for (int i = 0; i < sourceObjects.length; i++) { + assertSame( + sourceObjects[i], + sourceObjects[i], + "Source wrapper object at index " + i + " must retain identity"); + assertEquals( + sourceObjects[i].getValue(), + backingArrayRef[i], + "Wrapper value at index " + i + " must match backing array"); + } + + // 2. The original backing array reference is preserved in the original ByteBuffer + assertSame( + backingArrayRef, + byteBufferRef.array(), + "Original ByteBuffer's backing array must be the same instance"); + + // 3. Verify ArrowBuf received a copy (data independence) + byte[] originalValues = result.clone(); + backingArrayRef[0] = 99; + buf.getBytes(0, result); + assertArrayEquals(originalValues, result, "ArrowBuf data should be independent of source"); + + // 4. Verify modifying backing array doesn't affect wrapper objects + assertEquals( + (byte) 1, + sourceObjects[0].getValue(), + "Wrapper objects should be independent of backing array modifications"); + } + } + } + + /** Wrapper class that holds a byte value as a distinct Java object for reference tracking. */ + private static final class ByteWrapper { + private final byte value; + + ByteWrapper(byte value) { + this.value = value; + } + + byte getValue() { + return value; + } + } + + // --- Legacy Reflection Tests (only run when HasByteBuffer API is disabled) --- + + /** + * Tests for the legacy reflection-based implementation. These tests only run when the + * HasByteBuffer API is disabled via the system property {@code + * arrow.flight.grpc.enable_hasbytebuffer_api=false}. + */ + @Nested + @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isLegacyReflectionEnabled") + class LegacyReflectionTests { + + @Test + public void testFastPath_fallsBackToSlowPath_withRegularInputStream() throws IOException { + // When fastPath=true but stream is not BufferInputStream, should use slow path + byte[] testData = {1, 2, 3, 4, 5}; + InputStream stream = new ByteArrayInputStream(testData); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testFastPath_fallsBackToSlowPath_withHasByteBufferStream() throws IOException { + // When fastPath=true but HasByteBuffer API is disabled, should use slow path + byte[] testData = {10, 20, 30, 40}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = allocator.buffer(testData.length)) { + GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); + + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testFastPath_largeData_fallsBackToSlowPath() throws IOException { + // Verify large data works correctly when falling back to slow path + int size = 64 * 1024; + byte[] testData = new byte[size]; + for (int i = 0; i < size; i++) { + testData[i] = (byte) (i % 256); + } + InputStream stream = new ByteArrayInputStream(testData); + + try (ArrowBuf buf = allocator.buffer(size)) { + GetReadableBuffer.readIntoBuffer(stream, buf, size, true); + + assertEquals(size, buf.writerIndex()); + byte[] result = new byte[size]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + } + + // --- Zero-Copy Ownership Transfer Tests --- + + @Nested + @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isHasByteBufferApiEnabled") + class ZeroCopyOwnershipTransferTests { + + @Test + public void testReadWithOwnershipTransfer_emptyBuffer() throws IOException { + DetachableHasByteBufferInputStream stream = + new DetachableHasByteBufferInputStream(List.of(), true, true); + + try (ArrowBuf buf = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, 0)) { + assertNotNull(buf); + assertEquals(0, buf.capacity()); + } + } + + @Test + public void testReadWithOwnershipTransfer_fallbackToRegularStream() throws IOException { + // Regular stream without Detachable should fall back to copy + byte[] testData = {1, 2, 3, 4, 5}; + HasByteBufferInputStream stream = + new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); + + try (ArrowBuf buf = + GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { + assertNotNull(buf); + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testReadWithOwnershipTransfer_fallbackToHeapBuffer() throws IOException { + // Heap buffer (non-direct) should fall back to copy + byte[] testData = {1, 2, 3, 4, 5}; + ByteBuffer heapBuffer = ByteBuffer.wrap(testData); + DetachableHasByteBufferInputStream stream = + new DetachableHasByteBufferInputStream(List.of(heapBuffer), true, false); + + try (ArrowBuf buf = + GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { + assertNotNull(buf); + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + + @Test + public void testReadWithOwnershipTransfer_directBuffer() throws IOException { + // Direct buffer with Detachable should attempt zero-copy + byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; + ByteBuffer directBuffer = ByteBuffer.allocateDirect(testData.length); + directBuffer.put(testData); + directBuffer.flip(); + + AtomicBoolean detachCalled = new AtomicBoolean(false); + AtomicBoolean streamClosed = new AtomicBoolean(false); + DetachableHasByteBufferInputStream stream = + new DetachableHasByteBufferInputStream(List.of(directBuffer), true, true) { + @Override + public InputStream detach() { + detachCalled.set(true); + return new DetachableHasByteBufferInputStream( + List.of(directBuffer.duplicate()), true, true) { + @Override + public void close() throws IOException { + streamClosed.set(true); + super.close(); + } + }; + } + }; + + try (ArrowBuf buf = + GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { + assertNotNull(buf); + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + + // Verify detach was called for direct buffer + assertTrue(detachCalled.get(), "detach() should be called for direct buffer"); + } + // After ArrowBuf is closed, the detached stream should be closed + assertTrue(streamClosed.get(), "Detached stream should be closed when ArrowBuf is released"); + } + + @Test + public void testReadWithOwnershipTransfer_fragmentedBuffers() throws IOException { + // Fragmented buffers (multiple small buffers) should fall back to copy + byte[] part1 = {1, 2, 3}; + byte[] part2 = {4, 5, 6}; + byte[] expected = {1, 2, 3, 4, 5, 6}; + + ByteBuffer directBuffer1 = ByteBuffer.allocateDirect(part1.length); + directBuffer1.put(part1); + directBuffer1.flip(); + + ByteBuffer directBuffer2 = ByteBuffer.allocateDirect(part2.length); + directBuffer2.put(part2); + directBuffer2.flip(); + + DetachableHasByteBufferInputStream stream = + new DetachableHasByteBufferInputStream(List.of(directBuffer1, directBuffer2), true, true); + + try (ArrowBuf buf = + GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, expected.length)) { + assertNotNull(buf); + assertEquals(expected.length, buf.writerIndex()); + byte[] result = new byte[expected.length]; + buf.getBytes(0, result); + assertArrayEquals(expected, result); + } + } + + @Test + public void testReadWithOwnershipTransfer_byteBufferNotSupported() throws IOException { + // Stream that doesn't support byteBuffer should fall back to copy + byte[] testData = {1, 2, 3, 4, 5}; + DetachableHasByteBufferInputStream stream = + new DetachableHasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), false, true); + + try (ArrowBuf buf = + GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { + assertNotNull(buf); + assertEquals(testData.length, buf.writerIndex()); + byte[] result = new byte[testData.length]; + buf.getBytes(0, result); + assertArrayEquals(testData, result); + } + } + } + + /** + * Test helper class that implements both InputStream and HasByteBuffer. This allows testing the + * fast path without depending on gRPC internal classes. + */ + private static class HasByteBufferInputStream extends InputStream implements HasByteBuffer { + private final List buffers; + private final boolean byteBufferSupported; + private final boolean failOnSkip; + private int currentBufferIndex; + + HasByteBufferInputStream(List buffers, boolean byteBufferSupported) { + this(buffers, byteBufferSupported, false); + } + + HasByteBufferInputStream( + List buffers, boolean byteBufferSupported, boolean failOnSkip) { + this.buffers = new ArrayList<>(); + for (ByteBuffer bb : buffers) { + ByteBuffer copy = ByteBuffer.allocate(bb.remaining()); + copy.put(bb.duplicate()); + copy.flip(); + this.buffers.add(copy); + } + this.byteBufferSupported = byteBufferSupported; + this.failOnSkip = failOnSkip; + this.currentBufferIndex = 0; + } + + @Override + public boolean byteBufferSupported() { + return byteBufferSupported; + } + + @Override + public ByteBuffer getByteBuffer() { + while (currentBufferIndex < buffers.size() + && !buffers.get(currentBufferIndex).hasRemaining()) { + currentBufferIndex++; + } + + if (currentBufferIndex >= buffers.size()) { + return null; + } + + return buffers.get(currentBufferIndex).asReadOnlyBuffer(); + } + + @Override + public long skip(long n) throws IOException { + if (failOnSkip) { + throw new IOException("Simulated skip failure"); + } + + long skipped = 0; + while (skipped < n && currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + int toSkip = (int) Math.min(n - skipped, current.remaining()); + current.position(current.position() + toSkip); + skipped += toSkip; + + if (!current.hasRemaining()) { + currentBufferIndex++; + } + } + return skipped; + } + + @Override + public int read() throws IOException { + while (currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + if (current.hasRemaining()) { + return current.get() & 0xFF; + } + currentBufferIndex++; + } + return -1; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + + int totalRead = 0; + while (totalRead < len && currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + if (current.hasRemaining()) { + int toRead = Math.min(len - totalRead, current.remaining()); + current.get(b, off + totalRead, toRead); + totalRead += toRead; + } + if (!current.hasRemaining()) { + currentBufferIndex++; + } + } + return totalRead == 0 ? -1 : totalRead; + } + + @Override + public int available() { + int available = 0; + for (int i = currentBufferIndex; i < buffers.size(); i++) { + available += buffers.get(i).remaining(); + } + return available; + } + } + + /** + * Test helper class that implements InputStream, HasByteBuffer, and Detachable. This allows + * testing the zero-copy ownership transfer path. + */ + private static class DetachableHasByteBufferInputStream extends InputStream + implements HasByteBuffer, Detachable { + private final List buffers; + private final boolean byteBufferSupported; + private final boolean useDirect; + private int currentBufferIndex; + private int markBufferIndex; + private int[] markPositions; + + DetachableHasByteBufferInputStream( + List buffers, boolean byteBufferSupported, boolean useDirect) { + this.buffers = new ArrayList<>(); + for (ByteBuffer bb : buffers) { + ByteBuffer copy; + if (useDirect && bb.isDirect()) { + // Keep direct buffers as-is (duplicate to get independent position) + copy = bb.duplicate(); + } else if (useDirect) { + // Convert to direct buffer + copy = ByteBuffer.allocateDirect(bb.remaining()); + copy.put(bb.duplicate()); + copy.flip(); + } else { + // Use heap buffer + copy = ByteBuffer.allocate(bb.remaining()); + copy.put(bb.duplicate()); + copy.flip(); + } + this.buffers.add(copy); + } + this.byteBufferSupported = byteBufferSupported; + this.useDirect = useDirect; + this.currentBufferIndex = 0; + this.markBufferIndex = 0; + this.markPositions = null; + } + + @Override + public boolean byteBufferSupported() { + return byteBufferSupported; + } + + @Override + public ByteBuffer getByteBuffer() { + while (currentBufferIndex < buffers.size() + && !buffers.get(currentBufferIndex).hasRemaining()) { + currentBufferIndex++; + } + + if (currentBufferIndex >= buffers.size()) { + return null; + } + + // Return a duplicate so that position/limit changes don't affect the internal buffer + // This matches the HasByteBuffer contract + return buffers.get(currentBufferIndex).duplicate(); + } + + @Override + public InputStream detach() { + // Create a new stream with the remaining data + List remainingBuffers = new ArrayList<>(); + for (int i = currentBufferIndex; i < buffers.size(); i++) { + ByteBuffer bb = buffers.get(i); + if (bb.hasRemaining()) { + remainingBuffers.add(bb.duplicate()); + } + } + // Clear this stream's buffers + buffers.clear(); + currentBufferIndex = 0; + return new DetachableHasByteBufferInputStream( + remainingBuffers, byteBufferSupported, useDirect); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark(int readLimit) { + markBufferIndex = currentBufferIndex; + // Save positions of all buffers from current index onwards + markPositions = new int[buffers.size()]; + for (int i = 0; i < buffers.size(); i++) { + markPositions[i] = buffers.get(i).position(); + } + } + + @Override + public void reset() throws IOException { + if (markPositions == null) { + throw new IOException("Mark not set"); + } + currentBufferIndex = markBufferIndex; + // Restore positions of all buffers + for (int i = 0; i < buffers.size(); i++) { + buffers.get(i).position(markPositions[i]); + } + } + + @Override + public long skip(long n) throws IOException { + long skipped = 0; + while (skipped < n && currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + int toSkip = (int) Math.min(n - skipped, current.remaining()); + current.position(current.position() + toSkip); + skipped += toSkip; + + if (!current.hasRemaining()) { + currentBufferIndex++; + } + } + return skipped; + } + + @Override + public int read() throws IOException { + while (currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + if (current.hasRemaining()) { + return current.get() & 0xFF; + } + currentBufferIndex++; + } + return -1; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + + int totalRead = 0; + while (totalRead < len && currentBufferIndex < buffers.size()) { + ByteBuffer current = buffers.get(currentBufferIndex); + if (current.hasRemaining()) { + int toRead = Math.min(len - totalRead, current.remaining()); + current.get(b, off + totalRead, toRead); + totalRead += toRead; + } + if (!current.hasRemaining()) { + currentBufferIndex++; + } + } + return totalRead == 0 ? -1 : totalRead; + } + + @Override + public int available() { + int available = 0; + for (int i = currentBufferIndex; i < buffers.size(); i++) { + available += buffers.get(i).remaining(); + } + return available; + } + } +} From 167b0e66eaef051ed9378eddc117f734f757305c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Thu, 8 Jan 2026 23:50:25 +0000 Subject: [PATCH 2/4] Revert "wip" This reverts commit e109e2c8a72e5e4a940937e2af2bf7b8212cc825. --- .../org/apache/arrow/flight/ArrowMessage.java | 16 +- .../arrow/flight/grpc/GetReadableBuffer.java | 334 +------ .../flight/grpc/TestGetReadableBuffer.java | 888 ------------------ 3 files changed, 31 insertions(+), 1207 deletions(-) delete mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 6187a59672..ab4eab3048 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -312,12 +312,8 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s case APP_METADATA_TAG: { int size = readRawVarint32(stream); - if (ENABLE_ZERO_COPY_READ) { - appMetadata = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, size); - } else { - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, false); - } + appMetadata = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); break; } case BODY_TAG: @@ -327,12 +323,8 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s body = null; } int size = readRawVarint32(stream); - if (ENABLE_ZERO_COPY_READ) { - body = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, size); - } else { - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, false); - } + body = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); break; default: diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java index f3d3b9de02..45c32a86c6 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java @@ -18,95 +18,50 @@ import com.google.common.base.Throwables; import com.google.common.io.ByteStreams; -import io.grpc.Detachable; -import io.grpc.HasByteBuffer; import io.grpc.internal.ReadableBuffer; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; -import java.nio.ByteBuffer; import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.ForeignAllocation; -import org.apache.arrow.memory.util.MemoryUtil; /** - * Utility class for efficiently reading data from gRPC InputStreams into Arrow buffers. + * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target + * ByteBuffer/ByteBuf. * - *

This class provides two implementations for zero-copy reads: - * - *

- * - *

When neither fast path is available, falls back to copying via a byte array. + *

This could be solved by BufferInputStream exposing Drainable. */ public class GetReadableBuffer { - /** - * System property to control whether the HasByteBuffer API is used. Default is {@code true}. Set - * to {@code false} to use the legacy reflection-based implementation. - */ - public static final String HASBYTEBUFFER_API_PROPERTY = - "arrow.flight.grpc.enable_hasbytebuffer_api"; - - private static final boolean USE_HASBYTEBUFFER_API; - - // Legacy reflection-based fields private static final Field READABLE_BUFFER; private static final Class BUFFER_INPUT_STREAM; static { - // Determine which implementation to use based on system property (default: true) - USE_HASBYTEBUFFER_API = - !"false".equalsIgnoreCase(System.getProperty(HASBYTEBUFFER_API_PROPERTY, "true")); - - // Initialize legacy reflection-based implementation (used as fallback or when explicitly - // enabled) Field tmpField = null; Class tmpClazz = null; - if (!USE_HASBYTEBUFFER_API) { - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException( - "Failed to initialize GetReadableBuffer reflection, falling back to slow path", e) - .printStackTrace(); - } + try { + Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); + + Field f = clazz.getDeclaredField("buffer"); + f.setAccessible(true); + // don't set until we've gotten past all exception cases. + tmpField = f; + tmpClazz = clazz; + } catch (Exception e) { + new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) + .printStackTrace(); } READABLE_BUFFER = tmpField; BUFFER_INPUT_STREAM = tmpClazz; } - private GetReadableBuffer() {} - - /** - * Returns whether the HasByteBuffer API is enabled. - * - * @return true if the HasByteBuffer API is enabled, false if using legacy reflection - */ - public static boolean isHasByteBufferApiEnabled() { - return USE_HASBYTEBUFFER_API; - } - /** - * Extracts the ReadableBuffer for the given input stream using reflection. + * Extracts the ReadableBuffer for the given input stream. * * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null * will be returned. - * @deprecated This method uses gRPC internal APIs via reflection. Prefer using {@link - * #readIntoBuffer} which uses the public HasByteBuffer API by default. */ - @Deprecated public static ReadableBuffer getReadableBuffer(InputStream is) { + if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { return null; } @@ -121,259 +76,24 @@ public static ReadableBuffer getReadableBuffer(InputStream is) { /** * Helper method to read a gRPC-provided InputStream into an ArrowBuf. * - *

When fastPath is enabled, this method attempts to use zero-copy reads: - * - *

- * - *

Falls back to copying via a byte array if zero-copy is not available. - * - * @param stream The stream to read from. + * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. * @param buf The buffer to read into. * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (zero-copy reads). - * @throws IOException if there is an error reading from the stream + * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link + * #BUFFER_INPUT_STREAM}). + * @throws IOException if there is an error reading form the stream */ public static void readIntoBuffer( final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) throws IOException { - if (size == 0) { - buf.writerIndex(0); - return; + ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; + if (readableBuffer != null) { + readableBuffer.readBytes(buf.nioBuffer(0, size)); + } else { + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); } - - if (fastPath) { - if (USE_HASBYTEBUFFER_API) { - // New implementation using public HasByteBuffer API - if (stream instanceof HasByteBuffer) { - HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; - if (hasByteBuffer.byteBufferSupported()) { - readUsingHasByteBuffer(stream, hasByteBuffer, buf, size); - return; - } - } - } else { - // Legacy implementation using reflection - ReadableBuffer readableBuffer = getReadableBuffer(stream); - if (readableBuffer != null) { - readableBuffer.readBytes(buf.nioBuffer(0, size)); - buf.writerIndex(size); - return; - } - } - } - - // Slow path: copy via byte array - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - buf.writerIndex(size); - } - - /** - * Reads data from a stream using the HasByteBuffer zero-copy API. - * - *

This method copies data from gRPC's ByteBuffers into the provided ArrowBuf. While it avoids - * intermediate byte[] allocations, it still performs a memory copy. - * - * @param stream The underlying InputStream (for skip operations) - * @param hasByteBuffer The HasByteBuffer interface of the stream - * @param buf The ArrowBuf to write into - * @param size The number of bytes to read - * @throws IOException if there is an error reading from the stream - */ - private static void readUsingHasByteBuffer( - final InputStream stream, final HasByteBuffer hasByteBuffer, final ArrowBuf buf, int size) - throws IOException { - int offset = 0; - int remaining = size; - - while (remaining > 0) { - ByteBuffer byteBuffer = hasByteBuffer.getByteBuffer(); - if (byteBuffer == null) { - throw new IOException( - "Unexpected end of stream: expected " + size + " bytes, got " + offset); - } - - int available = byteBuffer.remaining(); - int toCopy = Math.min(remaining, available); - - // Copy data from the ByteBuffer to the ArrowBuf. - // We use the ByteBuffer directly without duplicate() since HasByteBuffer javadoc states: - // "The returned buffer's content should not be modified, but the position, limit, and mark - // may be changed. Operations for changing the position, limit, and mark of the returned - // buffer does not affect the position, limit, and mark of this input stream." - int originalLimit = byteBuffer.limit(); - byteBuffer.limit(byteBuffer.position() + toCopy); - buf.setBytes(offset, byteBuffer); - byteBuffer.limit(originalLimit); - - // Advance the stream position - long skipped = stream.skip(toCopy); - if (skipped != toCopy) { - throw new IOException("Failed to skip bytes: expected " + toCopy + ", skipped " + skipped); - } - - offset += toCopy; - remaining -= toCopy; - } - buf.writerIndex(size); } - - /** - * Reads data from a gRPC stream with true zero-copy ownership transfer when possible. - * - *

This method attempts to achieve zero-copy by taking ownership of gRPC's underlying - * ByteBuffers using the {@link Detachable} interface. When successful, the returned ArrowBuf - * directly wraps the gRPC buffer's memory without any data copying. - * - *

Zero-copy ownership transfer is only possible when: - * - *

- * - *

When zero-copy is not possible, this method falls back to allocating a new buffer and - * copying the data. - * - * @param allocator The allocator to use for buffer allocation (used for both zero-copy wrapping - * and fallback allocation) - * @param stream The gRPC InputStream to read from - * @param size The number of bytes to read - * @return An ArrowBuf containing the data. The caller is responsible for closing this buffer. - * @throws IOException if there is an error reading from the stream - */ - public static ArrowBuf readWithOwnershipTransfer( - final BufferAllocator allocator, final InputStream stream, final int size) - throws IOException { - if (size == 0) { - return allocator.getEmpty(); - } - - // Try zero-copy ownership transfer if the stream supports it - if (USE_HASBYTEBUFFER_API && stream instanceof HasByteBuffer && stream instanceof Detachable) { - HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; - if (hasByteBuffer.byteBufferSupported()) { - ArrowBuf zeroCopyBuf = tryZeroCopyOwnershipTransfer(allocator, stream, hasByteBuffer, size); - if (zeroCopyBuf != null) { - return zeroCopyBuf; - } - } - } - - // Fall back to copy-based approach - ArrowBuf buf = allocator.buffer(size); - try { - readIntoBuffer(stream, buf, size, true); - return buf; - } catch (Exception e) { - buf.close(); - throw e; - } - } - - /** - * Attempts zero-copy ownership transfer from gRPC stream to ArrowBuf. - * - * @return ArrowBuf wrapping gRPC's memory if successful, null if zero-copy is not possible - */ - private static ArrowBuf tryZeroCopyOwnershipTransfer( - final BufferAllocator allocator, - final InputStream stream, - final HasByteBuffer hasByteBuffer, - final int size) - throws IOException { - // Check if mark is supported - we need it to reset the stream if zero-copy fails - if (!stream.markSupported()) { - return null; - } - - // Use mark() to prevent premature deallocation while we inspect the buffer - // According to gRPC docs: "skip() deallocates the last ByteBuffer, similar to read()" - // mark() prevents this deallocation - stream.mark(size); - - try { - ByteBuffer byteBuffer = hasByteBuffer.getByteBuffer(); - if (byteBuffer == null) { - // No need to reset - stream is already at end - return null; - } - - // Zero-copy only works with direct ByteBuffers (they have a memory address) - if (!byteBuffer.isDirect()) { - // Stream position hasn't changed, no need to reset - return null; - } - - // Check if this single buffer contains all the data we need - if (byteBuffer.remaining() < size) { - // Data is fragmented across multiple buffers, can't do zero-copy - // Stream position hasn't changed, no need to reset - return null; - } - - // Take ownership of the underlying buffers using Detachable.detach() - Detachable detachable = (Detachable) stream; - InputStream detachedStream = detachable.detach(); - - // Get the ByteBuffer from the detached stream - if (!(detachedStream instanceof HasByteBuffer)) { - // Detached stream doesn't support HasByteBuffer, fall back - // Note: original stream is now empty after detach(), can't fall back - closeQuietly(detachedStream); - throw new IOException("Detached stream does not support HasByteBuffer"); - } - - HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; - ByteBuffer detachedBuffer = detachedHasByteBuffer.getByteBuffer(); - if (detachedBuffer == null || !detachedBuffer.isDirect()) { - closeQuietly(detachedStream); - throw new IOException("Detached buffer is null or not direct"); - } - - // Get the memory address of the ByteBuffer - long memoryAddress = - MemoryUtil.getByteBufferAddress(detachedBuffer) + detachedBuffer.position(); - - // Create a ForeignAllocation that will close the detached stream when released - final InputStream streamToClose = detachedStream; - ForeignAllocation allocation = - new ForeignAllocation(size, memoryAddress) { - @Override - protected void release0() { - closeQuietly(streamToClose); - } - }; - - // Wrap the foreign allocation in an ArrowBuf - ArrowBuf buf = allocator.wrapForeignAllocation(allocation); - buf.writerIndex(size); - return buf; - - } catch (Exception e) { - // Reset the stream position on failure (if possible) - try { - stream.reset(); - } catch (IOException resetEx) { - e.addSuppressed(resetEx); - } - throw e; - } - } - - private static void closeQuietly(InputStream stream) { - if (stream != null) { - try { - stream.close(); - } catch (IOException ignored) { - // Ignore close exceptions - } - } - } } diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java deleted file mode 100644 index 1d3e65909c..0000000000 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestGetReadableBuffer.java +++ /dev/null @@ -1,888 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.grpc; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.grpc.Detachable; -import io.grpc.HasByteBuffer; -import io.grpc.internal.ReadableBuffer; -import io.grpc.internal.ReadableBuffers; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIf; - -/** Tests for {@link GetReadableBuffer}. */ -public class TestGetReadableBuffer { - - private BufferAllocator allocator; - - @BeforeEach - public void setUp() { - allocator = new RootAllocator(Long.MAX_VALUE); - } - - @AfterEach - public void tearDown() { - allocator.close(); - } - - /** Check if the HasByteBuffer API is enabled (new implementation). */ - static boolean isHasByteBufferApiEnabled() { - return GetReadableBuffer.isHasByteBufferApiEnabled(); - } - - /** Check if the legacy reflection-based implementation is enabled. */ - static boolean isLegacyReflectionEnabled() { - return !GetReadableBuffer.isHasByteBufferApiEnabled(); - } - - // --- Feature Flag Tests --- - - @Test - public void testFeatureFlag_isConsistentWithSystemProperty() { - String propertyValue = System.getProperty(GetReadableBuffer.HASBYTEBUFFER_API_PROPERTY, "true"); - boolean expectedEnabled = !"false".equalsIgnoreCase(propertyValue); - assertEquals(expectedEnabled, GetReadableBuffer.isHasByteBufferApiEnabled()); - } - - // --- Slow Path Tests (work with both implementations) --- - - @Test - public void testSlowPath_regularInputStream() throws IOException { - byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; - InputStream stream = new ByteArrayInputStream(testData); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testSlowPath_fastPathDisabled() throws IOException { - byte[] testData = {10, 20, 30, 40}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - // fastPath=false should force slow path even with HasByteBuffer stream - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testSlowPath_byteBufferNotSupported() throws IOException { - byte[] testData = {100, (byte) 200, 50, 75}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), false); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testSlowPath_emptyBuffer() throws IOException { - InputStream stream = new ByteArrayInputStream(new byte[0]); - - try (ArrowBuf buf = allocator.buffer(8)) { - GetReadableBuffer.readIntoBuffer(stream, buf, 0, false); - assertEquals(0, buf.writerIndex()); - } - } - - @Test - public void testDataIntegrity_writerIndexSet() throws IOException { - byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; - InputStream stream = new ByteArrayInputStream(testData); - - try (ArrowBuf buf = allocator.buffer(16)) { - buf.writerIndex(4); - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, false); - assertEquals(testData.length, buf.writerIndex()); - } - } - - // --- Fast Path Tests (only run when HasByteBuffer API is enabled) --- - - @Nested - @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isHasByteBufferApiEnabled") - class HasByteBufferApiTests { - - @Test - public void testFastPath_singleByteBuffer() throws IOException { - byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testFastPath_multipleByteBuffers() throws IOException { - byte[] part1 = {1, 2, 3, 4}; - byte[] part2 = {5, 6, 7, 8}; - byte[] part3 = {9, 10}; - byte[] expected = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - - List buffers = new ArrayList<>(); - buffers.add(ByteBuffer.wrap(part1)); - buffers.add(ByteBuffer.wrap(part2)); - buffers.add(ByteBuffer.wrap(part3)); - HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); - - try (ArrowBuf buf = allocator.buffer(expected.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, expected.length, true); - - assertEquals(expected.length, buf.writerIndex()); - byte[] result = new byte[expected.length]; - buf.getBytes(0, result); - assertArrayEquals(expected, result); - } - } - - @Test - public void testFastPath_emptyBuffer() throws IOException { - HasByteBufferInputStream stream = new HasByteBufferInputStream(List.of(), true); - - try (ArrowBuf buf = allocator.buffer(8)) { - GetReadableBuffer.readIntoBuffer(stream, buf, 0, true); - assertEquals(0, buf.writerIndex()); - } - } - - @Test - public void testFastPath_partialByteBuffer() throws IOException { - byte[] part1 = {1, 2, 3}; - byte[] part2 = {4, 5, 6, 7, 8}; - byte[] expected = {1, 2, 3, 4, 5}; - - List buffers = new ArrayList<>(); - buffers.add(ByteBuffer.wrap(part1)); - buffers.add(ByteBuffer.wrap(part2)); - HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); - - try (ArrowBuf buf = allocator.buffer(expected.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, expected.length, true); - - assertEquals(expected.length, buf.writerIndex()); - byte[] result = new byte[expected.length]; - buf.getBytes(0, result); - assertArrayEquals(expected, result); - } - } - - @Test - public void testFastPath_largeData() throws IOException { - int size = 64 * 1024; - byte[] testData = new byte[size]; - for (int i = 0; i < size; i++) { - testData[i] = (byte) (i % 256); - } - - List buffers = new ArrayList<>(); - int chunkSize = 8 * 1024; - for (int offset = 0; offset < size; offset += chunkSize) { - int len = Math.min(chunkSize, size - offset); - byte[] chunk = new byte[len]; - System.arraycopy(testData, offset, chunk, 0, len); - buffers.add(ByteBuffer.wrap(chunk)); - } - HasByteBufferInputStream stream = new HasByteBufferInputStream(buffers, true); - - try (ArrowBuf buf = allocator.buffer(size)) { - GetReadableBuffer.readIntoBuffer(stream, buf, size, true); - - assertEquals(size, buf.writerIndex()); - byte[] result = new byte[size]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testErrorHandling_unexpectedEndOfStream() { - byte[] testData = {1, 2, 3}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = allocator.buffer(10)) { - assertThrows( - IOException.class, () -> GetReadableBuffer.readIntoBuffer(stream, buf, 10, true)); - } - } - - @Test - public void testErrorHandling_skipFailure() { - byte[] testData = {1, 2, 3, 4, 5}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true, true); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - assertThrows( - IOException.class, - () -> GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true)); - } - } - - @Test - public void testDataIntegrity_offsetWriting() throws IOException { - byte[] testData = {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = allocator.buffer(16)) { - for (int i = 0; i < 16; i++) { - buf.setByte(i, 0xFF); - } - - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); - - assertEquals((byte) 0xDE, buf.getByte(0)); - assertEquals((byte) 0xAD, buf.getByte(1)); - assertEquals((byte) 0xBE, buf.getByte(2)); - assertEquals((byte) 0xEF, buf.getByte(3)); - assertEquals(testData.length, buf.writerIndex()); - } - } - - /** - * Verifies the HasByteBuffer data transfer path using the actual gRPC {@link - * ReadableBuffers#openStream} implementation. - * - *

This test uses the real gRPC BufferInputStream class (via ReadableBuffers.openStream) to - * ensure our code works correctly with actual gRPC streams, not just custom test helpers. - */ - @Test - public void testHasByteBufferPath_dataTransferWithObjectTracking() throws IOException { - // Create wrapper objects as the conceptual "source" of our data - // Each ByteWrapper is a distinct Java object instance we can track - final ByteWrapper[] sourceObjects = { - new ByteWrapper((byte) 1), - new ByteWrapper((byte) 2), - new ByteWrapper((byte) 3), - new ByteWrapper((byte) 4), - new ByteWrapper((byte) 5), - new ByteWrapper((byte) 6), - new ByteWrapper((byte) 7), - new ByteWrapper((byte) 8) - }; - - // Extract byte values from wrapper objects into the backing array - final byte[] backingArrayRef = new byte[sourceObjects.length]; - for (int i = 0; i < sourceObjects.length; i++) { - backingArrayRef[i] = sourceObjects[i].getValue(); - } - - // Create ByteBuffer that wraps the backing array - final ByteBuffer byteBufferRef = ByteBuffer.wrap(backingArrayRef); - - // Verify initial object relationships - assertTrue(byteBufferRef.hasArray(), "ByteBuffer should be backed by an array"); - assertSame( - backingArrayRef, - byteBufferRef.array(), - "ByteBuffer.array() must return the exact same byte[] instance"); - - // Create stream using actual gRPC ReadableBuffers implementation - // ReadableBuffers.wrap() creates a ReadableBuffer from a ByteBuffer - // ReadableBuffers.openStream() creates a BufferInputStream that implements HasByteBuffer - ReadableBuffer readableBuffer = ReadableBuffers.wrap(byteBufferRef); - InputStream stream = ReadableBuffers.openStream(readableBuffer, true); - - // Verify the stream implements HasByteBuffer (required for the fast path) - assertTrue( - stream instanceof HasByteBuffer, "gRPC BufferInputStream should implement HasByteBuffer"); - assertTrue( - ((HasByteBuffer) stream).byteBufferSupported(), - "gRPC BufferInputStream should support byteBuffer"); - - try (ArrowBuf buf = allocator.buffer(backingArrayRef.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, backingArrayRef.length, true); - - // Verify data transfer - assertEquals(backingArrayRef.length, buf.writerIndex()); - byte[] result = new byte[backingArrayRef.length]; - buf.getBytes(0, result); - assertArrayEquals(backingArrayRef, result); - - // VERIFICATION: Check that source objects are preserved and data transferred correctly - - // 1. The source wrapper objects are unchanged and still accessible with same identity - for (int i = 0; i < sourceObjects.length; i++) { - assertSame( - sourceObjects[i], - sourceObjects[i], - "Source wrapper object at index " + i + " must retain identity"); - assertEquals( - sourceObjects[i].getValue(), - backingArrayRef[i], - "Wrapper value at index " + i + " must match backing array"); - } - - // 2. The original backing array reference is preserved in the original ByteBuffer - assertSame( - backingArrayRef, - byteBufferRef.array(), - "Original ByteBuffer's backing array must be the same instance"); - - // 3. Verify ArrowBuf received a copy (data independence) - byte[] originalValues = result.clone(); - backingArrayRef[0] = 99; - buf.getBytes(0, result); - assertArrayEquals(originalValues, result, "ArrowBuf data should be independent of source"); - - // 4. Verify modifying backing array doesn't affect wrapper objects - assertEquals( - (byte) 1, - sourceObjects[0].getValue(), - "Wrapper objects should be independent of backing array modifications"); - } - } - } - - /** Wrapper class that holds a byte value as a distinct Java object for reference tracking. */ - private static final class ByteWrapper { - private final byte value; - - ByteWrapper(byte value) { - this.value = value; - } - - byte getValue() { - return value; - } - } - - // --- Legacy Reflection Tests (only run when HasByteBuffer API is disabled) --- - - /** - * Tests for the legacy reflection-based implementation. These tests only run when the - * HasByteBuffer API is disabled via the system property {@code - * arrow.flight.grpc.enable_hasbytebuffer_api=false}. - */ - @Nested - @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isLegacyReflectionEnabled") - class LegacyReflectionTests { - - @Test - public void testFastPath_fallsBackToSlowPath_withRegularInputStream() throws IOException { - // When fastPath=true but stream is not BufferInputStream, should use slow path - byte[] testData = {1, 2, 3, 4, 5}; - InputStream stream = new ByteArrayInputStream(testData); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testFastPath_fallsBackToSlowPath_withHasByteBufferStream() throws IOException { - // When fastPath=true but HasByteBuffer API is disabled, should use slow path - byte[] testData = {10, 20, 30, 40}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = allocator.buffer(testData.length)) { - GetReadableBuffer.readIntoBuffer(stream, buf, testData.length, true); - - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testFastPath_largeData_fallsBackToSlowPath() throws IOException { - // Verify large data works correctly when falling back to slow path - int size = 64 * 1024; - byte[] testData = new byte[size]; - for (int i = 0; i < size; i++) { - testData[i] = (byte) (i % 256); - } - InputStream stream = new ByteArrayInputStream(testData); - - try (ArrowBuf buf = allocator.buffer(size)) { - GetReadableBuffer.readIntoBuffer(stream, buf, size, true); - - assertEquals(size, buf.writerIndex()); - byte[] result = new byte[size]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - } - - // --- Zero-Copy Ownership Transfer Tests --- - - @Nested - @EnabledIf("org.apache.arrow.flight.grpc.TestGetReadableBuffer#isHasByteBufferApiEnabled") - class ZeroCopyOwnershipTransferTests { - - @Test - public void testReadWithOwnershipTransfer_emptyBuffer() throws IOException { - DetachableHasByteBufferInputStream stream = - new DetachableHasByteBufferInputStream(List.of(), true, true); - - try (ArrowBuf buf = GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, 0)) { - assertNotNull(buf); - assertEquals(0, buf.capacity()); - } - } - - @Test - public void testReadWithOwnershipTransfer_fallbackToRegularStream() throws IOException { - // Regular stream without Detachable should fall back to copy - byte[] testData = {1, 2, 3, 4, 5}; - HasByteBufferInputStream stream = - new HasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), true); - - try (ArrowBuf buf = - GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { - assertNotNull(buf); - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testReadWithOwnershipTransfer_fallbackToHeapBuffer() throws IOException { - // Heap buffer (non-direct) should fall back to copy - byte[] testData = {1, 2, 3, 4, 5}; - ByteBuffer heapBuffer = ByteBuffer.wrap(testData); - DetachableHasByteBufferInputStream stream = - new DetachableHasByteBufferInputStream(List.of(heapBuffer), true, false); - - try (ArrowBuf buf = - GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { - assertNotNull(buf); - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - - @Test - public void testReadWithOwnershipTransfer_directBuffer() throws IOException { - // Direct buffer with Detachable should attempt zero-copy - byte[] testData = {1, 2, 3, 4, 5, 6, 7, 8}; - ByteBuffer directBuffer = ByteBuffer.allocateDirect(testData.length); - directBuffer.put(testData); - directBuffer.flip(); - - AtomicBoolean detachCalled = new AtomicBoolean(false); - AtomicBoolean streamClosed = new AtomicBoolean(false); - DetachableHasByteBufferInputStream stream = - new DetachableHasByteBufferInputStream(List.of(directBuffer), true, true) { - @Override - public InputStream detach() { - detachCalled.set(true); - return new DetachableHasByteBufferInputStream( - List.of(directBuffer.duplicate()), true, true) { - @Override - public void close() throws IOException { - streamClosed.set(true); - super.close(); - } - }; - } - }; - - try (ArrowBuf buf = - GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { - assertNotNull(buf); - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - - // Verify detach was called for direct buffer - assertTrue(detachCalled.get(), "detach() should be called for direct buffer"); - } - // After ArrowBuf is closed, the detached stream should be closed - assertTrue(streamClosed.get(), "Detached stream should be closed when ArrowBuf is released"); - } - - @Test - public void testReadWithOwnershipTransfer_fragmentedBuffers() throws IOException { - // Fragmented buffers (multiple small buffers) should fall back to copy - byte[] part1 = {1, 2, 3}; - byte[] part2 = {4, 5, 6}; - byte[] expected = {1, 2, 3, 4, 5, 6}; - - ByteBuffer directBuffer1 = ByteBuffer.allocateDirect(part1.length); - directBuffer1.put(part1); - directBuffer1.flip(); - - ByteBuffer directBuffer2 = ByteBuffer.allocateDirect(part2.length); - directBuffer2.put(part2); - directBuffer2.flip(); - - DetachableHasByteBufferInputStream stream = - new DetachableHasByteBufferInputStream(List.of(directBuffer1, directBuffer2), true, true); - - try (ArrowBuf buf = - GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, expected.length)) { - assertNotNull(buf); - assertEquals(expected.length, buf.writerIndex()); - byte[] result = new byte[expected.length]; - buf.getBytes(0, result); - assertArrayEquals(expected, result); - } - } - - @Test - public void testReadWithOwnershipTransfer_byteBufferNotSupported() throws IOException { - // Stream that doesn't support byteBuffer should fall back to copy - byte[] testData = {1, 2, 3, 4, 5}; - DetachableHasByteBufferInputStream stream = - new DetachableHasByteBufferInputStream(List.of(ByteBuffer.wrap(testData)), false, true); - - try (ArrowBuf buf = - GetReadableBuffer.readWithOwnershipTransfer(allocator, stream, testData.length)) { - assertNotNull(buf); - assertEquals(testData.length, buf.writerIndex()); - byte[] result = new byte[testData.length]; - buf.getBytes(0, result); - assertArrayEquals(testData, result); - } - } - } - - /** - * Test helper class that implements both InputStream and HasByteBuffer. This allows testing the - * fast path without depending on gRPC internal classes. - */ - private static class HasByteBufferInputStream extends InputStream implements HasByteBuffer { - private final List buffers; - private final boolean byteBufferSupported; - private final boolean failOnSkip; - private int currentBufferIndex; - - HasByteBufferInputStream(List buffers, boolean byteBufferSupported) { - this(buffers, byteBufferSupported, false); - } - - HasByteBufferInputStream( - List buffers, boolean byteBufferSupported, boolean failOnSkip) { - this.buffers = new ArrayList<>(); - for (ByteBuffer bb : buffers) { - ByteBuffer copy = ByteBuffer.allocate(bb.remaining()); - copy.put(bb.duplicate()); - copy.flip(); - this.buffers.add(copy); - } - this.byteBufferSupported = byteBufferSupported; - this.failOnSkip = failOnSkip; - this.currentBufferIndex = 0; - } - - @Override - public boolean byteBufferSupported() { - return byteBufferSupported; - } - - @Override - public ByteBuffer getByteBuffer() { - while (currentBufferIndex < buffers.size() - && !buffers.get(currentBufferIndex).hasRemaining()) { - currentBufferIndex++; - } - - if (currentBufferIndex >= buffers.size()) { - return null; - } - - return buffers.get(currentBufferIndex).asReadOnlyBuffer(); - } - - @Override - public long skip(long n) throws IOException { - if (failOnSkip) { - throw new IOException("Simulated skip failure"); - } - - long skipped = 0; - while (skipped < n && currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - int toSkip = (int) Math.min(n - skipped, current.remaining()); - current.position(current.position() + toSkip); - skipped += toSkip; - - if (!current.hasRemaining()) { - currentBufferIndex++; - } - } - return skipped; - } - - @Override - public int read() throws IOException { - while (currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - if (current.hasRemaining()) { - return current.get() & 0xFF; - } - currentBufferIndex++; - } - return -1; - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - if (len == 0) { - return 0; - } - - int totalRead = 0; - while (totalRead < len && currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - if (current.hasRemaining()) { - int toRead = Math.min(len - totalRead, current.remaining()); - current.get(b, off + totalRead, toRead); - totalRead += toRead; - } - if (!current.hasRemaining()) { - currentBufferIndex++; - } - } - return totalRead == 0 ? -1 : totalRead; - } - - @Override - public int available() { - int available = 0; - for (int i = currentBufferIndex; i < buffers.size(); i++) { - available += buffers.get(i).remaining(); - } - return available; - } - } - - /** - * Test helper class that implements InputStream, HasByteBuffer, and Detachable. This allows - * testing the zero-copy ownership transfer path. - */ - private static class DetachableHasByteBufferInputStream extends InputStream - implements HasByteBuffer, Detachable { - private final List buffers; - private final boolean byteBufferSupported; - private final boolean useDirect; - private int currentBufferIndex; - private int markBufferIndex; - private int[] markPositions; - - DetachableHasByteBufferInputStream( - List buffers, boolean byteBufferSupported, boolean useDirect) { - this.buffers = new ArrayList<>(); - for (ByteBuffer bb : buffers) { - ByteBuffer copy; - if (useDirect && bb.isDirect()) { - // Keep direct buffers as-is (duplicate to get independent position) - copy = bb.duplicate(); - } else if (useDirect) { - // Convert to direct buffer - copy = ByteBuffer.allocateDirect(bb.remaining()); - copy.put(bb.duplicate()); - copy.flip(); - } else { - // Use heap buffer - copy = ByteBuffer.allocate(bb.remaining()); - copy.put(bb.duplicate()); - copy.flip(); - } - this.buffers.add(copy); - } - this.byteBufferSupported = byteBufferSupported; - this.useDirect = useDirect; - this.currentBufferIndex = 0; - this.markBufferIndex = 0; - this.markPositions = null; - } - - @Override - public boolean byteBufferSupported() { - return byteBufferSupported; - } - - @Override - public ByteBuffer getByteBuffer() { - while (currentBufferIndex < buffers.size() - && !buffers.get(currentBufferIndex).hasRemaining()) { - currentBufferIndex++; - } - - if (currentBufferIndex >= buffers.size()) { - return null; - } - - // Return a duplicate so that position/limit changes don't affect the internal buffer - // This matches the HasByteBuffer contract - return buffers.get(currentBufferIndex).duplicate(); - } - - @Override - public InputStream detach() { - // Create a new stream with the remaining data - List remainingBuffers = new ArrayList<>(); - for (int i = currentBufferIndex; i < buffers.size(); i++) { - ByteBuffer bb = buffers.get(i); - if (bb.hasRemaining()) { - remainingBuffers.add(bb.duplicate()); - } - } - // Clear this stream's buffers - buffers.clear(); - currentBufferIndex = 0; - return new DetachableHasByteBufferInputStream( - remainingBuffers, byteBufferSupported, useDirect); - } - - @Override - public boolean markSupported() { - return true; - } - - @Override - public void mark(int readLimit) { - markBufferIndex = currentBufferIndex; - // Save positions of all buffers from current index onwards - markPositions = new int[buffers.size()]; - for (int i = 0; i < buffers.size(); i++) { - markPositions[i] = buffers.get(i).position(); - } - } - - @Override - public void reset() throws IOException { - if (markPositions == null) { - throw new IOException("Mark not set"); - } - currentBufferIndex = markBufferIndex; - // Restore positions of all buffers - for (int i = 0; i < buffers.size(); i++) { - buffers.get(i).position(markPositions[i]); - } - } - - @Override - public long skip(long n) throws IOException { - long skipped = 0; - while (skipped < n && currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - int toSkip = (int) Math.min(n - skipped, current.remaining()); - current.position(current.position() + toSkip); - skipped += toSkip; - - if (!current.hasRemaining()) { - currentBufferIndex++; - } - } - return skipped; - } - - @Override - public int read() throws IOException { - while (currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - if (current.hasRemaining()) { - return current.get() & 0xFF; - } - currentBufferIndex++; - } - return -1; - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - if (len == 0) { - return 0; - } - - int totalRead = 0; - while (totalRead < len && currentBufferIndex < buffers.size()) { - ByteBuffer current = buffers.get(currentBufferIndex); - if (current.hasRemaining()) { - int toRead = Math.min(len - totalRead, current.remaining()); - current.get(b, off + totalRead, toRead); - totalRead += toRead; - } - if (!current.hasRemaining()) { - currentBufferIndex++; - } - } - return totalRead == 0 ? -1 : totalRead; - } - - @Override - public int available() { - int available = 0; - for (int i = currentBufferIndex; i < buffers.size(); i++) { - available += buffers.get(i).remaining(); - } - return available; - } - } -} From 3b1d51c24a9df757f2dfdc2064dcd0f8c37e7985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Sat, 10 Jan 2026 23:22:15 +0000 Subject: [PATCH 3/4] Remove grpc-java reflection for ownership transfers --- .../org/apache/arrow/flight/ArrowMessage.java | 130 +++++++++++- .../arrow/flight/grpc/GetReadableBuffer.java | 99 --------- .../flight/TestArrowMessageZeroCopy.java | 196 ++++++++++++++++++ 3 files changed, 321 insertions(+), 104 deletions(-) delete mode 100644 flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java create mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index ab4eab3048..2292a256ca 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -23,7 +23,9 @@ import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; +import io.grpc.Detachable; import io.grpc.Drainable; +import io.grpc.HasByteBuffer; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ByteBuf; @@ -41,11 +43,12 @@ import java.util.Collections; import java.util.List; import org.apache.arrow.flight.grpc.AddWritableBuffer; -import org.apache.arrow.flight.grpc.GetReadableBuffer; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -55,10 +58,14 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The in-memory representation of FlightData used to manage a stream of Arrow messages. */ class ArrowMessage implements AutoCloseable { + private static final Logger LOG = LoggerFactory.getLogger(ArrowMessage.class); + // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer // instead of copying the data. Defaults to true. public static final boolean ENABLE_ZERO_COPY_READ; @@ -312,8 +319,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s case APP_METADATA_TAG: { int size = readRawVarint32(stream); - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); + appMetadata = readBuffer(allocator, stream, size); break; } case BODY_TAG: @@ -323,8 +329,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s body = null; } int size = readRawVarint32(stream); - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); + body = readBuffer(allocator, stream, size); break; default: @@ -377,6 +382,121 @@ private static int readRawVarint32(int firstByte, InputStream is) throws IOExcep return CodedInputStream.readRawVarint32(firstByte, is); } + /** + * Reads data from the stream into an ArrowBuf, without copying data when possible. + * + *

First attempts to transfer ownership of the gRPC buffer to Arrow via {@link + * #wrapGrpcBuffer}. This avoids any memory copy when the gRPC transport provides a direct + * ByteBuffer (e.g., Netty). + * + *

If not possible (e.g., heap buffer, fragmented data, or unsupported transport), falls back + * to allocating a new buffer and copying data into it. + * + * @param allocator The allocator to use for buffer allocation + * @param stream The input stream to read from + * @param size The number of bytes to read + * @return An ArrowBuf containing the data + * @throws IOException if there is an error reading from the stream + */ + private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream, int size) + throws IOException { + if (ENABLE_ZERO_COPY_READ) { + ArrowBuf zeroCopyBuf = wrapGrpcBuffer(stream, allocator, size); + if (zeroCopyBuf != null) { + return zeroCopyBuf; + } + } + + // Fall back to allocating and copying + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; + } + + /** + * Attempts to wrap gRPC's buffer as an ArrowBuf without copying. + * + *

This method takes ownership of gRPC's underlying buffer via {@link Detachable#detach()} and + * wraps it as an ArrowBuf using {@link BufferAllocator#wrapForeignAllocation}. The gRPC buffer + * will be released when the ArrowBuf is closed. + * + * @param stream The gRPC-provided InputStream + * @param allocator The allocator to use for wrapping the foreign allocation + * @param size The number of bytes to wrap + * @return An ArrowBuf wrapping gRPC's buffer, or {@code null} if zero-copy is not possible + */ + static ArrowBuf wrapGrpcBuffer( + final InputStream stream, final BufferAllocator allocator, final int size) { + + if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { + return null; + } + + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (!hasByteBuffer.byteBufferSupported()) { + return null; + } + + ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); + if (peekBuffer == null) { + return null; + } + if (!peekBuffer.isDirect()) { + return null; + } + if (peekBuffer.remaining() < size) { + // Data is fragmented across multiple buffers; zero-copy not possible + return null; + } + + // Take ownership + Detachable detachable = (Detachable) stream; + InputStream detachedStream = detachable.detach(); + + // Get buffer from detached stream + HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; + ByteBuffer detachedByteBuffer = detachedHasByteBuffer.getByteBuffer(); + + if (detachedByteBuffer == null || !detachedByteBuffer.isDirect()) { + closeQuietly(detachedStream); + return null; + } + + // Calculate memory address accounting for buffer position + long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer); + long dataAddress = baseAddress + detachedByteBuffer.position(); + + // Create ForeignAllocation with proper cleanup + ForeignAllocation foreignAllocation = + new ForeignAllocation(size, dataAddress) { + @Override + protected void release0() { + closeQuietly(detachedStream); + } + }; + + try { + return allocator.wrapForeignAllocation(foreignAllocation); + } catch (Throwable t) { + // If it fails, clean up the detached stream and propagate + closeQuietly(detachedStream); + throw t; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + LOG.debug("Error closing detached gRPC stream", e); + } + } + } + /** * Convert the ArrowMessage to an InputStream. * diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java deleted file mode 100644 index 45c32a86c6..0000000000 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.grpc; - -import com.google.common.base.Throwables; -import com.google.common.io.ByteStreams; -import io.grpc.internal.ReadableBuffer; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import org.apache.arrow.memory.ArrowBuf; - -/** - * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target - * ByteBuffer/ByteBuf. - * - *

This could be solved by BufferInputStream exposing Drainable. - */ -public class GetReadableBuffer { - - private static final Field READABLE_BUFFER; - private static final Class BUFFER_INPUT_STREAM; - - static { - Field tmpField = null; - Class tmpClazz = null; - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - // don't set until we've gotten past all exception cases. - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) - .printStackTrace(); - } - READABLE_BUFFER = tmpField; - BUFFER_INPUT_STREAM = tmpClazz; - } - - /** - * Extracts the ReadableBuffer for the given input stream. - * - * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null - * will be returned. - */ - public static ReadableBuffer getReadableBuffer(InputStream is) { - - if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { - return null; - } - - try { - return (ReadableBuffer) READABLE_BUFFER.get(is); - } catch (Exception ex) { - throw Throwables.propagate(ex); - } - } - - /** - * Helper method to read a gRPC-provided InputStream into an ArrowBuf. - * - * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. - * @param buf The buffer to read into. - * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link - * #BUFFER_INPUT_STREAM}). - * @throws IOException if there is an error reading form the stream - */ - public static void readIntoBuffer( - final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) - throws IOException { - ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; - if (readableBuffer != null) { - readableBuffer.readBytes(buf.nioBuffer(0, size)); - } else { - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - } - buf.writerIndex(size); - } -} diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java new file mode 100644 index 0000000000..099b1cd3e5 --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Random; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestArrowMessageZeroCopy { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + private static InputStream createGrpcStreamWithDirectBuffer(byte[] data) { + ByteBuffer directBuffer = ByteBuffer.allocateDirect(data.length); + directBuffer.put(data); + directBuffer.flip(); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(directBuffer); + return ReadableBuffers.openStream(readableBuffer, true); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + InputStream stream = new ByteArrayInputStream(testData); + + // ByteArrayInputStream doesn't implement Detachable or HasByteBuffer + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for streams not implementing required interfaces"); + } + + @Test + public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOException { + byte[] testData = new byte[] {11, 22, 33, 44, 55}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Direct buffer stream should support ByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), + "Should have direct ByteBuffer backing"); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + assertEquals(testData.length, result.capacity()); + + // Check received data is the same + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ByteBuffer heapBuffer = ByteBuffer.wrap(testData); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(heapBuffer); + + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Heap ByteBuffer stream should support ByteBuffer"); + assertFalse( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), "Should have heap ByteBuffer backing"); + + // Zero-copy should return null for heap buffer (not direct) + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream with heap buffer"); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcByteArrayStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ReadableBuffer readableBuffer = ReadableBuffers.wrap(testData); + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + // Verify the stream has the expected gRPC interfaces + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + // Byte array backed streams don't support ByteBuffer access + assertFalse( + ((HasByteBuffer) stream).byteBufferSupported(), + "Byte array stream should not support ByteBuffer"); + + // Zero-copy should return null when byteBufferSupported() is false + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream backed by byte array"); + } + + @Test + public void testWrapGrpcBufferMemoryAccountingWithRealGrpcStream() throws IOException { + byte[] testData = new byte[1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + long memoryBefore = allocator.getAllocatedMemory(); + assertEquals(0, memoryBefore); + + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + + long memoryDuring = allocator.getAllocatedMemory(); + assertEquals(testData.length, memoryDuring); + + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + + result.close(); + + long memoryAfter = allocator.getAllocatedMemory(); + assertEquals(0, memoryAfter); + } + + @Test + public void testWrapGrpcBufferReturnsNullForInsufficientDataWithRealGrpcStream() + throws IOException { + byte[] testData = new byte[] {1, 2, 3}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + // Request more data than available + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, 10); + assertNull(result, "Should return null when buffer has insufficient data"); + } + + @Test + public void testWrapGrpcBufferLargeDataWithRealGrpcStream() throws IOException { + // Test with larger data (64KB) + byte[] testData = new byte[64 * 1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for large data with real gRPC stream"); + assertEquals(testData.length, result.capacity()); + + // Verify data integrity + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } +} From 788633021f4df1b98b2b1f7a3d4ebe84bc6769c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Mon, 12 Jan 2026 14:19:32 +0000 Subject: [PATCH 4/4] remove redundant check --- .../org/apache/arrow/flight/ArrowMessage.java | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 2292a256ca..366277f711 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -408,12 +408,12 @@ private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream } // Fall back to allocating and copying - ArrowBuf buf = allocator.buffer(size); - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - buf.writerIndex(size); - return buf; + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; } /** @@ -453,17 +453,10 @@ static ArrowBuf wrapGrpcBuffer( } // Take ownership - Detachable detachable = (Detachable) stream; - InputStream detachedStream = detachable.detach(); + InputStream detachedStream = ((Detachable) stream).detach(); // Get buffer from detached stream - HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; - ByteBuffer detachedByteBuffer = detachedHasByteBuffer.getByteBuffer(); - - if (detachedByteBuffer == null || !detachedByteBuffer.isDirect()) { - closeQuietly(detachedStream); - return null; - } + ByteBuffer detachedByteBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); // Calculate memory address accounting for buffer position long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer);