BigQuery to Cloud Storage TFRecords 模板

BigQuery to Cloud Storage TFRecords 模板是一种流水线,可从 BigQuery 查询读取数据并以 TFRecord 格式将其写入 Cloud Storage 存储桶。您可以指定训练、测试和验证拆分百分比。默认情况下,训练集的拆分比例为 1 或 100%,测试和验证集的拆分比例为 0 或 0%。设置数据集拆分比例时,训练、测试和验证之和加起来必须为 1 或 100%(例如,0.6 + 0.2 + 0.2)。Dataflow 会自动确定每个输出数据集的最佳分片数。

流水线要求

  • BigQuery 数据集和表必须已存在。
  • 输出 Cloud Storage 存储桶必须存在才能执行此流水线。训练、测试和验证子目录不需要预先存在,将会自动生成。

模板参数

必需参数

  • readQuery:用于从来源中提取数据的 BigQuery SQL 查询。例如 select * from dataset1.sample_table
  • outputDirectory:在写入训练、测试和验证 TFRecord 文件时使用的顶级 Cloud Storage 路径前缀。生成的训练、测试和验证 TFRecord 文件的子目录根据 outputDirectory 自动生成。例如 gs://mybucket/output

可选参数

  • readIdColumn:存储行的唯一标识符的 BigQuery 列的名称。
  • invalidOutputPath:用于写入无法转换为目标实体的 BigQuery 行的 Cloud Storage 路径。例如 gs://your-bucket/your-path
  • outputSuffix:写入的训练、测试和验证 TFRecord 文件的文件后缀。默认值为 .tfrecord
  • trainingPercentage:分配给训练 TFRecord 文件的查询数据所占的百分比。默认值为 1100%
  • testingPercentage:分配给测试 TFRecord 文件的查询数据所占的百分比。默认值为 00%
  • validationPercentage:分配给验证 TFRecord 文件的查询数据所占的百分比。默认值为 00%

运行模板

  1. 转到 Dataflow 基于模板创建作业页面。
  2. 转到“基于模板创建作业”
  3. 作业名称字段中,输入唯一的作业名称。
  4. 可选:对于区域性端点,从下拉菜单中选择一个值。默认区域为 us-central1

    如需查看可以在其中运行 Dataflow 作业的区域列表,请参阅 Dataflow 位置

  5. Dataflow 模板下拉菜单中,选择 the BigQuery to TFRecords template。
  6. 在提供的参数字段中,输入您的参数值。
  7. 点击运行作业

在 shell 或终端中,运行模板:

gcloud dataflow jobs run JOB_NAME \
    --gcs-location gs://dataflow-templates-REGION_NAME/VERSION/Cloud_BigQuery_to_GCS_TensorFlow_Records \
    --region REGION_NAME \
    --parameters \
readQuery=READ_QUERY,\
outputDirectory=OUTPUT_DIRECTORY,\
trainingPercentage=TRAINING_PERCENTAGE,\
testingPercentage=TESTING_PERCENTAGE,\
validationPercentage=VALIDATION_PERCENTAGE,\
outputSuffix=OUTPUT_FILENAME_SUFFIX

替换以下内容:

  • JOB_NAME:您选择的唯一性作业名称
  • VERSION:您要使用的模板的版本

    您可使用以下值:

  • REGION_NAME:要在其中部署 Dataflow 作业的区域,例如 us-central1
  • READ_QUERY:要运行的 BigQuery 查询
  • OUTPUT_DIRECTORY:输出数据集的 Cloud Storage 路径前缀
  • TRAINING_PERCENTAGE:训练数据集的拆分小数百分比
  • TESTING_PERCENTAGE:测试数据集的拆分小数百分比
  • VALIDATION_PERCENTAGE:验证数据集的拆分小数百分比
  • OUTPUT_FILENAME_SUFFIX:首选输出 TensorFlow 记录文件后缀

如需使用 REST API 来运行模板,请发送 HTTP POST 请求。如需详细了解 API 及其授权范围,请参阅 projects.templates.launch

POST https://dataflow.googleapis.com/v1b3/projects/PROJECT_ID/locations/LOCATION/templates:launch?gcsPath=gs://dataflow-templates-LOCATION/VERSION/Cloud_BigQuery_to_GCS_TensorFlow_Records
{
   "jobName": "JOB_NAME",
   "parameters": {
       "readQuery":"READ_QUERY",
       "outputDirectory":"OUTPUT_DIRECTORY",
       "trainingPercentage":"TRAINING_PERCENTAGE",
       "testingPercentage":"TESTING_PERCENTAGE",
       "validationPercentage":"VALIDATION_PERCENTAGE",
       "outputSuffix":"OUTPUT_FILENAME_SUFFIX"
   },
   "environment": { "zone": "us-central1-f" }
}

