Cloud TPU v5p 训练

Cloud TPU v5p 是 Google Cloud 的第五代 Cloud TPU 和 v4 TPU 的继任者。v5p 针对大规模训练进行了优化,是开发基础 LLM、扩散模型和生成式 AI 的领先平台。概括来讲,v5p 的性能高达 v4 的 2 倍,同时还在 Pod 中封装了 2 倍的 TPU(最大切片为 6k,而 v4 中为 3k),从而使 Pod 级的性能高达 v4 的 4 倍。它还以更高的时钟频率运行(1.75Ghz 与 1.05Ghz),增加了 SparseCore 以用于大规模嵌入,并将高带宽内存 (HBM) 容量增加到了原来的三倍。

Cloud TPU v5p 概念

如果您刚开始接触 Cloud TPU,请查看 TPU 文档首页

Cloud TPU 系统架构页面介绍了所有 Cloud TPU 版本的 Cloud TPU 概念(例如切片、主机和 TensorCore)及 Cloud TPU 系统架构。

每个 Cloud TPU 版本都需要特定的加速器类型来进行训练或推理。v5p 配置中介绍了这些加速器类型。

管理 TPU 资源

如需了解如何管理已加入队列的资源,请参阅管理 TPU已加入队列的资源用户指南,了解可用于管理 TPU 虚拟机的所有命令。

框架设置

本部分介绍使用 JAX 或 PyTorch 与 TPU v5p 进行模型训练的一般设置过程。

JAX 设置

如果您的切片形状超过 4 个条状标签,则一个切片中会有多个虚拟机。在这种情况下,您需要使用 --worker=all 标志,通过一条命令在所有 TPU 虚拟机上运行安装:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

您可以运行以下命令来检查设备数量(此处显示的输出是使用 v5p-32 切片生成的)。此代码通过检查 JAX 能否看到 Cloud TPU TensorCore 并能够运行基本操作来测试所有内容是否已正确安装:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

输出将类似于以下内容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() 显示给定切片中的条状标签总数。jax.local_device_count() 表示此切片中单个虚拟机可访问的芯片数量。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

输出将类似于以下内容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

使用 --node=all 在所有多切片工作器上运行该命令。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

试用本文档中的 JAX 教程,开始使用 JAX 进行 v5p 训练。

设置 PyTorch

PJRT 运行时是 v5p 唯一支持的运行时,PyTorch 2.1 及更高版本使用 PJRT 作为所有 TPU 版本的默认运行时。本部分介绍如何开始在具有 PyTorch/XLA 2.2.0 的 v5p Pod 上使用所有工作器的 PJRT。

安装依赖项

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
sudo apt-get update
sudo apt-get install libopenblas-dev -y
pip3 install numpy
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
'

使用带有 PJRT 的 Python 脚本验证安装,以显示可用的 TPU 设备(此处显示的输出是使用 v5p-32 切片生成的)。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']

使用 --node=all 在所有多切片工作器上运行该命令。

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'

试用本文档中的 PyTorch 教程,开始使用 PyTorch 进行 v5p 训练。

监控和分析

Cloud TPU v5p 支持使用与上一代 Cloud TPU 相同的方法进行监控和分析。您可以阅读使用 Cloud TPU 工具分析模型,详细了解如何分析和监控 Cloud TPU 虚拟机,以详细了解如何监控。

训练教程

本部分重点介绍单个切片的训练教程。您可以通过在 SSH 命令中添加 --node=all 标志来实现这些教程,使其适合多切片训练。如需了解详情和最佳实践,请参阅多切片简介

JAX 教程

Train Diffusion 2.1

本教程介绍如何使用 Cloud TPU v5p 上的 Pokémon 数据集通过 HuggingFace 训练稳定扩散模型。

Stable Diffusion 模型是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图像。如需了解详情,请参阅以下资源:

设置

  1. 创建环境变量:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    

    命令标志说明

    变量 说明
    PROJECT_ID Google Cloud 项目名称
    ACCELERATOR_TYPE 请参阅您的 TPU 版本的 TPU 版本页面。
    ZONE 如需了解受支持的可用区,请参阅 TPU 区域和可用区文档。
    RUNTIME_VERSION 对于 v5p,请为 RUNTIME_VERSION 使用 v2-alpha-tpuv5。
    SERVICE_ACCOUNT 这是您服务帐号的地址,您可以在 Google Cloud 控制台 -> IAM -> 服务帐号中找到。例如:tpu-service-account@myprojectID.iam.gserviceaccount.com
    TPU_NAME 由用户指定的 TPU 文本 ID(在分配已排队的资源请求时创建)。
    QUEUED_RESOURCE_ID 已加入队列的资源请求的文本 ID,由用户分配。如需了解已加入队列的资源,请参阅已加入队列的资源文档。
    QUOTA_TYPE 可以是 reservedbest-effort。如果二者均未指定,则 QUOTA_TYPE 默认为 on-demand。 如需了解 Cloud TPU 支持的不同类型的配额,请参阅quotas
    VALID_UNTIL_DURATION 请求的有效时长。如需了解不同的有效时长,请参阅 已加入队列的资源
  2. 创建 TPU 资源

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_UNTIL_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    加入队列的资源处于 ACTIVE 状态后,您就可以通过 SSH 连接到 TPU 虚拟机。运行以下命令来检查已加入队列的资源的状态:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}
    

    当已加入队列的资源处于 ACTIVE 状态时,输出将类似于以下内容:

    state: ACTIVE
    
  3. 安装 JAX 及其依赖项。

    # compatible with v5p: only jax version 0.4.19 and later \
    # jax 0.4.19 requires py 3.10 \
    
    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} --zone=${ZONE} --worker=all \
    --command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 下载 HuggingFace 代码库和安装要求。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='git clone https://github.com/huggingface/diffusers.git && cd diffusers && pip install . && pip install tensorflow clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
    
  5. 训练模型

    使用 4GB 的预映射缓冲区训练模型。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='export PATH=$PATH:$HOME/.local/bin && cd diffusers/examples/text_to_image && JAX_PLATFORMS=tpu,cpu python3 train_text_to_image_flax.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1 --dataset_name=lambdalabs/pokemon-blip-captions --resolution=256 --center_crop --random_flip --train_batch_size=1 --mixed_precision=bf16 --max_train_steps=150 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-pokemon-model --from_pt'
    

清理

在会话结束时删除 TPU 和已加入队列的资源请求,或移除处于“FAILED”状态的已加入队列的资源请求。如需删除已加入队列的资源,请按照 2 个步骤先删除切片,然后删除已加入队列的资源请求:

   gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID}
   --zone=${ZONE} --quiet
   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
   --project ${PROJECT_ID} --zone ${ZONE} --quiet

或者,使用 --force 一步删除 Slice 和已加入队列的资源请求:

# With --force
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project ${PROJECT_ID} --zone ${ZONE} --quiet --force

基准测试结果

Stable Diffusion 训练脚本在 v5p-8、v5p-32 和 v5p-128 上运行。下表显示了吞吐量。

v5p-8

v5p-32

v5p-128

训练步骤

150

150

150

全局批量大小

32

64

64

吞吐量(样本数/秒)

12.10

8 月 18 日

19 月 10 日

MaxText

本教程介绍如何使用 Cloud TPU 上的合成数据集训练 MaxText 模型。

MaxText 是一种经过良好测试的高性能、可任意伸缩的开源 LLM,以纯 Python/JAX 为目标 Cloud TPU 编写而成。MaxText 为研究人员和开发者提供方便易用且可适应的工具,以推进自然语言处理 (NLP) 的研发前沿。

在运行本教程之前,您需要设置 Cloud TPU 环境

  1. 设置环境变量

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name # user defined TPU name
    export ACCELERATOR_TYPE=v5p-256
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export RUN_NAME=your_experiment_run_name # user defined name for this run
    export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs://
    export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path
    export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
    

    命令标志说明

    变量 说明
    PROJECT_ID Google Cloud 项目名称
    TPU_NAME 用户定义的 TPU 名称。
    ACCELERATOR_TYPE 请参阅您的 TPU 版本的 TPU 版本页面。
    ZONE 如需了解受支持的可用区,请参阅 TPU 区域和可用区文档。
    RUNTIME_VERSION 对于 v5p,请使用 v2-alpha-tpuv5 作为运行时版本。
    RUN_NAME 用户提供的实验运行名称。

    建议针对多切片进行可选设置:

    export NETWORK_NAME=your_network_name
    export FIREWALL_RULE_NAME=your_firewall_rule_name
    

    如果您正在运行多切片工作负载并希望获得最佳网络性能,请考虑创建最大传输单元 (MTU) 为 8896 字节的专用网络,并配置适当的防火墙规则。此步骤虽然是可选的,但可以显著提高性能,尤其是在通过数据中心网络 (DCN) 纵向扩容切片数时。请注意,创建网络需要项目的 compute.networks.create 权限。以下示例展示了如何创建专用网络和防火墙规则。

    创建专用网络:

    gcloud compute networks create ${NETWORK_NAME} \
    --mtu=8896 \
    --project=${PROJECT_ID} \
    --subnet-mode=auto \
    --bgp-routing-mode=regional
    

    创建防火墙规则:

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
    
  2. 克隆 MaxText 代码库

    git clone https://github.com/google/maxtext.git
    
  3. 训练模型

    以下部分介绍了训练 MaxText 的两个选项。

    选项 1

    如果您希望使用脚本来管理整个工作流(从预配 Cloud TPU 和安装依赖项到运行模型和拆解资源),则可以使用 multihost_job.py

    cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \
    --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \
    --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \
    --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs
    --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

    启动脚本后,您应该会在日志中看到类似于以下内容的消息。输出消息中引用了日志位置。 TPU 预配完成后,点击第一个链接可访问所有工作器的日志。

    ------------------------------------
    
    multihost_job finished running, TPUs are starting up to run your job remotely.
    
    Logs for your job are displayed here:
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22_log%22%2529;?project=PROJECT_ID
    
    To see the output of a single host, you may edit the slice and worker
    number in the `log_file_path` property here:
    
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID
    
    When your job is finished, the main command log is in your Cloud Storage
    bucket:
    https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID
    
    View the status of the created TPUs using:
    gcloud compute tpus queued-resources list --filter=RUN_NAME
    --zone=ZONE --project=PROJECT_ID
    
