Plantilla TFRecords de BigQuery a Cloud Storage

La plantilla de datos de BigQuery a Cloud Storage en TFRecord es una canalización que lee datos de una consulta de BigQuery y los escribe en un bucket de Cloud Storage en formato TFRecord. Puedes especificar las divisiones de porcentaje de entrenamiento, prueba y validación. De forma predeterminada, la división es 1 o 100% para el conjunto de entrenamiento y 0 o 0% para los conjuntos de prueba y validación. Cuando configuras la división del conjunto de datos, la suma del entrenamiento, la prueba y la validación debe ser hasta 1 o 100% (por ejemplo, 0.6 + 0.2 + 0.2). Dataflow determina de forma automática la cantidad óptima de fragmentos para cada conjunto de datos de salida.

Requisitos de la canalización

  • El conjunto de datos y la tabla de BigQuery deben existir.
  • El bucket de Cloud Storage de salida debe existir antes de la ejecución de la canalización. Los subdirectorios de entrenamiento, prueba y validación no necesitan preexistir y se generan automáticamente.

Parámetros de la plantilla

Parámetros obligatorios

  • readQuery: Una consulta de BigQuery SQL que extrae datos de la fuente. Por ejemplo, select * from dataset1.sample_table
  • outputDirectory: El prefijo de ruta de acceso de Cloud Storage de nivel superior que se usará cuando se escriban los archivos TFRecord de entrenamiento, prueba y validación. Los subdirectorios de los archivos de TFRecord para la capacitación, la prueba y la validación resultantes se generan automáticamente a partir de outputDirectory. Por ejemplo, gs://mybucket/output.

Parámetros opcionales

  • readIdColumn: Es el nombre de la columna de BigQuery que almacena el identificador único de la fila.
  • invalidOutputPath: Es la ruta de acceso de Cloud Storage donde se escriben las filas de BigQuery que no se pueden convertir en entidades de destino. Por ejemplo, gs://your-bucket/your-path
  • outputSuffix: Es el sufijo del archivo para los archivos TFRecord de entrenamiento, prueba y validación que se escriben. El valor predeterminado es .tfrecord.
  • trainingPercentage: El porcentaje de datos de consulta asignados al entrenamiento de archivos TFRecord. El valor predeterminado es 1 o 100%.
  • testingPercentage: El porcentaje de datos de consulta asignados para probar archivos TFRecord. El valor predeterminado es 0 o 0%.
  • validationPercentage: El porcentaje de datos de consulta asignados a los archivos TFRecord de validación. El valor predeterminado es 0 o 0%.

Ejecuta la plantilla

  1. Ve a la página Crear un trabajo a partir de una plantilla de Dataflow.
  2. Ir a Crear un trabajo a partir de una plantilla
  3. En el campo Nombre del trabajo, ingresa un nombre de trabajo único.
  4. Opcional: Para Extremo regional, selecciona un valor del menú desplegable. La región predeterminada es us-central1.

    Para obtener una lista de regiones en las que puedes ejecutar un trabajo de Dataflow, consulta Ubicaciones de Dataflow.

  5. En el menú desplegable Plantilla de Dataflow, selecciona the BigQuery to TFRecords template.
  6. En los campos de parámetros proporcionados, ingresa los valores de tus parámetros.
  7. Haga clic en Ejecutar trabajo.

En tu shell o terminal, ejecuta la plantilla:

gcloud dataflow jobs run JOB_NAME \
    --gcs-location gs://dataflow-templates-REGION_NAME/VERSION/Cloud_BigQuery_to_GCS_TensorFlow_Records \
    --region REGION_NAME \
    --parameters \
readQuery=READ_QUERY,\
outputDirectory=OUTPUT_DIRECTORY,\
trainingPercentage=TRAINING_PERCENTAGE,\
testingPercentage=TESTING_PERCENTAGE,\
validationPercentage=VALIDATION_PERCENTAGE,\
outputSuffix=OUTPUT_FILENAME_SUFFIX