替换以下内容:

  • PROJECT_ID:您要在其中运行 Dataflow 作业的 Google Cloud 项目 ID
  • JOB_NAME:您选择的唯一性作业名称
  • VERSION:您要使用的模板的版本

    您可使用以下值:

  • LOCATION:要在其中部署 Dataflow 作业的区域,例如 us-central1
  • READ_QUERY:要运行的 BigQuery 查询
  • OUTPUT_DIRECTORY:输出数据集的 Cloud Storage 路径前缀
  • TRAINING_PERCENTAGE:训练数据集的拆分小数百分比
  • TESTING_PERCENTAGE:测试数据集的拆分小数百分比
  • VALIDATION_PERCENTAGE:验证数据集的拆分小数百分比
  • OUTPUT_FILENAME_SUFFIX:首选输出 TensorFlow 记录文件后缀
Java
/*
 * Copyright (C) 2019 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */
package com.google.cloud.teleport.templates;

import com.google.api.services.bigquery.model.TableFieldSchema;
import com.google.cloud.teleport.metadata.Template;
import com.google.cloud.teleport.metadata.TemplateCategory;
import com.google.cloud.teleport.metadata.TemplateParameter;
import com.google.cloud.teleport.templates.BigQueryToTFRecord.Options;
import com.google.cloud.teleport.templates.common.BigQueryConverters.BigQueryReadOptions;
import com.google.protobuf.ByteString;
import java.util.Iterator;
import java.util.Random;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO;
import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.transforms.Partition;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.tensorflow.example.Example;
import org.tensorflow.example.Feature;
import org.tensorflow.example.Features;

/**
 * Dataflow template which reads BigQuery data and writes it to GCS as a set of TFRecords. The
 * source is a SQL query.
 *
 * <p>Check out <a
 * href="https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v1/README_Cloud_BigQuery_to_GCS_TensorFlow_Records.md">README</a>
 * for instructions on how to use or modify this template.
 */
@Template(
    name = "Cloud_BigQuery_to_GCS_TensorFlow_Records",
    category = TemplateCategory.BATCH,
    displayName = "BigQuery to TensorFlow Records",
    description =
        "The BigQuery to Cloud Storage TFRecords template is a pipeline that reads data from a BigQuery query and writes it to a Cloud Storage bucket in TFRecord format. "
            + "You can specify the training, testing, and validation percentage splits. "
            + "By default, the split is 1 or 100% for the training set and 0 or 0% for testing and validation sets. "
            + "When setting the dataset split, the sum of training, testing, and validation needs to add up to 1 or 100% (for example, 0.6+0.2+0.2). "
            + "Dataflow automatically determines the optimal number of shards for each output dataset.",
    optionsClass = Options.class,
    optionsOrder = {BigQueryReadOptions.class, Options.class},
    documentation =
        "https://cloud.google.com/dataflow/docs/guides/templates/provided/bigquery-to-tfrecords",
    contactInformation = "https://cloud.google.com/support",
    requirements = {
      "The BigQuery dataset and table must exist.",
      "The output Cloud Storage bucket must exist before pipeline execution. Training, testing, and validation subdirectories don't need to preexist and are autogenerated."
    })
public class BigQueryToTFRecord {

  /**
   * The {@link BigQueryToTFRecord#buildFeatureFromIterator(Class, Object, Feature.Builder)} method
   * handles {@link GenericData.Array} that are passed into the {@link
   * BigQueryToTFRecord#buildFeature} method creating a TensorFlow feature from the record.
   */
  private static final String TRAIN = "train/";

  private static final String TEST = "test/";
  private static final String VAL = "val/";

