Train a model using TPU v6e
This document guides you through training models on Cloud TPU v6e (also called Trillium), covering environment setup, performance optimization, and practical training examples using JAX and PyTorch/XLA.
TPU v6e, also called Trillium, is Google's 6th generation of TPUs. On all technical surfaces, such as the API and logs, and throughout this document, Trillium will be referred to as v6e. With 256 chips per Pod, the architecture of TPU v6e shares many similarities with v5e. TPU v6e is optimized for transformer, text-to-image, and convolutional neural network (CNN) training, fine-tuning, and serving. For more information about the TPU v6e system architecture and configurations, see TPU v6e.
For information about running inference on Cloud TPU v6e, see the following tutorials:
- JetStream MaxText inference on v6e
- JetStream PyTorch inference on v6e
- MaxDiffusion inference on v6e
- vLLM inference on v6e
- Perform multihost inference using Pathways
Before you begin
Before you begin, you need to:
- Create a Google Cloud account and project with billing enabled
- Install Google Cloud CLI alpha components
- Enable the Cloud TPU API
- Create a Cloud TPU service agent
- Create a Cloud TPU service account and grant permissions
For more information, see Set up the Cloud TPU environment.
Verify quota and permissions
Verify that your project has the following quotas:
- TPU v6e preemptible or on-demand quota
- IP address quota
Quota for Hyperdisk Balanced and for any other disk types you want to use
If you're using GKE with XPK, you need additional permissions in the Google Cloud console. For more information, see Permissions needed on Google Cloud console .
Provision TPUs
You can provision and manage TPU v6e using the following methods:
- GKE: You can use GKE to provision and manage TPUs as a pool of accelerators for your containerized machine learning workloads. For more information, see About TPUs in GKE.
- GKE and XPK: XPK is a command-line tool that simplifies cluster creation and workload execution on GKE. It's designed for ML practitioners to provision TPUs and run training jobs without needing deep Kubernetes expertise. For more information, see the XPK GitHub repository.
- Cloud TPU queued resources: Queued resources let you request TPU capacity that is provisioned when it becomes available. It's ideal for batch jobs and fault-tolerant workloads that can wait in a queue. You can specify a time window for your request. For more information, see Manage queued resources.
Provision v6e Cloud TPUs with GKE and XPK
If you are using GKE commands with v6e, you can use Kubernetes commands or XPK to provision Cloud TPUs and train or serve models. See Plan for Cloud TPUs in GKE to learn how to plan your Cloud TPU configurations in GKE clusters. The following sections provide commands to create an XPK cluster with single-NIC support and multi-NIC support.
Create an XPK cluster with single-NIC support
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Variable | Description |
CLUSTER_NAME | The user-assigned name for the XPK cluster. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project. |
ZONE | See the Cloud TPU regions and zones document for the supported zones. |
TPU_TYPE | See Accelerator Types. |
NUM_SLICES | The number of slices you want to create |
CLUSTER_ARGUMENTS | The network and subnetwork to use.
For example: |
NUM_SLICES | The number of slices to create. |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Create an XPK cluster with multi-NIC support
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Variable | Description |
CLUSTER_NAME | The user-assigned name for the XPK cluster. |
PROJECT_ID | Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project. |
ZONE | See the Cloud TPU regions and zones document for the supported zones. |
TPU_TYPE | See Accelerator Types. |
NUM_SLICES | The number of slices you want to create |
CLUSTER_ARGUMENTS | The network and subnetwork to use.
For example: |
NODE_POOL_ARGUMENTS | Additional node network to use.
For example: |
NUM_SLICES | The number of slices to create (needed for Multislice only). |
NETWORK_NAME | The name of a secondary network to use. |
NETWORK_FW_NAME | The name of a secondary network firewall to use. |
Set up JAX or PyTorch
The following resources show how to set up JAX or PyTorch on your Cloud TPU, depending on which provisioning and management method you use:
- GKE Autopilot: Prepare your TPU application
- GKE Standard: Prepare your workloads
- GKE and XPK: XPK README
- Single-host Cloud TPU using JAX: Run a calculation on a Cloud TPU VM using JAX
- Multi-host Cloud TPU using JAX: Run JAX code on TPU slices
- Single-host Cloud TPU using PyTorch: Run a calculation on a Cloud TPU VM using PyTorch
- Multi-host Cloud TPU using PyTorch: Run PyTorch code on TPU slices
To set up and run XPK with MaxText, see Running MaxText at Scale with XPK .
Optimize network performance
This section describes how to optimize your network performance by configuring the maximum transmission unit (MTU), using multi-NIC for Multislice environments, and improving TCP settings.
Configure MTU
For the best network performance, use a network with 8,896 MTU (maximum transmission unit).
By default, a Virtual Private Cloud (VPC) only provides an MTU of 1,460 bytes, which provides suboptimal network performance. You can set a VPC network's MTU to any value between 1,300 bytes and 8,896 bytes (inclusive). Common custom MTU sizes are 1,500 bytes (standard Ethernet) or 8,896 bytes (the maximum possible). For more information, see Valid VPC network MTU sizes.
For more information about changing the MTU setting for an existing or default network, see Change the MTU setting of a VPC network.
The following example creates a network with 8,896 MTU and a corresponding firewall rule that allows TCP, ICMP, and UDP traffic within the network.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Replace your-resource-name with a base name for the network and firewall.
Use the multi-NIC option for Multislice
If you're using a Multislice environment, set the following environment variables, which are required for a secondary subnet:
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Use the following commands to create custom IP routing for the network and subnet.
Create the secondary network.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \ --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
Create a subnetwork for the secondary network.
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 --region=${REGION} \ --project=${PROJECT_ID}
Create a firewall rule to allow traffic within the new subnetwork.
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \ --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
Create a Cloud Router for the secondary network.
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
Create a NAT configuration for the Cloud Router.
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
After you create a multi-network slice, you can validate that both network
interface cards (NICs) are being used by setting up an XPK
cluster
and adding the --command ifconfig
flag to the XPK workload creation
command.
Use the following
workload create
command to display the output of theifconfig
command in Google Cloud console logs and check that both eth0 and eth1 have MTU set to 8,896.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
If you want to enable debug logs or use Vertex AI TensorBoard, add the following optional arguments to the command:
--enable-debug-logs \ --use-vertex-tensorboard
Verify that both eth0 and eth1 have MTU set to 8,896 by checking the output of the XPK workload in Google Cloud console logs.
Improve TCP settings
If you provisioned your Cloud TPUs using queued resources, you can run the following command to improve network performance by increasing TCP receive buffer limits.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Optimize memory allocation performance
The tcmalloc
library is used by default on Cloud TPU VMs to improve
performance for models with sizable, frequent memory allocations. This is
configured through the LD_PRELOAD
environment variable.
However, for some workloads (for example, DLRM with very large embedding table
allocations), tcmalloc
can cause a slowdown. In such cases, you can revert to
the standard malloc
function by unsetting the LD_PRELOAD
variable in your
shell session before running your training script:
unset LD_PRELOAD
Use SkyPilot
You can use Cloud TPU v6e with SkyPilot. SkyPilot is an open-source framework that simplifies the process of running, managing, and scaling AI workloads. You can add v6e-related location and pricing information to SkyPilot. For more information, see the SkyPilot TPU v6e example.
Training examples
The following sections provide examples for training MaxText, MaxDiffusion, and PyTorch models on Cloud TPU v6e.
These examples have been tested with the following software versions:
- Python
3.10
or later - Nightly software versions:
- Nightly JAX
0.4.32.dev20240912
- Nightly LibTPU
0.1.dev20240912+nightly
- Nightly JAX
- Stable software versions:
- JAX + JAX Lib of v0.4.37
Train MaxText and MaxDiffusion on Cloud TPU v6e
The following sections cover the training lifecycle of the MaxText and MaxDiffusion models.
In general, the high-level steps are:
- Build the workload base image.
- Run your workload using XPK.
- Build the training command for the workload.
- Deploy the workload.
- Follow the workload and view metrics.
- Delete the XPK workload if it isn't needed.
- Delete the XPK cluster when it's no longer needed.
Build base image
Install MaxText or MaxDiffusion and build the Docker image:
Clone the repository you want to use and change to the directory for the repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Configure Docker to use the Google Cloud CLI:
gcloud auth configure-docker
Build the Docker image using the following command or using a JAX AI image. For more information about JAX AI images, see JAX AI images.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Set your project ID in your active gcloud CLI configuration:
gcloud config set project ${PROJECT_ID}
If you're launching the workload from a machine that doesn't have the image built locally, upload the image.
Set the
CLOUD_IMAGE_NAME
environment variable:export CLOUD_IMAGE_NAME=${USER}_runner
Upload the image:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Run your workload using XPK
Set the following environment variables if you're not using the default values set by MaxText or MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Build your model script. This script will be copied as a training command in a later step.
Don't execute the model script yet.
MaxText
MaxText is a high performance, highly scalable, open-source LLM written in pure Python and JAX and targeting Google Cloud TPUs and GPUs for training and inference.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma is a family of open-weights LLMs developed by Google DeepMind, based on Gemini research and technology.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama is a family of open-weights LLMs developed by Meta.
For an example of how to run Llama3 on PyTorch, see torch_xla models in the torchprime repository.
MaxDiffusion
MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python and JAX that run on XLA devices including Cloud TPUs and GPUs. Stable Diffusion is a latent text-to-image model that generates photo-realistic images from any text input.
You need to install a specific Git branch to run MaxDiffusion as shown in the following training script.
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash \ run_name=sdxl-ddp-v6e
Export the following variables:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Environment variable descriptions
Variable Description CLUSTER_NAME
The name of your XPK cluster. ACCELERATOR_TYPE
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions. NUM_SLICES
The number of TPU slices. YOUR_MODEL_SCRIPT
The model script to execute as a training command. Run the model using the script you created in the previous step. You must either specify the
--base-docker-image
flag to use the MaxText base image or specify the--docker-image
flag and the image you want to use.You can choose to add the following optional flags:
- You can enable debug logging by including the
--enable-debug-logs
flag. For more information, see Debug JAX on MaxText. - You can create a Vertex AI Experiment to upload data to
Vertex AI TensorBoard by including the
--use-vertex-tensorboard
flag. For more information, see Monitor JAX on MaxText using Vertex AI.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command="${YOUR_MODEL_SCRIPT}"
The output includes a link to follow your workload. Open the link and click the Logs tab to track your workload in real time.
- You can enable debug logging by including the
Debug JAX on MaxText
Use supplemental XPK commands to diagnose why the cluster or workload isn't running:
- XPK workload list
- XPK inspector
- Enable verbose logging in your workload logs using the
--enable-debug-logs
flag when you create the XPK workload
Monitor JAX on MaxText using Vertex AI
To use TensorBoard, your Google Cloud user account must have the aiplatform.user
role. Run the following command to grant this role:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
View scalar and profile data through the Vertex AI managed TensorBoard.
Increase resource management (CRUD) requests for the zone you're using from 600 to 5000. This might not be an issue for small workloads using less than 16 VMs.
Install dependencies such as
cloud-accelerator-diagnostics
for Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Create your XPK cluster using the
--create-vertex-tensorboard
flag, as documented in Create Vertex AI TensorBoard. You can also run this command on existing clusters.Create your Vertex AI experiment when running your XPK workload using the
--use-vertex-tensorboard
flag and the optional--experiment-name
flag. For the full list of steps, see Create Vertex AI Experiment to upload data to Vertex AI TensorBoard.
The logs include a link to a Vertex AI TensorBoard, similar to the following:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
You can also find the Vertex AI TensorBoard link in the Google Cloud console. Go to Vertex AI Experiments in the Google Cloud console. Select the appropriate region from the drop-down.
The TensorBoard directory is also written to the Cloud Storage bucket that
you specified with ${BASE_OUTPUT_DIR}
.
Delete your XPK workload
Use the xpk workload delete
command
to delete one or more workloads based on the job prefix or job status. This
command might be useful if you sent XPK workloads that no longer need to be run,
or if you have jobs that are stuck in the queue.
Delete your XPK cluster
Use the xpk cluster delete
command to delete your cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
MaxDiffusion benchmarking results
We ran the training script for MaxDiffusion on a v6e-4, a v6e-16, and two v6e-16. The following table shows the measured throughputs.
v6e-4 | v6e-16 | Two v6e-16 | |
---|---|---|---|
Training steps | 0.069 | 0.073 | 0.13 |
Global batch size | 8 | 32 | 64 |
Throughput (examples/sec) | 115.9 | 438.4 | 492.3 |
Train Llama models using PyTorch/XLA on Cloud TPU v6e
This section describes how to train Llama models using PyTorch/XLA on Cloud TPU v6e using the WikiText dataset.
Get access to Hugging Face and the Llama 3 model
You need a Hugging Face user access token for this example. For information about creating user access tokens, see the Hugging Face documentation on user access tokens.
You also need permission to access the Llama-3-8B model on Hugging Face. To get access, go to the Meta-Llama-3-8B model on HuggingFace and request access.
Create a Cloud TPU VM
Create a Cloud TPU v6e with 8 chips for this example.
Set up environment variables:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east1-d export ACCELERATOR_TYPE=v6e-8 export RUNTIME_VERSION=v2-alpha-tpuv6e
Environment variable descriptions
Variable Description PROJECT_ID
Your Google Cloud project ID. Use an existing project or create a new one. TPU_NAME
The name of the TPU. ZONE
The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones. ACCELERATOR_TYPE
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions. RUNTIME_VERSION
The Cloud TPU software version. Create a Cloud TPU VM:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Installation
Install the pytorch-tpu/transformers
fork
of Hugging Face transformers and dependencies. This example was tested with the
following dependency versions:
torch
: compatible with 2.5.0torch_xla[tpu]
: compatible with 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Set up model configuration files
The training command in the next section, Run the model, uses two JSON configuration files to define model parameters and Fully Sharded Data Parallel (FSDP) configuration. FSDP sharding lets you use a bigger batch size while training by sharding your model weights across multiple TPUs. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD user guide.
Create the model parameter configuration file. The following is the model parameter configuration for Llama-3-8B. For other models, find the configuration file on Hugging Face. For example, see the Llama-2-7B config.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Create the FSDP configuration file:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
For more information about FSDP, see Fully Sharded Data Parallel using SPMD .
Upload the configuration files to your Cloud TPU VMs using the following command:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Run the model
Using the configuration files you created in the previous section, run the
run_clm.py
script to train the Llama-3-8B model on the WikiText dataset. The
training script takes approximately 10 minutes to run on a Cloud TPU v6e-8.
Sign in to Hugging Face on your Cloud TPU using the following command:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Run the model training:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Troubleshooting PyTorch/XLA
If you set the optional variables for debugging in the previous section,
the profile for the model will be stored at the location specified by the
variable PROFILE_LOGDIR
. You can extract the xplane.pb
file stored
at this location and use tensorboard
to view the profiles in your
browser using the TensorBoard instructions.
If PyTorch/XLA isn't performing as expected, see the Troubleshooting guide, which has suggestions for debugging, profiling, and optimizing your model.