Reemplaza lo siguiente:

  • JOB_NAME: Es el nombre del trabajo que elijas
  • VERSION: Es la versión de la plantilla que deseas usar.

    Puedes usar los siguientes valores:

    • latest para usar la última versión de la plantilla, que está disponible en la carpeta superior non-dated en el bucket gs://dataflow-templates-REGION_NAME/latest/
    • el nombre de la versión, como 2023-09-12-00_RC00, para usar una versión específica de la plantilla, que se puede encontrar anidada en la carpeta superior con fecha correspondiente en el bucket gs://dataflow-templates-REGION_NAME/
  • REGION_NAME: La región en la que deseas implementar tu trabajo de Dataflow, por ejemplo, us-central1
  • READ_QUERY: Es la consulta de BigQuery que se ejecutará.
  • OUTPUT_DIRECTORY: Es el prefijo de la ruta de acceso de Cloud Storage para los conjuntos de datos de salida
  • TRAINING_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de entrenamiento
  • TESTING_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de prueba
  • VALIDATION_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de validación
  • OUTPUT_FILENAME_SUFFIX: Es el sufijo de archivo de registro de salida de TensorFlow preferido.

Para ejecutar la plantilla con la API de REST, envía una solicitud POST HTTP. Para obtener más información de la API y sus permisos de autorización, consulta projects.templates.launch.

POST https://dataflow.googleapis.com/v1b3/projects/PROJECT_ID/locations/LOCATION/templates:launch?gcsPath=gs://dataflow-templates-LOCATION/VERSION/Cloud_BigQuery_to_GCS_TensorFlow_Records
{
   "jobName": "JOB_NAME",
   "parameters": {
       "readQuery":"READ_QUERY",
       "outputDirectory":"OUTPUT_DIRECTORY",
       "trainingPercentage":"TRAINING_PERCENTAGE",
       "testingPercentage":"TESTING_PERCENTAGE",
       "validationPercentage":"VALIDATION_PERCENTAGE",
       "outputSuffix":"OUTPUT_FILENAME_SUFFIX"
   },
   "environment": { "zone": "us-central1-f" }
}

Reemplaza lo siguiente:

  • PROJECT_ID: El ID del proyecto de Google Cloud en el que deseas ejecutar el trabajo de Dataflow.
  • JOB_NAME: Es el nombre del trabajo que elijas
  • VERSION: Es la versión de la plantilla que deseas usar.

    Puedes usar los siguientes valores:

    • latest para usar la última versión de la plantilla, que está disponible en la carpeta superior non-dated en el bucket gs://dataflow-templates-REGION_NAME/latest/
    • el nombre de la versión, como 2023-09-12-00_RC00, para usar una versión específica de la plantilla, que se puede encontrar anidada en la carpeta superior con fecha correspondiente en el bucket gs://dataflow-templates-REGION_NAME/
  • LOCATION: La región en la que deseas implementar tu trabajo de Dataflow, por ejemplo, us-central1
  • READ_QUERY: Es la consulta de BigQuery que se ejecutará.
  • OUTPUT_DIRECTORY: Es el prefijo de la ruta de acceso de Cloud Storage para los conjuntos de datos de salida
  • TRAINING_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de entrenamiento
  • TESTING_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de prueba
  • VALIDATION_PERCENTAGE: Es la división de porcentaje decimal para el conjunto de datos de validación
  • OUTPUT_FILENAME_SUFFIX: Es el sufijo de archivo de registro de salida de TensorFlow preferido.