  private static void buildFeatureFromIterator(
      Class<?> fieldType, Object field, Feature.Builder feature) {
    ByteString byteString;
    GenericData.Array f = (GenericData.Array) field;
    if (fieldType == Long.class) {
      Iterator<Long> longIterator = f.iterator();
      while (longIterator.hasNext()) {
        Long longValue = longIterator.next();
        feature.getInt64ListBuilder().addValue(longValue);
      }
    } else if (fieldType == double.class) {
      Iterator<Double> doubleIterator = f.iterator();
      while (doubleIterator.hasNext()) {
        double doubleValue = doubleIterator.next();
        feature.getFloatListBuilder().addValue((float) doubleValue);
      }
    } else if (fieldType == String.class) {
      Iterator<Utf8> stringIterator = f.iterator();
      while (stringIterator.hasNext()) {
        String stringValue = stringIterator.next().toString();
        byteString = ByteString.copyFromUtf8(stringValue);
        feature.getBytesListBuilder().addValue(byteString);
      }
    } else if (fieldType == boolean.class) {
      Iterator<Boolean> booleanIterator = f.iterator();
      while (booleanIterator.hasNext()) {
        Boolean boolValue = booleanIterator.next();
        int boolAsInt = boolValue ? 1 : 0;
        feature.getInt64ListBuilder().addValue(boolAsInt);
      }
    }
  }

