From 4838687066bcfd77dc6f5a68cbbb39591ac168fa Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:52:47 +0000 Subject: [PATCH] Add bigframes.bigquery.load_data function Implemented bigframes.bigquery.load_data to execute LOAD DATA SQL statements. Supports partitioning, clustering, schema specification, and other options. Includes logic to generate temporary destination table name if not provided. Added DDL generation helper in bigframes.core.sql. Exposed function in bigframes.bigquery package. Added unit tests in tests/unit/bigquery/test_io.py. --- bigframes/bigquery/__init__.py | 5 + bigframes/bigquery/_operations/io.py | 116 +++++++++++++++++++ bigframes/core/sql/__init__.py | 93 ++++++++++++++- tests/unit/bigquery/test_io.py | 163 +++++++++++++++++++++++++++ 4 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 bigframes/bigquery/_operations/io.py create mode 100644 tests/unit/bigquery/test_io.py diff --git a/bigframes/bigquery/__init__.py b/bigframes/bigquery/__init__.py index 7a7a01a8fc..7781ee1576 100644 --- a/bigframes/bigquery/__init__.py +++ b/bigframes/bigquery/__init__.py @@ -43,6 +43,7 @@ st_regionstats, st_simplify, ) +from bigframes.bigquery._operations.io import load_data from bigframes.bigquery._operations.json import ( json_extract, json_extract_array, @@ -85,6 +86,8 @@ st_length, st_regionstats, st_simplify, + # io ops + load_data, # json ops json_extract, json_extract_array, @@ -135,6 +138,8 @@ "st_length", "st_regionstats", "st_simplify", + # io ops + "load_data", # json ops "json_extract", "json_extract_array", diff --git a/bigframes/bigquery/_operations/io.py b/bigframes/bigquery/_operations/io.py new file mode 100644 index 0000000000..5503f943e8 --- /dev/null +++ b/bigframes/bigquery/_operations/io.py @@ -0,0 +1,116 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing +from typing import Any, List, Optional + +import google.cloud.bigquery as bigquery + +import bigframes.core.sql + +if typing.TYPE_CHECKING: + import bigframes.dataframe as dataframe + import bigframes.session + +_PLACEHOLDER_SCHEMA = [ + bigquery.SchemaField("bf_load_placeholder", "INT64"), +] + + +def load_data( + uris: str | List[str], + format: str, + destination_table: Optional[str] = None, + *, + schema: Optional[List[bigquery.SchemaField]] = None, + cluster_by: Optional[List[str]] = None, + partition_by: Optional[str] = None, + options: Optional[dict[str, Any]] = None, + load_options: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + hive_partition_columns: Optional[List[bigquery.SchemaField]] = None, + overwrite: bool = False, + session: Optional[bigframes.session.Session] = None, +) -> dataframe.DataFrame: + """ + Loads data from external files into a BigQuery table using the `LOAD DATA` statement. + + Args: + uris (str | List[str]): + The fully qualified URIs for the external data locations (e.g., 'gs://bucket/path/file.csv'). + format (str): + The format of the external data (e.g., 'CSV', 'PARQUET', 'AVRO', 'JSON'). + destination_table (str, optional): + The name of the destination table. If not specified, a temporary table will be created. + schema (List[google.cloud.bigquery.SchemaField], optional): + The schema of the destination table. If not provided, schema auto-detection will be used. + cluster_by (List[str], optional): + A list of columns to cluster the table by. + partition_by (str, optional): + The partition expression for the table. + options (dict[str, Any], optional): + Table options (e.g., {'description': 'my table'}). + load_options (dict[str, Any], optional): + Options for loading data (e.g., {'skip_leading_rows': 1}). + connection (str, optional): + The connection name to use for reading external data. + hive_partition_columns (List[google.cloud.bigquery.SchemaField], optional): + The external partitioning columns. If set to an empty list, partitioning is inferred. + overwrite (bool, default False): + If True, overwrites the destination table. If False, appends to it. + session (bigframes.session.Session, optional): + The session to use. If not provided, the default session is used. + + Returns: + bigframes.dataframe.DataFrame: A DataFrame representing the loaded table. + """ + import bigframes.pandas as bpd + + if session is None: + session = bpd.get_global_session() + + if isinstance(uris, str): + uris = [uris] + + if destination_table is None: + # Create a temporary table name + # We need to access the storage manager from the session + # This is internal API usage, but requested by the user + table_ref = session._storage_manager.create_temp_table(_PLACEHOLDER_SCHEMA) + destination_table = f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + # Since we created a placeholder table, we must overwrite it + overwrite = True + + sql = bigframes.core.sql.load_data_ddl( + destination_table=destination_table, + uris=uris, + format=format, + schema_fields=schema, + cluster_by=cluster_by, + partition_by=partition_by, + table_options=options, + load_options=load_options, + connection=connection, + hive_partition_columns=hive_partition_columns, + overwrite=overwrite, + ) + + # Execute the LOAD DATA statement + session.read_gbq_query(sql) + + # Return a DataFrame pointing to the destination table + # We use session.read_gbq to ensure it uses the same session + return session.read_gbq(destination_table) diff --git a/bigframes/core/sql/__init__.py b/bigframes/core/sql/__init__.py index ccd2a16ddc..608db1d1c3 100644 --- a/bigframes/core/sql/__init__.py +++ b/bigframes/core/sql/__init__.py @@ -21,7 +21,7 @@ import decimal import json import math -from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union import shapely.geometry.base # type: ignore @@ -246,3 +246,94 @@ def create_vector_search_sql( distance, FROM VECTOR_SEARCH({args_str}) """ + + +def _field_type_to_sql(field: bigquery.SchemaField) -> str: + if field.field_type in ("RECORD", "STRUCT"): + sub_defs = [] + for sub in field.fields: + sub_type = _field_type_to_sql(sub) + sub_def = f"{sub.name} {sub_type}" + if sub.mode == "REQUIRED": + sub_def += " NOT NULL" + sub_defs.append(sub_def) + type_str = f"STRUCT<{', '.join(sub_defs)}>" + else: + type_str = field.field_type + + if field.mode == "REPEATED": + return f"ARRAY<{type_str}>" + return type_str + + +def schema_field_to_sql(field: bigquery.SchemaField) -> str: + """Convert a BigQuery SchemaField to a SQL DDL column definition.""" + type_sql = _field_type_to_sql(field) + sql = f"{field.name} {type_sql}" + if field.mode == "REQUIRED": + sql += " NOT NULL" + if field.description: + sql += f" OPTIONS(description={simple_literal(field.description)})" + return sql + + +def load_data_ddl( + destination_table: str, + uris: list[str], + format: str, + *, + schema_fields: list[bigquery.SchemaField] | None = None, + cluster_by: list[str] | None = None, + partition_by: str | None = None, + table_options: dict[str, Any] | None = None, + load_options: dict[str, Any] | None = None, + connection: str | None = None, + hive_partition_columns: list[bigquery.SchemaField] | None = None, + overwrite: bool = False, +) -> str: + """Construct a LOAD DATA DDL statement.""" + action = "OVERWRITE" if overwrite else "INTO" + + query = f"LOAD DATA {action} {googlesql.identifier(destination_table)}\n" + + if schema_fields: + columns_sql = ",\n".join(schema_field_to_sql(field) for field in schema_fields) + query += f"(\n{columns_sql}\n)\n" + + if partition_by: + query += f"PARTITION BY {partition_by}\n" + + if cluster_by: + query += f"CLUSTER BY {', '.join(cluster_by)}\n" + + if table_options: + opts_list = [] + for k, v in table_options.items(): + opts_list.append(f"{k}={simple_literal(v)}") + query += f"OPTIONS({', '.join(opts_list)})\n" + + files_opts = {} + if load_options: + files_opts.update(load_options) + + files_opts["uris"] = uris + files_opts["format"] = format + + files_opts_list = [] + for k, v in files_opts.items(): + files_opts_list.append(f"{k}={simple_literal(v)}") + + query += f"FROM FILES({', '.join(files_opts_list)})\n" + + if hive_partition_columns: + cols_sql = ",\n".join( + schema_field_to_sql(field) for field in hive_partition_columns + ) + query += f"WITH PARTITION COLUMNS (\n{cols_sql}\n)\n" + elif hive_partition_columns is not None: + query += "WITH PARTITION COLUMNS\n" + + if connection: + query += f"WITH CONNECTION {connection}\n" + + return query diff --git a/tests/unit/bigquery/test_io.py b/tests/unit/bigquery/test_io.py new file mode 100644 index 0000000000..50fdfb506c --- /dev/null +++ b/tests/unit/bigquery/test_io.py @@ -0,0 +1,163 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock + +import google.cloud.bigquery as bigquery +import bigframes.bigquery as bbq +import bigframes.session + +@pytest.fixture +def mock_session(): + session = MagicMock(spec=bigframes.session.Session) + session._storage_manager = MagicMock() + return session + +def test_load_data_minimal(mock_session): + # Setup + uris = ["gs://my-bucket/file.csv"] + format = "CSV" + destination_table = "my_project.my_dataset.my_table" + + # Execution + bbq.load_data(uris, format, destination_table, session=mock_session) + + # Verification + mock_session.read_gbq_query.assert_called_once() + sql = mock_session.read_gbq_query.call_args[0][0] + assert "LOAD DATA INTO `my_project.my_dataset.my_table`" in sql + assert "FROM FILES" in sql + assert "format='CSV'" in sql + assert "uris=['gs://my-bucket/file.csv']" in sql + + mock_session.read_gbq.assert_called_once_with(destination_table) + +def test_load_data_single_uri(mock_session): + # Setup + uris = "gs://my-bucket/file.csv" + format = "CSV" + destination_table = "t" + + # Execution + bbq.load_data(uris, format, destination_table, session=mock_session) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "uris=['gs://my-bucket/file.csv']" in sql + +def test_load_data_temp_table(mock_session): + # Setup + uris = "gs://my-bucket/file.csv" + format = "CSV" + + # Mock return of create_temp_table + mock_session._storage_manager.create_temp_table.return_value = bigquery.TableReference.from_string("p.d.t") + + # Execution + bbq.load_data(uris, format, session=mock_session) + + # Verification + mock_session._storage_manager.create_temp_table.assert_called_once() + + mock_session.read_gbq_query.assert_called_once() + sql = mock_session.read_gbq_query.call_args[0][0] + # Should use OVERWRITE for temp table we just created + assert "LOAD DATA OVERWRITE `p.d.t`" in sql + + mock_session.read_gbq.assert_called_once_with("p.d.t") + +def test_load_data_all_options(mock_session): + # Setup + uris = ["gs://file.parquet"] + format = "PARQUET" + destination_table = "dest" + schema = [ + bigquery.SchemaField("col1", "INT64", mode="REQUIRED", description="my col"), + bigquery.SchemaField("col2", "STRING") + ] + cluster_by = ["col1"] + partition_by = "col1" + options = {"description": "desc"} + load_options = {"ignore_unknown_values": True} + connection = "my_conn" + hive_partition_columns = [bigquery.SchemaField("pcol", "STRING")] + overwrite = True + + # Execution + bbq.load_data( + uris, format, destination_table, + schema=schema, + cluster_by=cluster_by, + partition_by=partition_by, + options=options, + load_options=load_options, + connection=connection, + hive_partition_columns=hive_partition_columns, + overwrite=overwrite, + session=mock_session + ) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + # Normalize newlines for easier assertion or check parts + assert "LOAD DATA OVERWRITE `dest`" in sql + assert "col1 INT64 NOT NULL OPTIONS(description='my col')" in sql + assert "col2 STRING" in sql + assert "PARTITION BY col1" in sql + assert "CLUSTER BY col1" in sql + assert "OPTIONS(description='desc')" in sql + assert "FROM FILES" in sql + assert "ignore_unknown_values=True" in sql + assert "WITH PARTITION COLUMNS" in sql + assert "pcol STRING" in sql + assert "WITH CONNECTION my_conn" in sql + +def test_load_data_hive_partition_inference(mock_session): + # Setup + uris = ["gs://file.parquet"] + format = "PARQUET" + destination_table = "dest" + + # Execution + bbq.load_data( + uris, format, destination_table, + hive_partition_columns=[], # Empty list -> Inference + session=mock_session + ) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "WITH PARTITION COLUMNS" in sql + assert "WITH PARTITION COLUMNS (" not in sql + +def test_nested_schema_generation(mock_session): + # Setup + uris = "gs://file.json" + format = "JSON" + destination_table = "dest" + schema = [ + bigquery.SchemaField("nested", "STRUCT", fields=[ + bigquery.SchemaField("sub", "INT64") + ]), + bigquery.SchemaField("arr", "INT64", mode="REPEATED") + ] + + # Execution + bbq.load_data(uris, format, destination_table, schema=schema, session=mock_session) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "nested STRUCT" in sql + assert "arr ARRAY" in sql