Cloud TPU v5p 训练
Cloud TPU v5p 是 Google Cloud 的第五代 Cloud TPU,是 v4 TPU 的后继产品。v5p 针对大规模训练进行了优化,是开发基础 LLM、扩散模型和生成式 AI 的领先平台。概括地说,v5p 的性能最高可达到 v4 的 2 倍,同时还可将 2 倍多的 TPU 打包到一个 Pod 中(最大 slice 为 6,000,而 v4 为 3,000),从而在 Pod 级别将性能提高到最高 4 倍。它还采用更高的时钟频率(1.75Ghz 对比 1.05Ghz),添加了 SparseCore 以实现大规模嵌入,并将高带宽内存 (HBM) 容量提高了三倍。
Cloud TPU v5p 概念
如果您是 Cloud TPU 新手,请参阅 TPU 文档首页。
Cloud TPU 系统架构页面介绍了 Cloud TPU 的概念(例如 slice、主机和 TensorCore)以及所有 Cloud TPU 版本的 Cloud TPU 系统架构。
每个 Cloud TPU 版本都需要特定的加速器类型才能进行训练或推理。v5p 配置中介绍了这些加速器类型。
管理 TPU 资源
本文档中的所有命令都假定您要创建 TPU v5p 虚拟机。如需详细了解用于创建 TPU 虚拟机的命令,请参阅管理 TPU;如需管理队列化资源,请参阅队列化资源用户指南。为了更方便地运行命令,本文档中的代码示例使用以下环境变量:
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
环境变量说明
框架设置
本部分介绍了使用 JAX 或 PyTorch 搭配 TPU v5p 进行模型训练的一般设置流程。
JAX 设置
如果您的 slice 形状大于 4 个条状标签,则一个 slice 中将包含多个虚拟机。在这种情况下,您需要使用 --worker=all
标志,通过单个命令在所有 TPU 虚拟机上运行安装:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
您可以运行以下命令来检查设备数量(此处显示的输出是使用 v5p-32 slice 生成的)。此代码通过检查 JAX 是否看到 Cloud TPU TensorCore 以及是否可以运行基本操作来测试是否已正确安装所有组件:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
输出将如下所示:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... 16 4 16 4 16 4 16 4
jax.device_count()
显示给定 slice 中的条状标签总数。jax.local_device_count()
表示此 slice 中单个虚拟机可以访问的芯片数量。
# Check the number of chips in the given slice by summing the count of chips # from all VMs through the # jax.local_device_count() API call. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
输出将如下所示:
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.] [16. 16. 16. 16.]
使用 --node=all
在所有 Multislice 工作器上运行该命令。
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
尝试学习本文档中的 JAX 教程,开始使用 JAX 进行 v5p 训练。
PyTorch 设置
PJRT 运行时是 v5p 的唯一受支持的运行时,而 PyTorch 2.1 及更高版本使用 PJRT 作为所有 TPU 版本的默认运行时。本部分介绍如何开始在 v5p Pod 上使用 PJRT,并为所有工作器安装 PyTorch/XLA 2.2.0。
安装依赖项
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip install numpy pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu '
将 Python 脚本与 PJRT 搭配使用,以验证安装情况。该脚本会显示可用的 TPU 设备(此处显示的输出是使用 v5p-32 切片生成的)。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} --zone ${ZONE} --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
SSH: Attempting to connect to worker 0... SSH: Attempting to connect to worker 1... SSH: Attempting to connect to worker 2... SSH: Attempting to connect to worker 3... ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3'] ['xla:0', 'xla:1', 'xla:2', 'xla:3']
使用 --node=all
在所有 Multislice 工作器上运行该命令。
gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \ --command=' PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" '
尝试学习本文档中的 PyTorch 教程,开始使用 PyTorch 进行 v5p 训练。
监控和配置文件
Cloud TPU v5p 支持使用与上一代 Cloud TPU 相同的方法进行监控和性能分析。您可以参阅使用 Cloud TPU 工具剖析模型的性能,详细了解性能分析;也可以参阅监控 Cloud TPU 虚拟机,详细了解监控。
培训教程
本部分重点介绍单个 Slice 的训练教程。如需将这些教程调整为适用于多 Slice 训练,只需向 SSH 命令添加 --node=all
标志即可。如需了解详情和最佳实践,请参阅多 Slice 简介。
JAX 教程
模块序列扩散 2.1
本教程介绍如何在 Cloud TPU v5p 上使用 Pokémon 数据集训练来自 HuggingFace 的 Stable Diffusion 模型。
Stable Diffusion 模型是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图片。如需了解详情,请参阅以下资源:
设置
创建环境变量:
export GCS_BUCKET_NAME=your-bucket export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5p-32 export ZONE=europe-west4-b export LOCATION=europe-west4 export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your-service-account export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-qr-name export QUOTA_TYPE=spot export VALID_UNTIL_DURATION=1d
命令标志说明
变量 说明 PROJECT_ID Google Cloud 项目名称 ACCELERATOR_TYPE 请参阅 TPU 版本页面,了解您的 TPU 版本。 区域 如需了解支持的区域,请参阅 TPU 区域和可用区文档。 LOCATION 用于创建 Cloud Storage 存储桶的 Google Cloud 区域。 RUNTIME_VERSION 对于 v5p,请为 RUNTIME_VERSION 使用 v2-alpha-tpuv5。 SERVICE_ACCOUNT 这是您的服务账号的地址,您可以在 Google Cloud 控制台 -> IAM -> 服务账号中找到。例如:tpu-service-account@myprojectID。iam.gserviceaccount.com TPU_NAME TPU 的用户分配文本 ID,在分配队列中的资源请求时创建。 QUEUED_RESOURCE_ID 已加入队列的资源请求的用户分配的文本 ID。如需了解已排队的资源,请参阅已排队的资源文档。 QUOTA_TYPE 可以是 reserved
或spot
。如果未指定这两者,则 QUOTA_TYPE 默认为on-demand
。如需了解 Cloud TPU 支持的不同类型的配额,请参阅配额。VALID_UNTIL_DURATION 请求有效的时长。如需了解不同的有效时长,请参阅 队列化资源。 为模型输出设置存储桶。
gcloud storage buckets create gs://$GCS_BUCKET_NAME \ --project=$PROJECT_ID \ --location=$LOCATION
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_UNTIL_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
当已排队的资源处于
ACTIVE
状态时,您就可以通过 SSH 连接到 TPU 虚拟机。运行以下命令检查队列中资源的状态:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
当队列中的资源处于
ACTIVE
状态时,输出将类似于以下内容:state: ACTIVE
训练模型
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command=" git clone https://github.com/google/maxdiffusion cd maxdiffusion git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use. pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install the latest version of JAX pip3 install -r requirements.txt pip3 install . export LIBTPU_INIT_ARGS="" python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run base_output_directory=gs://$GCS_BUCKET_NAME enable_profiler=False"
清理
在会话结束时删除 TPU 和已排队的资源请求,或移除处于“失败”状态的已排队资源请求。如需删除已排队的资源,请按以下 2 个步骤删除 slice,然后删除已排队的资源请求:
gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID} \ --zone=${ZONE} --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
或者,使用 --force
一步即可删除 slice 和队列中的资源请求:
# With --force gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
基准测试结果
Stable Diffusion 训练脚本在 v5p-8、v5p-32 和 v5p-128 上运行。 下表显示了吞吐量。
v5p-8 |
v5p-32 |
v5p-128 |
|
---|---|---|---|
训练步骤 |
150 |
150 |
150 |
全局批量大小 |
32 |
64 |
64 |
吞吐量(示例/秒) |
12.10 |
18.08 |
19.10 |
MaxText
本教程介绍如何在 Cloud TPU 上使用合成数据集训练 MaxText 模型。
MaxText 是一个高性能、任意可伸缩、开源且经过充分测试的 LLM,以纯 Python/JAX 编写,以 Cloud TPU 为目标平台。MaxText 是一款简单易用且可自定义的工具,可帮助研究人员和开发者推进自然语言处理 (NLP) 研究和开发的前沿。
在运行本教程之前,您需要设置 Cloud TPU 环境。
设置环境变量
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name # user defined TPU name export ACCELERATOR_TYPE=v5p-256 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export RUN_NAME=your_experiment_run_name # user defined name for this run export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs:// export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
命令标志说明
变量 说明 PROJECT_ID Google Cloud 项目名称 TPU_NAME TPU 的用户定义的名称。 ACCELERATOR_TYPE 请参阅 TPU 版本页面,了解您的 TPU 版本。 区域 如需了解支持的区域,请参阅 TPU 区域和可用区文档。 RUNTIME_VERSION 对于 v5p,请将运行时版本设为 v2-alpha-tpuv5。 RUN_NAME 用户提供的实验运行名称。 建议为多 Slice 设置以下可选设置:
export NETWORK_NAME=your_network_name export FIREWALL_RULE_NAME=your_firewall_rule_name
如果您正在运行多 slice 工作负载并希望获得最佳网络性能,不妨考虑创建一个最大传输单元 (MTU) 为 8896 字节的专用网络,并配置适当的防火墙规则。虽然此步骤是可选的,但可以显著提升性能,尤其是在通过数据中心网络 (DCN) 扩容 slice 数量时。请注意,创建网络需要在项目中拥有
compute.networks.create
权限。以下示例展示了如何创建专用网络和防火墙规则。创建专用网络:
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
创建防火墙规则:
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
克隆 MaxText 代码库
git clone https://github.com/google/maxtext.git
训练模型
以下部分介绍了训练 MaxText 的两种方法。
选项 1
如果您希望使用脚本来管理整个工作流(从预配 Cloud TPU 和安装依赖项到运行模型和拆解资源),可以使用
multihost_job.py
。cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \ --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \ --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \ --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
启动脚本后,您应该会在日志中看到类似于以下内容的消息。输出消息中会引用日志位置。TPU 预配完成后,点击第一个链接即可访问所有工作器的日志。
------------------------------------ multihost_job finished running, TPUs are starting up to run your job remotely. Logs for your job are displayed here: https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22
_log%22%2529;?project=PROJECT_ID To see the output of a single host, you may edit the slice and worker number in the `log_file_path` property here: https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID When your job is finished, the main command log is in your Cloud Storage bucket: https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID View the status of the created TPUs using: gcloud compute tpus queued-resources list --filter=RUN_NAME --zone=ZONE --project=PROJECT_ID
选项 2
如需在预配的 Cloud TPU 上多次运行训练脚本,请使用 multihost_runner.py
脚本使用该资源。
设置变量以创建 TPU。
export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export VALID_DURATION=1d export QUOTA_TYPE=quota_type
--node-count ${NODE_COUNT} \ --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
创建 TPU 资源。
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
当
QueuedResource
处于ACTIVE
状态后,您就可以使用 SSH 连接到 TPU 虚拟机:使用
describe
命令查询队列中资源的状态。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} --project ${PROJECT_ID} --zone ${ZONE}
当队列中的资源处于“有效”状态时,输出将类似于以下内容:
state: ACTIVE
使用 SSH 连接到 TPU
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
安装依赖项
export TPU_NAME=your_tpu_name export MAXTEXT_OUTPUT_PATH=output-path
cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \ --COMMAND='bash setup.sh'
使用各种配置脚本(例如 32b.sh、64b.sh)运行模型。如果您要从 TPU 虚拟机运行脚本,则需要添加标志
--INTERNAL_IP=true
。python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \ --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
清理
基准测试结果
MaxText 训练脚本以 bf16 精度从 32B 运行到 1160B。这些运行的结果如以下表所示。
参数数量 |
加速器类型 |
TFLOP/芯片/秒 |
模型 FLOP 利用率 (MFU) |
---|---|---|---|
32B |
v5p-128 |
3.28E+02 |
71.47% |
64B |
v5p-128 |
3.23E+02 |
70.31% |
128B |
v5p-256 |
3.15E+02 |
68.68% |
128B |
v5p-512 |
3.15E+02 |
68.53% |
256B |
v5p-1024 |
3.16E+02 |
68.82% |
512B |
v5p-1024 |
2.94E+02 |
63.99% |
1024B |
v5p-2048 |
2.49E+02 |
64.05% |
1024B |
v5p-4096 |
2.97E+02 |
64.80% |
1160B |
v5p-7680 |
2.95E+02 |
64.27% |
1160B |
v5p-12288 |
3.04E+02 |
66.23% |
我们使用 bf16 和 int8 权重在 v5p-512 和 v5p-1024 上测试了 256B 参数模型。下表显示了这些测试的结果。
v5p-512 |
v5p-512 |
v5p-1024 |
v5p-1024 |
|
---|---|---|---|---|
全局批量大小 (tokens) |
5.24E+05 |
5.24E+05 |
1.05E+06 |
1.05E+06 |
精确率 |
bf16 |
int8 |
bf16 |
int8 |
TFLOP/芯片/秒 |
307 |
408 |
308 |
414 |
模型 FLOP 利用率 (MFU) |
66.98% |
88.85% |
67.09% |
90.23% |
TensorFlow 教程
在单个主机 v5p 上训练 ResNet
本教程介绍了如何使用虚构数据集在 v5p-8
TPU 上训练 ImageNet。如果您想使用其他数据集,请参阅准备数据集。
设置
创建环境变量:
export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5p-32 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pjrt export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
在本教程中,请使用
v5p-8
作为ACCELERATOR_TYPE
。-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
队列中的资源处于
ACTIVE
状态后,您就可以使用 SSH 连接到 TPU 虚拟机了。如需检查队列中资源的状态,请使用以下命令:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
使用 SSH 连接到 TPU
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
设置一些环境变量
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export NEXT_PLUGGABLE_DEVICE_USE_C_API=true export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
切换到模型代码库目录并安装要求。
cd ${MODELS_REPO} && git checkout r2.15.0 pip install -r official/requirements.txt
训练模型
运行训练脚本。
python3 official/vision/train.py \ --tpu=local \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
清理
在多主机 v5p 上训练 ResNet
本教程介绍了如何使用虚构数据集在 v5p-16
或更高版本上训练 ImageNet。如果您想使用其他数据集,请参阅准备数据集。
创建环境变量:
export PROJECT_ID=your_project_ID export TPU_NAME=your_tpu_name export ZONE=us-east1-c export ACCELERATOR_TYPE=v5p-16 export RUNTIME_VERSION=tpu-vm-tf-2.18.0-pod-pjrt export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
可以是v5p-16
或更大。-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --${QUOTA_TYPE}
队列中的资源处于
ACTIVE
状态后,您就可以使用 SSH 连接到 TPU 虚拟机了。使用
describe
命令查询队列中资源的状态:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
使用 SSH 连接到 TPU(工作器 0)
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE}
设置一些环境变量
export TPU_NAME=your_tpu_name export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export TPU_LOAD_LIBRARY=0
切换到模型代码库目录并安装要求。
cd $MODELS_REPO && git checkout r2.15.0 pip install -r official/requirements.txt
训练模型
运行训练脚本。
python3 official/vision/train.py \ --tpu=${TPU_NAME} \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
清理
PyTorch/XLA
Llama 2
本教程将介绍如何使用 PyTorch/XLA 中的 HuggingFace 代码库分支(并使用适用于机器学习计算图的通用可伸缩并行化 [GSPMD])在 v5p 上训练 Llama 2 7B 模型。
设置
创建环境变量。
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_DURATION=1d
创建 TPU 资源
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
当
QueuedResource
处于ACTIVE
状态后,您就可以使用 SSH 连接到 TPU 虚拟机:使用
describe
命令查询队列中资源的状态。gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE}
当队列中的资源处于 ACTIVE 状态时,输出将类似于以下内容:
state: ACTIVE
安装 Pytorch/XLA 及所需的依赖项。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' sudo apt-get update sudo apt-get install libopenblas-dev -y pip3 install numpy pip3 install typing-extensions pip install torch torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html '
下载 HuggingFace 仓库并安装要求。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install git+file://$PWD pip3 install datasets accelerate evaluate scikit-learn'
下载 7B 模型配置。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
训练模型
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_BF16=1 export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true \ --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ --xla_tpu_enable_async_collective_fusion=true \ --xla_tpu_overlap_compute_collective_tc=true \ --xla_enable_async_all_gather=true \ --xla_jf_spmd_threshold_for_windowed_einsum_mib=0" export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=20000 export PROFILE_LOGDIR=/tmp/home/ cd transformers python examples/pytorch/language-modeling/run_clm.py \ --tokenizer_name hf-internal-testing/llama-tokenizer \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 96 \ --per_device_eval_batch_size 8 \ --num_train_epochs 1 \ --do_train \ --output_dir /tmp/output \ --overwrite_output_dir \ --config_name ~/config.json \ --save_strategy no \ --logging_strategy no \ --remove_unused_columns no \ --optim adafactor \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --block_size 2048 \ --spmd_2d_sharding 1 \ --spmd_grad_chkpt '
如果您是在多 slice 环境中运行,则需要将标志 --spmd_dcn_parallelism
设置为 slice 数量。
SPMD_USER_GUIDE 提供了更深入的用户指南,其中介绍了 HF 脚本的所有不同环境变量和切换开关。请注意,LIBTPU_INIT_ARGS 将在未来的版本中纳入 PyTorch/XLA 中,并默认处于开启状态。
清理
基准测试结果
下表列出了三种 Llama 2 模型大小的吞吐量。
v5p-8 |
v5p-128 |
v5p-128 |
|
---|---|---|---|
模型大小 |
70 亿 |
130 亿 |
700 亿 |
全局批量大小 |
96 |
1024 |
128 |
分片网格形状 |
(4, 1) |
(64, 1) |
(16, 4) |
模型 FLOP 利用率 (MFU) |
56.67% |
55.80% |
51.85% |
支持与反馈
欢迎您提供任何反馈!如需分享反馈或申请支持,请填写 Cloud TPU 支持或反馈表单。