  /**
   * The {@link BigQueryToTFRecord#buildFeature} method takes in an individual field and type
   * corresponding to a column value from a SchemaAndRecord Object returned from a BigQueryIO.read()
   * step. The method builds a TensorFlow Feature based on the type of the object- ie: STRING, TIME,
   * INTEGER etc..
   */
  private static Feature buildFeature(Object field, String type) {
    Feature.Builder feature = Feature.newBuilder();
    ByteString byteString;

    switch (type) {
      case "STRING":
      case "TIME":
      case "DATE":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(String.class, field, feature);
        } else {
          byteString = ByteString.copyFromUtf8(field.toString());
          feature.getBytesListBuilder().addValue(byteString);
        }
        break;
      case "BYTES":
        byteString = ByteString.copyFrom((byte[]) field);
        feature.getBytesListBuilder().addValue(byteString);
        break;
      case "INTEGER":
      case "INT64":
      case "TIMESTAMP":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(Long.class, field, feature);
        } else {
          feature.getInt64ListBuilder().addValue((long) field);
        }
        break;
      case "FLOAT":
      case "FLOAT64":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(double.class, field, feature);
        } else {
          feature.getFloatListBuilder().addValue((float) (double) field);
        }
        break;
      case "BOOLEAN":
      case "BOOL":
        if (field instanceof GenericData.Array) {
          buildFeatureFromIterator(boolean.class, field, feature);
        } else {
          int boolAsInt = (boolean) field ? 1 : 0;
          feature.getInt64ListBuilder().addValue(boolAsInt);
        }
        break;
      default:
        throw new RuntimeException("Unsupported type: " + type);
    }
    return feature.build();
  }

  /**
   * The {@link BigQueryToTFRecord#record2Example(SchemaAndRecord)} method uses takes in a
   * SchemaAndRecord Object returned from a BigQueryIO.read() step and builds a TensorFlow Example
   * from the record.
   */
  @VisibleForTesting
  protected static byte[] record2Example(SchemaAndRecord schemaAndRecord) {
    Example.Builder example = Example.newBuilder();
    Features.Builder features = example.getFeaturesBuilder();
    GenericRecord record = schemaAndRecord.getRecord();
    for (TableFieldSchema field : schemaAndRecord.getTableSchema().getFields()) {
      Object fieldValue = record.get(field.getName());
      if (fieldValue != null) {
        Feature feature = buildFeature(fieldValue, field.getType());
        features.putFeature(field.getName(), feature);
      }
    }
    return example.build().toByteArray();
  }

  /**
   * The {@link BigQueryToTFRecord#concatURI} method uses takes in a Cloud Storage URI and a
   * subdirectory name and safely concatenates them. The resulting String is used as a sink for
   * TFRecords.
   */
  private static String concatURI(String dir, String folder) {
    if (dir.endsWith("/")) {
      return dir + folder;
    } else {
      return dir + "/" + folder;
    }
  }

  /**
   * The {@link BigQueryToTFRecord#applyTrainTestValSplit} method transforms the PCollection by
   * randomly partitioning it into PCollections for each dataset.
   */
  static PCollectionList<byte[]> applyTrainTestValSplit(
      PCollection<byte[]> input,
      ValueProvider<Float> trainingPercentage,
      ValueProvider<Float> testingPercentage,
      ValueProvider<Float> validationPercentage,
      Random rand) {
    return input.apply(
        Partition.of(
            3,
            (Partition.PartitionFn<byte[]>)
                (number, numPartitions) -> {
                  Float train = trainingPercentage.get();
                  Float test = testingPercentage.get();
                  Float validation = validationPercentage.get();
                  Double d = rand.nextDouble();
                  if (train + test + validation != 1) {
                    throw new RuntimeException(
                        String.format(
                            "Train %.2f, Test %.2f, Validation"
                                + " %.2f percentages must add up to 100 percent",
                            train, test, validation));
                  }
                  if (d < train) {
                    return 0;
                  } else if (d >= train && d < train + test) {
                    return 1;
                  } else {
                    return 2;
                  }
                }));
  }

  /** Run the pipeline. */
  public static void main(String[] args) {
    Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class);
    run(options);
  }

  /**
   * Runs the pipeline to completion with the specified options. This method does not wait until the
   * pipeline is finished before returning. Invoke {@code result.waitUntilFinish()} on the result
   * object to block until the pipeline is finished running if blocking programmatic execution is
   * required.
   *
   * @param options The execution options.
   * @return The pipeline result.
   */
  public static PipelineResult run(Options options) {
    Random rand = new Random(100); // set random seed
    Pipeline pipeline = Pipeline.create(options);

    PCollection<byte[]> bigQueryToExamples =
        pipeline
            .apply(
                "RecordToExample",
                BigQueryIO.read(BigQueryToTFRecord::record2Example)
                    .fromQuery(options.getReadQuery())
                    .withCoder(ByteArrayCoder.of())
                    .withTemplateCompatibility()
                    .withoutValidation()
                    .usingStandardSql()
                    .withMethod(BigQueryIO.TypedRead.Method.DIRECT_READ)
                // Enable BigQuery Storage API
                )
            .apply("ReshuffleResults", Reshuffle.viaRandomKey());

    PCollectionList<byte[]> partitionedExamples =
        applyTrainTestValSplit(
            bigQueryToExamples,
            options.getTrainingPercentage(),
            options.getTestingPercentage(),
            options.getValidationPercentage(),
            rand);

    partitionedExamples
        .get(0)
        .apply(
            "WriteTFTrainingRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, TRAIN)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    partitionedExamples
        .get(1)
        .apply(
            "WriteTFTestingRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, TEST)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    partitionedExamples
        .get(2)
        .apply(
            "WriteTFValidationRecord",
            FileIO.<byte[]>write()
                .via(TFRecordIO.sink())
                .to(
                    ValueProvider.NestedValueProvider.of(
                        options.getOutputDirectory(), dir -> concatURI(dir, VAL)))
                .withNumShards(0)
                .withSuffix(options.getOutputSuffix()));

    return pipeline.run();
  }

  /** Define command line arguments. */
  public interface Options extends BigQueryReadOptions {

    @TemplateParameter.GcsWriteFolder(
        order = 1,
        groupName = "Target",
        description = "Output Cloud Storage directory.",
        helpText =
            "The top-level Cloud Storage path prefix to use when writing the training, testing, and validation TFRecord files. Subdirectories for resulting training, testing, and validation TFRecord files are automatically generated from `outputDirectory`.",
        example = "gs://mybucket/output")
    ValueProvider<String> getOutputDirectory();

    void setOutputDirectory(ValueProvider<String> outputDirectory);

    @TemplateParameter.Text(
        order = 2,
        groupName = "Target",
        optional = true,
        regexes = {"^[A-Za-z_0-9.]*"},
        description = "The output suffix for TFRecord files",
        helpText =
            "The file suffix for the training, testing, and validation TFRecord files that are written. The default value is `.tfrecord`.")
    @Default.String(".tfrecord")
    ValueProvider<String> getOutputSuffix();

    void setOutputSuffix(ValueProvider<String> outputSuffix);

    @TemplateParameter.Float(
        order = 3,
        optional = true,
        description = "Percentage of data to be in the training set ",
        helpText =
            "The percentage of query data allocated to training TFRecord files. The default value is `1`, or `100%`.")
    @Default.Float(1)
    ValueProvider<Float> getTrainingPercentage();

    void setTrainingPercentage(ValueProvider<Float> trainingPercentage);

    @TemplateParameter.Float(
        order = 4,
        optional = true,
        description = "Percentage of data to be in the testing set ",
        helpText =
            "The percentage of query data allocated to testing TFRecord files. The default value is `0`, or `0%`.")
    @Default.Float(0)
    ValueProvider<Float> getTestingPercentage();

    void setTestingPercentage(ValueProvider<Float> testingPercentage);

    @TemplateParameter.Float(
        order = 5,
        optional = true,
        description = "Percentage of data to be in the validation set ",
        helpText =
            "The percentage of query data allocated to validation TFRecord files. The default value is `0`, or `0%`.")
    @Default.Float(0)
    ValueProvider<Float> getValidationPercentage();

    void setValidationPercentage(ValueProvider<Float> validationPercentage);
  }
}

后续步骤