Java
/*
 * Copyright (C) 2019 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.teleport.templates;

import com.google.api.services.bigquery.model.TableFieldSchema;
import com.google.cloud.teleport.metadata.Template;
import com.google.cloud.teleport.metadata.TemplateCategory;
import com.google.cloud.teleport.metadata.TemplateParameter;
import com.google.cloud.teleport.templates.BigQueryToTFRecord.Options;
import com.google.cloud.teleport.templates.common.BigQueryConverters.BigQueryReadOptions;
import com.google.protobuf.ByteString;
import java.util.Iterator;
import java.util.Random;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO;
import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.transforms.Partition;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.tensorflow.proto.Example;
import org.tensorflow.proto.Feature;
import org.tensorflow.proto.Features;

/**
 * Dataflow template which reads BigQuery data and writes it to GCS as a set of TFRecords. The
 * source is a SQL query.
 *
 * <p>Check out <a
 * href="https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v1/README_Cloud_BigQuery_to_GCS_TensorFlow_Records.md">README</a>
 * for instructions on how to use or modify this template.
 */
@Template(
    name = "Cloud_BigQuery_to_GCS_TensorFlow_Records",
    category = TemplateCategory.BATCH,
    displayName = "BigQuery to TensorFlow Records",
    description =
        "The BigQuery to Cloud Storage TFRecords template is a pipeline that reads data from a BigQuery query and writes it to a Cloud Storage bucket in TFRecord format. "
            + "You can specify the training, testing, and validation percentage splits. "
            + "By default, the split is 1 or 100% for the training set and 0 or 0% for testing and validation sets. "
            + "When setting the dataset split, the sum of training, testing, and validation needs to add up to 1 or 100% (for example, 0.6+0.2+0.2). "
            + "Dataflow automatically determines the optimal number of shards for each output dataset.",
    optionsClass = Options.class,
    optionsOrder = {BigQueryReadOptions.class, Options.class},
    documentation =
        "https://cloud.google.com/dataflow/docs/guides/templates/provided/bigquery-to-tfrecords",
    contactInformation = "https://cloud.google.com/support",
    requirements = {
      "The BigQuery dataset and table must exist.",
      "The output Cloud Storage bucket must exist before pipeline execution. Training, testing, and validation subdirectories don't need to preexist and are autogenerated."
    })
public class BigQueryToTFRecord {

  /**
   * The {@link BigQueryToTFRecord#buildFeatureFromIterator(Class, Object, Feature.Builder)} method
   * handles {@link GenericData.Array} that are passed into the {@link
   * BigQueryToTFRecord#buildFeature} method creating a TensorFlow feature from the record.
   */
  private static final String TRAIN = "train/";

  private static final String TEST = "test/";
  private static final String VAL = "val/";

  private static void buildFeatureFromIterator(
      Class<?> fieldType, Object field, Feature.Builder feature) {
    ByteString byteString;
    GenericData.Array f = (GenericData.Array) field;
    if (fieldType == Long.class) {
      Iterator<Long> longIterator = f.iterator();
      while (longIterator.hasNext()) {
        Long longValue = longIterator.next();
        feature.getInt64ListBuilder().addValue(longValue);
      }
    } else if (fieldType == double.class) {
      Iterator<Double> doubleIterator = f.iterator();
      while (doubleIterator.hasNext()) {
        double doubleValue = doubleIterator.next();
        feature.getFloatListBuilder().addValue((float) doubleValue);
      }
    } else if (fieldType == String.class) {
      Iterator<Utf8> stringIterator = f.iterator();
      while (stringIterator.hasNext()) {
        String stringValue = stringIterator.next().toString();
        byteString = ByteString.copyFromUtf8(stringValue);
        feature.getBytesListBuilder().addValue(byteString);
      }
    } else if (fieldType == boolean.class) {
      Iterator<Boolean> booleanIterator = f.iterator();
      while (booleanIterator.hasNext()) {
        Boolean boolValue = booleanIterator.next();
        int boolAsInt = boolValue ? 1 : 0;
        feature.getInt64ListBuilder().addValue(boolAsInt);
      }
    }
  }

