Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import datetime
import decimal

from google.cloud import bigquery
from google.cloud.bigquery import enums
from google.cloud.bigquery_storage_v1 import types as gapic_types
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
import pandas as pd

import pyarrow as pa

from google.cloud import bigquery
from google.cloud.bigquery_storage_v1 import types as gapic_types
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream

TABLE_LENGTH = 100_000

BQ_SCHEMA = [
Expand Down Expand Up @@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client):


def create_stream(bqstorage_write_client, table):
stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default"
stream_name = (
f"projects/{table.project}/datasets/{table.dataset_id}/"
f"tables/{table.table_id}/_default"
)
request_template = gapic_types.AppendRowsRequest()
request_template.write_stream = stream_name

Expand Down Expand Up @@ -160,18 +163,64 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH):


def generate_write_requests(pyarrow_table):
# Determine max_chunksize of the record batches. Because max size of
# AppendRowsRequest is 10 MB, we need to split the table if it's too big.
# See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest
max_request_bytes = 10 * 2**20 # 10 MB
chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1
chunk_size = int(pyarrow_table.num_rows / chunk_num)

# Construct request(s).
for batch in pyarrow_table.to_batches(max_chunksize=chunk_size):
# Maximum size for a single AppendRowsRequest is 10 MB.
# To be safe, we'll aim for a soft limit of 7 MB.
max_request_bytes = 7 * 1024 * 1024 # 7 MB

def _create_request(batches):
"""Helper to create an AppendRowsRequest from a list of batches."""
combined_table = pa.Table.from_batches(batches)
request = gapic_types.AppendRowsRequest()
request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes()
yield request
request.arrow_rows.rows.serialized_record_batch = (
combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes()
)
return request

batches = pyarrow_table.to_batches()

current_batches = []
current_size = 0

while batches:
batch = batches.pop()
batch_size = batch.nbytes

if current_size + batch_size > max_request_bytes:
if batch.num_rows > 1:
# Split the batch into 2 sub batches with identical chunksizes
mid = batch.num_rows // 2
batch_left = batch.slice(offset=0, length=mid)
batch_right = batch.slice(offset=mid)

# Append the new batches into the stack and continue poping.
batches.append(batch_right)
batches.append(batch_left)
continue

# If the batch is single row and still larger than max_request_size
else:
# If current batches is empty, throw error
if len(current_batches) == 0:
raise ValueError(
f"A single PyArrow batch of one row is larger than the maximum request size "
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
)
# Otherwise, generate the request, reset current_size and current_batches
else:
yield _create_request(current_batches)

current_batches = []
current_size = 0
batches.append(batch)

# Otherwise, add the batch into current_batches
else:
current_batches.append(batch)
current_size += batch_size

# Flush remaining batches
if current_batches:
yield _create_request(current_batches)


def verify_result(client, table, futures):
Expand All @@ -181,14 +230,13 @@ def verify_result(client, table, futures):
assert bq_table.schema == BQ_SCHEMA

# Verify table size.
query = client.query(f"SELECT COUNT(1) FROM `{bq_table}`;")
query = client.query(f"SELECT DISTINCT int64_col FROM `{bq_table}`;")
query_result = query.result().to_dataframe()

# There might be extra rows due to retries.
assert query_result.iloc[0, 0] >= TABLE_LENGTH
assert len(query_result) == TABLE_LENGTH

# Verify that table was split into multiple requests.
assert len(futures) == 2
assert len(futures) == 3


def main(project_id, dataset):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import pyarrow as pa
import pytest

from . import append_rows_with_arrow


def create_table_with_batches(num_batches, rows_per_batch):
# Generate a small table to get a valid batch
small_table = append_rows_with_arrow.generate_pyarrow_table(rows_per_batch)
# Ensure we get exactly one batch for the small table
batches = small_table.to_batches()
assert len(batches) == 1
batch = batches[0]

# Replicate the batch
all_batches = [batch] * num_batches
return pa.Table.from_batches(all_batches)


# Test generate_write_requests with different numbers of batches in the input table.
# The total rows in the generated table is constantly 1000000.
@pytest.mark.parametrize(
"num_batches, rows_per_batch",
[
(1, 1000000),
(10, 100000),
(100, 10000),
(1000, 1000),
(10000, 100),
(100000, 10),
(1000000, 1),
],
)
def test_generate_write_requests_varying_batches(num_batches, rows_per_batch):
"""Test generate_write_requests with different numbers of batches in the input table."""
# Create a table that returns `num_batches` when to_batches() is called.
table = create_table_with_batches(num_batches, rows_per_batch)

# Verify our setup is correct
assert len(table.to_batches()) == num_batches

# Generate requests
start_time = time.perf_counter()
requests = list(append_rows_with_arrow.generate_write_requests(table))
end_time = time.perf_counter()
print(
f"\nTime used to generate requests for {num_batches} batches: {end_time - start_time:.4f} seconds"
)

# We expect the requests to be aggregated until 7MB.
# Since the row number is constant, the number of requests should be deterministic.
assert len(requests) == 26

# Verify total rows in requests matches total rows in table
total_rows_processed = 0
for request in requests:
# Deserialize the batch from the request to count rows
serialized_batch = request.arrow_rows.rows.serialized_record_batch
# We need a schema to read the batch. The schema is PYARROW_SCHEMA.
batch = pa.ipc.read_record_batch(
serialized_batch, append_rows_with_arrow.PYARROW_SCHEMA
)
total_rows_processed += batch.num_rows

expected_rows = num_batches * rows_per_batch
assert total_rows_processed == expected_rows
Loading