选项 2

如需在已预配的 Cloud TPU 上多次运行训练脚本,请使用 multihost_runner.py 脚本来使用资源。

  1. 设置变量以创建 TPU。

    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export VALID_DURATION=1d
    export QUOTA_TYPE=quota_type
    
    --node-count ${NODE_COUNT} \
    --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
    
  2. 创建 TPU 资源。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    QueuedResource 处于 ACTIVE 状态后,您可以使用 SSH 连接到您的 TPU 虚拟机:

    使用 describe 命令查询已加入队列的资源的状态。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}
    --project ${PROJECT_ID} --zone ${ZONE}
    

    当已加入队列的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    
  3. 使用 SSH 连接到您的 TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. 安装依赖项

    export TPU_NAME=your_tpu_name
    export MAXTEXT_OUTPUT_PATH=output-path
    
    cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND='bash setup.sh'
    
  5. 使用各种配置脚本(如 32b.sh、64b.sh)运行模型。如果您从 TPU 虚拟机运行脚本,则需要添加标志 --INTERNAL_IP=true

    python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME}
    OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

清理

删除 TPU 和排队的资源

基准测试结果

MaxText 训练脚本的运行规模介于 320 亿到 11600 亿之间,且精确率为 bf16。运行结果如下表所示。

参数数量

加速器类型

TFLOP/芯片/秒

模型 flops 利用率

(MFU)

320 亿

v5p-128

3.28E+02

71.47%

640 亿

v5p-128

3.23E+02

70.31%

1280 亿

v5p-256

3.15E+02

68.68%

1280 亿

v5p-512

3.15E+02

68.53%

2560 亿

v5p-1024

3.16E+02

68.82%

5,120 亿

v5p-1024

2.94E+02

63.99%

1,0240 亿

5p-2048

2.49E+02

64.05%

1,0240 亿

5p-4096

2.97E+02

64.80%

1,1600 亿

5p-7680

2.95E+02

64.27%

1,1600 亿

v5p-12288

3.04E+02

66.23%

256B 参数模型已使用 bf16 和 int8 权重在 v5p-512 和 v5p-1024 上进行了测试。下表显示了这些测试的结果。

v5p-512

v5p-512

v5p-1024

v5p-1024

全局批量大小

(词元)

5.24 欧洲 + 05

5.24 欧洲 + 05

1.05E+06

1.05E+06

精确率

bf16

int8

bf16

int8

TFLOP/芯片/秒

307

408

308

414

模型 flops 利用率

(MFU)

66.98%

88.85%

67.09%

90.23%

TensorFlow 教程

在单个主机 v5p 上训练 ResNet

本教程介绍如何使用虚构数据集在 v5p-8 TPU 上训练 ImageNet。如果要使用其他数据集,请参阅准备数据集

设置

  1. 创建环境变量:

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.16.1-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    在本教程中,使用 v5p-8 作为 ACCELERATOR_TYPE

  2. 创建 TPU 资源

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    加入队列的资源处于 ACTIVE 状态后,您就可以使用 SSH 连接到您的 TPU 虚拟机。如需检查已加入队列的资源的状态,请使用以下命令:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. 使用 SSH 连接到您的 TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. 设置一些环境变量

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
    
  5. 切换到模型代码库目录和安装要求。

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt
    

训练模型

  1. 运行训练脚本。

    python3 official/vision/train.py \
      --tpu=local \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

清理

删除 TPU 和排队的资源

在多主机 v5p 上训练 ResNet