  /**
   * The {@link BigQueryToTFRecord#buildFeature} method takes in an individual field and type
   * corresponding to a column value from a SchemaAndRecord Object returned from a BigQueryIO.read()
   * step. The method builds a TensorFlow Feature based on the type of the object- ie: STRING, TIME,
   * INTEGER etc..
   */
  private static Feature buildFeature(Object field, String type) {
    Feature.Builder feature = Feature.newBuilder();
    ByteString byteString;

    switch (type) {
      case "STRING":
      case "TIME":
      case "DATE":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(String.class, field, feature);
        } else {
          byteString = ByteString.copyFromUtf8(field.toString());
          feature.getBytesListBuilder().addValue(byteString);
        }
        break;
      case "BYTES":
        byteString = ByteString.copyFrom((byte[]) field);
        feature.getBytesListBuilder().addValue(byteString);
        break;
      case "INTEGER":
      case "INT64":
      case "TIMESTAMP":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(Long.class, field, feature);
        } else {
          feature.getInt64ListBuilder().addValue((long) field);
        }
        break;
      case "FLOAT":
      case "FLOAT64":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(double.class, field, feature);
        } else {
          feature.getFloatListBuilder().addValue((float) (double) field);
        }
        break;
      case "BOOLEAN":
      case "BOOL":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(boolean.class, field, feature);
        } else {
          int boolAsInt = (boolean) field ? 1 : 0;
          feature.getInt64ListBuilder().addValue(boolAsInt);
        }
        break;
      default:
        throw new RuntimeException("Unsupported type: " + type);
    }
    return feature.build();
  }

  /**
   * The {@link BigQueryToTFRecord#record2Example(SchemaAndRecord)} method uses takes in a
   * SchemaAndRecord Object returned from a BigQueryIO.read() step and builds a TensorFlow Example
   * from the record.
   */
  @VisibleForTesting
  protected static byte[] record2Example(SchemaAndRecord schemaAndRecord) {
    Example.Builder example = Example.newBuilder();
    Features.Builder features = example.getFeaturesBuilder();
    GenericRecord record = schemaAndRecord.getRecord();
    for (TableFieldSchema field : schemaAndRecord.getTableSchema().getFields()) {
      Object fieldValue = record.get(field.getName());
      if (fieldValue != null) {
        Feature feature = buildFeature(fieldValue, field.getType());
        features.putFeature(field.getName(), feature);
      }
    }
    return example.build().toByteArray();
  }

  /**
   * The {@link BigQueryToTFRecord#concatURI} method uses takes in a Cloud Storage URI and a
   * subdirectory name and safely concatenates them. The resulting String is used as a sink for
   * TFRecords.
   */
  private static String concatURI(String dir, String folder) {
    if (dir.endsWith("/")) {
      return dir + folder;
    } else {
      return dir + "/" + folder;
    }
  }

  /**
   * The {@link BigQueryToTFRecord#applyTrainTestValSplit} method transforms the PCollection by
   * randomly partitioning it into PCollections for each dataset.
   */
  static PCollectionList<byte[]> applyTrainTestValSplit(
      PCollection<byte[]> input,
      ValueProvider<Float> trainingPercentage,
      ValueProvider<Float> testingPercentage,
      ValueProvider<Float> validationPercentage,
      Random rand) {
    return input.apply(
        Partition.of(
            3,
            (Partition.PartitionFn<byte[]>)
                (number, numPartitions) -> {
                  Float train = trainingPercentage.get();
                  Float test = testingPercentage.get();
                  Float validation = validationPercentage.get();
                  Double d = rand.nextDouble();
                  if (train + test + validation != 1) {
                    throw new RuntimeException(
                        String.format(
                            "Train %.2f, Test %.2f, Validation"
                                + " %.2f percentages must add up to 100 percent",
                            train, test, validation));
                  }
                  if (d < train) {
                    return 0;
                  } else if (d >= train && d < train + test) {
                    return 1;
                  } else {
                    return 2;
                  }
                }));
  }

  /** Run the pipeline. */
  public static void main(String[] args) {
    Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class);
    run(options);
  }

  /**
   * Runs the pipeline to completion with the specified options. This method does not wait until the
   * pipeline is finished before returning. Invoke {@code result.waitUntilFinish()} on the result
   * object to block until the pipeline is finished running if blocking programmatic execution is
   * required.
   *
   * @param options The execution options.
   * @return The pipeline result.
   */
  public static PipelineResult run(Options options) {
    Random rand = new Random(100); // set random seed
    Pipeline pipeline = Pipeline.create(options);

    PCollection<byte[]> bigQueryToExamples =
        pipeline
            .apply(
                "RecordToExample",
                BigQueryIO.read(BigQueryToTFRecord::record2Example)
                    .fromQuery(options.getReadQuery())
                    .withCoder(ByteArrayCoder.of())
                    .withTemplateCompatibility()
                    .withoutValidation()
                    .usingStandardSql()
                    .withMethod(BigQueryIO.TypedRead.Method.DIRECT_READ)
                // Enable BigQuery Storage API
                )
            .apply("ReshuffleResults", Reshuffle.viaRandomKey());

    PCollectionList<byte[]> partitionedExamples =
        applyTrainTestValSplit(
            bigQueryToExamples,
            options.getTrainingPercentage(),
            options.getTestingPercentage(),
            options.getValidationPercentage(),
            rand);

    partitionedExamples
        .get(0)
        .apply(
            "WriteTFTrainingRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, TRAIN)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    partitionedExamples
        .get(1)
        .apply(
            "WriteTFTestingRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, TEST)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    partitionedExamples
        .get(2)
        .apply(
            "WriteTFValidationRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, VAL)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    return pipeline.run();
  }

  /** Define command line arguments. */
  public interface Options extends BigQueryReadOptions {

    @TemplateParameter.GcsWriteFolder(
        order = 1,
        groupName = "Target",
        description = "Output Cloud Storage directory.",
        helpText =
            "The top-level Cloud Storage path prefix to use when writing the training, testing, and validation TFRecord files. Subdirectories for resulting training, testing, and validation TFRecord files are automatically generated from `outputDirectory`.",
        example = "gs://mybucket/output")
    ValueProvider<String> getOutputDirectory();

    void setOutputDirectory(ValueProvider<String> outputDirectory);

    @TemplateParameter.Text(
        order = 2,
        groupName = "Target",
        optional = true,
        regexes = {"^[A-Za-z_0-9.]*"},
        description = "The output suffix for TFRecord files",
        helpText =
            "The file suffix for the training, testing, and validation TFRecord files that are written. The default value is `.tfrecord`.")
    @Default.String(".tfrecord")
    ValueProvider<String> getOutputSuffix();

    void setOutputSuffix(ValueProvider<String> outputSuffix);

    @TemplateParameter.Float(
        order = 3,
        optional = true,
        description = "Percentage of data to be in the training set ",
        helpText =
            "The percentage of query data allocated to training TFRecord files. The default value is `1`, or `100%`.")
    @Default.Float(1)
    ValueProvider<Float> getTrainingPercentage();

    void setTrainingPercentage(ValueProvider<Float> trainingPercentage);

    @TemplateParameter.Float(
        order = 4,
        optional = true,
        description = "Percentage of data to be in the testing set ",
        helpText =
            "The percentage of query data allocated to testing TFRecord files. The default value is `0`, or `0%`.")
    @Default.Float(0)
    ValueProvider<Float> getTestingPercentage();

    void setTestingPercentage(ValueProvider<Float> testingPercentage);

    @TemplateParameter.Float(
        order = 5,
        optional = true,
        description = "Percentage of data to be in the validation set ",
        helpText =
            "The percentage of query data allocated to validation TFRecord files. The default value is `0`, or `0%`.")
    @Default.Float(0)
    ValueProvider<Float> getValidationPercentage();

    void setValidationPercentage(ValueProvider<Float> validationPercentage);
  }
}

¿Qué sigue?