diff --git a/bigframes/bigquery/__init__.py b/bigframes/bigquery/__init__.py index 7a7a01a8fc..7781ee1576 100644 --- a/bigframes/bigquery/__init__.py +++ b/bigframes/bigquery/__init__.py @@ -43,6 +43,7 @@ st_regionstats, st_simplify, ) +from bigframes.bigquery._operations.io import load_data from bigframes.bigquery._operations.json import ( json_extract, json_extract_array, @@ -85,6 +86,8 @@ st_length, st_regionstats, st_simplify, + # io ops + load_data, # json ops json_extract, json_extract_array, @@ -135,6 +138,8 @@ "st_length", "st_regionstats", "st_simplify", + # io ops + "load_data", # json ops "json_extract", "json_extract_array", diff --git a/bigframes/bigquery/_operations/io.py b/bigframes/bigquery/_operations/io.py new file mode 100644 index 0000000000..5503f943e8 --- /dev/null +++ b/bigframes/bigquery/_operations/io.py @@ -0,0 +1,116 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import typing +from typing import Any, List, Optional + +import google.cloud.bigquery as bigquery + +import bigframes.core.sql + +if typing.TYPE_CHECKING: + import bigframes.dataframe as dataframe + import bigframes.session + +_PLACEHOLDER_SCHEMA = [ + bigquery.SchemaField("bf_load_placeholder", "INT64"), +] + + +def load_data( + uris: str | List[str], + format: str, + destination_table: Optional[str] = None, + *, + schema: Optional[List[bigquery.SchemaField]] = None, + cluster_by: Optional[List[str]] = None, + partition_by: Optional[str] = None, + options: Optional[dict[str, Any]] = None, + load_options: Optional[dict[str, Any]] = None, + connection: Optional[str] = None, + hive_partition_columns: Optional[List[bigquery.SchemaField]] = None, + overwrite: bool = False, + session: Optional[bigframes.session.Session] = None, +) -> dataframe.DataFrame: + """ + Loads data from external files into a BigQuery table using the `LOAD DATA` statement. + + Args: + uris (str | List[str]): + The fully qualified URIs for the external data locations (e.g., 'gs://bucket/path/file.csv'). + format (str): + The format of the external data (e.g., 'CSV', 'PARQUET', 'AVRO', 'JSON'). + destination_table (str, optional): + The name of the destination table. If not specified, a temporary table will be created. + schema (List[google.cloud.bigquery.SchemaField], optional): + The schema of the destination table. If not provided, schema auto-detection will be used. + cluster_by (List[str], optional): + A list of columns to cluster the table by. + partition_by (str, optional): + The partition expression for the table. + options (dict[str, Any], optional): + Table options (e.g., {'description': 'my table'}). + load_options (dict[str, Any], optional): + Options for loading data (e.g., {'skip_leading_rows': 1}). + connection (str, optional): + The connection name to use for reading external data. + hive_partition_columns (List[google.cloud.bigquery.SchemaField], optional): + The external partitioning columns. If set to an empty list, partitioning is inferred. + overwrite (bool, default False): + If True, overwrites the destination table. If False, appends to it. + session (bigframes.session.Session, optional): + The session to use. If not provided, the default session is used. + + Returns: + bigframes.dataframe.DataFrame: A DataFrame representing the loaded table. + """ + import bigframes.pandas as bpd + + if session is None: + session = bpd.get_global_session() + + if isinstance(uris, str): + uris = [uris] + + if destination_table is None: + # Create a temporary table name + # We need to access the storage manager from the session + # This is internal API usage, but requested by the user + table_ref = session._storage_manager.create_temp_table(_PLACEHOLDER_SCHEMA) + destination_table = f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + # Since we created a placeholder table, we must overwrite it + overwrite = True + + sql = bigframes.core.sql.load_data_ddl( + destination_table=destination_table, + uris=uris, + format=format, + schema_fields=schema, + cluster_by=cluster_by, + partition_by=partition_by, + table_options=options, + load_options=load_options, + connection=connection, + hive_partition_columns=hive_partition_columns, + overwrite=overwrite, + ) + + # Execute the LOAD DATA statement + session.read_gbq_query(sql) + + # Return a DataFrame pointing to the destination table + # We use session.read_gbq to ensure it uses the same session + return session.read_gbq(destination_table) diff --git a/bigframes/core/sql/__init__.py b/bigframes/core/sql/__init__.py index ccd2a16ddc..608db1d1c3 100644 --- a/bigframes/core/sql/__init__.py +++ b/bigframes/core/sql/__init__.py @@ -21,7 +21,7 @@ import decimal import json import math -from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union import shapely.geometry.base # type: ignore @@ -246,3 +246,94 @@ def create_vector_search_sql( distance, FROM VECTOR_SEARCH({args_str}) """ + + +def _field_type_to_sql(field: bigquery.SchemaField) -> str: + if field.field_type in ("RECORD", "STRUCT"): + sub_defs = [] + for sub in field.fields: + sub_type = _field_type_to_sql(sub) + sub_def = f"{sub.name} {sub_type}" + if sub.mode == "REQUIRED": + sub_def += " NOT NULL" + sub_defs.append(sub_def) + type_str = f"STRUCT<{', '.join(sub_defs)}>" + else: + type_str = field.field_type + + if field.mode == "REPEATED": + return f"ARRAY<{type_str}>" + return type_str + + +def schema_field_to_sql(field: bigquery.SchemaField) -> str: + """Convert a BigQuery SchemaField to a SQL DDL column definition.""" + type_sql = _field_type_to_sql(field) + sql = f"{field.name} {type_sql}" + if field.mode == "REQUIRED": + sql += " NOT NULL" + if field.description: + sql += f" OPTIONS(description={simple_literal(field.description)})" + return sql + + +def load_data_ddl( + destination_table: str, + uris: list[str], + format: str, + *, + schema_fields: list[bigquery.SchemaField] | None = None, + cluster_by: list[str] | None = None, + partition_by: str | None = None, + table_options: dict[str, Any] | None = None, + load_options: dict[str, Any] | None = None, + connection: str | None = None, + hive_partition_columns: list[bigquery.SchemaField] | None = None, + overwrite: bool = False, +) -> str: + """Construct a LOAD DATA DDL statement.""" + action = "OVERWRITE" if overwrite else "INTO" + + query = f"LOAD DATA {action} {googlesql.identifier(destination_table)}\n" + + if schema_fields: + columns_sql = ",\n".join(schema_field_to_sql(field) for field in schema_fields) + query += f"(\n{columns_sql}\n)\n" + + if partition_by: + query += f"PARTITION BY {partition_by}\n" + + if cluster_by: + query += f"CLUSTER BY {', '.join(cluster_by)}\n" + + if table_options: + opts_list = [] + for k, v in table_options.items(): + opts_list.append(f"{k}={simple_literal(v)}") + query += f"OPTIONS({', '.join(opts_list)})\n" + + files_opts = {} + if load_options: + files_opts.update(load_options) + + files_opts["uris"] = uris + files_opts["format"] = format + + files_opts_list = [] + for k, v in files_opts.items(): + files_opts_list.append(f"{k}={simple_literal(v)}") + + query += f"FROM FILES({', '.join(files_opts_list)})\n" + + if hive_partition_columns: + cols_sql = ",\n".join( + schema_field_to_sql(field) for field in hive_partition_columns + ) + query += f"WITH PARTITION COLUMNS (\n{cols_sql}\n)\n" + elif hive_partition_columns is not None: + query += "WITH PARTITION COLUMNS\n" + + if connection: + query += f"WITH CONNECTION {connection}\n" + + return query diff --git a/tests/unit/bigquery/test_io.py b/tests/unit/bigquery/test_io.py new file mode 100644 index 0000000000..50fdfb506c --- /dev/null +++ b/tests/unit/bigquery/test_io.py @@ -0,0 +1,163 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock + +import google.cloud.bigquery as bigquery +import bigframes.bigquery as bbq +import bigframes.session + +@pytest.fixture +def mock_session(): + session = MagicMock(spec=bigframes.session.Session) + session._storage_manager = MagicMock() + return session + +def test_load_data_minimal(mock_session): + # Setup + uris = ["gs://my-bucket/file.csv"] + format = "CSV" + destination_table = "my_project.my_dataset.my_table" + + # Execution + bbq.load_data(uris, format, destination_table, session=mock_session) + + # Verification + mock_session.read_gbq_query.assert_called_once() + sql = mock_session.read_gbq_query.call_args[0][0] + assert "LOAD DATA INTO `my_project.my_dataset.my_table`" in sql + assert "FROM FILES" in sql + assert "format='CSV'" in sql + assert "uris=['gs://my-bucket/file.csv']" in sql + + mock_session.read_gbq.assert_called_once_with(destination_table) + +def test_load_data_single_uri(mock_session): + # Setup + uris = "gs://my-bucket/file.csv" + format = "CSV" + destination_table = "t" + + # Execution + bbq.load_data(uris, format, destination_table, session=mock_session) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "uris=['gs://my-bucket/file.csv']" in sql + +def test_load_data_temp_table(mock_session): + # Setup + uris = "gs://my-bucket/file.csv" + format = "CSV" + + # Mock return of create_temp_table + mock_session._storage_manager.create_temp_table.return_value = bigquery.TableReference.from_string("p.d.t") + + # Execution + bbq.load_data(uris, format, session=mock_session) + + # Verification + mock_session._storage_manager.create_temp_table.assert_called_once() + + mock_session.read_gbq_query.assert_called_once() + sql = mock_session.read_gbq_query.call_args[0][0] + # Should use OVERWRITE for temp table we just created + assert "LOAD DATA OVERWRITE `p.d.t`" in sql + + mock_session.read_gbq.assert_called_once_with("p.d.t") + +def test_load_data_all_options(mock_session): + # Setup + uris = ["gs://file.parquet"] + format = "PARQUET" + destination_table = "dest" + schema = [ + bigquery.SchemaField("col1", "INT64", mode="REQUIRED", description="my col"), + bigquery.SchemaField("col2", "STRING") + ] + cluster_by = ["col1"] + partition_by = "col1" + options = {"description": "desc"} + load_options = {"ignore_unknown_values": True} + connection = "my_conn" + hive_partition_columns = [bigquery.SchemaField("pcol", "STRING")] + overwrite = True + + # Execution + bbq.load_data( + uris, format, destination_table, + schema=schema, + cluster_by=cluster_by, + partition_by=partition_by, + options=options, + load_options=load_options, + connection=connection, + hive_partition_columns=hive_partition_columns, + overwrite=overwrite, + session=mock_session + ) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + # Normalize newlines for easier assertion or check parts + assert "LOAD DATA OVERWRITE `dest`" in sql + assert "col1 INT64 NOT NULL OPTIONS(description='my col')" in sql + assert "col2 STRING" in sql + assert "PARTITION BY col1" in sql + assert "CLUSTER BY col1" in sql + assert "OPTIONS(description='desc')" in sql + assert "FROM FILES" in sql + assert "ignore_unknown_values=True" in sql + assert "WITH PARTITION COLUMNS" in sql + assert "pcol STRING" in sql + assert "WITH CONNECTION my_conn" in sql + +def test_load_data_hive_partition_inference(mock_session): + # Setup + uris = ["gs://file.parquet"] + format = "PARQUET" + destination_table = "dest" + + # Execution + bbq.load_data( + uris, format, destination_table, + hive_partition_columns=[], # Empty list -> Inference + session=mock_session + ) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "WITH PARTITION COLUMNS" in sql + assert "WITH PARTITION COLUMNS (" not in sql + +def test_nested_schema_generation(mock_session): + # Setup + uris = "gs://file.json" + format = "JSON" + destination_table = "dest" + schema = [ + bigquery.SchemaField("nested", "STRUCT", fields=[ + bigquery.SchemaField("sub", "INT64") + ]), + bigquery.SchemaField("arr", "INT64", mode="REPEATED") + ] + + # Execution + bbq.load_data(uris, format, destination_table, schema=schema, session=mock_session) + + # Verification + sql = mock_session.read_gbq_query.call_args[0][0] + assert "nested STRUCT" in sql + assert "arr ARRAY" in sql