本教程介绍如何使用虚构数据集在 v5p-16 或更高版本上训练 ImageNet。如果要使用其他数据集,请参阅准备数据集

  1. 创建环境变量:

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name
    export ZONE=us-east1-c
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=tpu-vm-tf-2.16.1-pod-pjrt
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    ACCELERATOR_TYPE 可以为 v5p-16 或更大。

  2. 创建 TPU 资源

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    加入队列的资源处于 ACTIVE 状态后,您就可以使用 SSH 连接到您的 TPU 虚拟机。

    使用 describe 命令查询已加入队列的资源的状态:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. 使用 SSH 连接到您的 TPU (工作器零)

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. 设置一些环境变量

    export TPU_NAME=your_tpu_name
    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
    
  5. 切换到模型代码库目录和安装要求。

    cd $MODELS_REPO && git checkout r2.15.0
    pip install -r official/requirements.txt
    

训练模型

  1. 运行训练脚本。

    python3 official/vision/train.py \
      --tpu=${TPU_NAME} \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

清理

删除 TPU 和排队的资源

PyTorch/XLA

Llama 2

本教程将介绍如何使用 PyTorch/XLA 上具有通用和可扩缩并行处理的机器学习计算图 (GSPMD) 的 HuggingFace 代码库分支,在 v5p 上训练 Llama 2 7B 模型。

设置

  1. 为项目 ID、加速器类型、地区、运行时版本和 TPU 名称创建变量。

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_DURATION=1d
    
  2. 创建 TPU 资源

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    在您的 QueuedResource 处于 ACTIVE 状态后,您可以使用 SSH 连接到您的 TPU 虚拟机:

    使用 describe 命令查询已加入队列的资源的状态。

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project ${PROJECT_ID} \
    --zone ${ZONE}
    

    当已加入队列的资源处于 ACTIVE 状态时,输出将类似于以下内容:

     state: ACTIVE
    

  3. 安装 Pytorch/XLA 和所需的依赖项。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='
    sudo apt-get update
    sudo apt-get install libopenblas-dev -y
    pip3 install numpy
    pip3 install typing-extensions
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    '
    
  4. 下载 HuggingFace 代码库和安装要求。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git
    cd transformers
    pip3 install git+file://$PWD
    pip3 install datasets accelerate evaluate scikit-learn'
    
  5. 下载 7B 模型配置。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
    
  6. 训练模型

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    export PJRT_DEVICE=TPU
    export XLA_USE_BF16=1
    export XLA_IR_DEBUG=1
    export XLA_HLO_DEBUG=1
    
    export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true
    --xla_tpu_enable_async_collective_fusion_multiple_steps=true
    --xla_tpu_enable_async_collective_fusion=true
    --xla_tpu_overlap_compute_collective_tc=true
    --xla_enable_async_all_gather=true
    --xla_jf_spmd_threshold_for_windowed_einsum_mib=0"
    
    export PROFILE_EPOCH=0
    export PROFILE_STEP=3
    export PROFILE_DURATION_MS=20000
    export PROFILE_LOGDIR=/tmp/home/
    
    cd transformers
    python examples/pytorch/language-modeling/run_clm.py \
     --tokenizer_name hf-internal-testing/llama-tokenizer \
     --dataset_name wikitext \
     --dataset_config_name wikitext-2-raw-v1 \
     --per_device_train_batch_size 96 \
     --per_device_eval_batch_size 8 \
     --num_train_epochs 1 \
     --do_train \
     --output_dir /tmp/output \
     --overwrite_output_dir \
     --config_name ~/config.json \
     --save_strategy no \
     --logging_strategy no \
     --remove_unused_columns no \
     --optim adafactor \
     --torch_dtype bfloat16 \
     --dataloader_drop_last yes \
     --block_size 2048 \
     --spmd_2d_sharding 1 \
     --spmd_grad_chkpt
    '
    

如果在多切片环境中运行,则需要将标志 --spmd_dcn_parallelism 设置为切片的数量。

SPMD_USER_GUIDE 提供了一份更深入的用户指南,其中解释了 HF 脚本的所有不同环境变量和切换开关。请注意,LIBTPU_INIT_ARGS 将纳入 PyTorch/XLA 中,并在未来版本中默认启用。

清理

删除 TPU 和排队的资源

基准测试结果

下表列出了所有三种 Llama 2 模型大小的吞吐量。

v5p-8

v5p-128

v5p-128

模型大小

70 亿

130 亿

700 亿

全局批量大小

96

1024

128

分片网格形状

(4、1)

(64、1)

(16、4)

模型 flops 利用率

(MFU)

56.67%

55.80%

51.85%

支持与反馈

我们欢迎各种反馈!如需分享反馈或请求支持,请填写 Cloud TPU 支持或反馈表单