Cloud TPU v5p training

Cloud TPU v5p is Google Cloud's fifth generation Cloud TPU and the successor to the v4 TPU. v5p is optimized for large scale training and to be a leading platform for the development of foundational LLMs, diffusion models, and generative AI. At a high level, v5p provides up to 2x the performance of v4, while also packing 2x more TPUs into a Pod (6k largest slice versus 3k in v4), resulting in up to 4x performance at a Pod-level. It also runs at a higher clock frequency (1.75Ghz vs. 1.05Ghz), adds SparseCore for large scale embeddings, and triples High Bandwidth Memory (HBM) capacity.

Cloud TPU v5p concepts

If you are new to Cloud TPUs, check out the TPU documentation home.

Cloud TPU concepts (for example, slices, hosts, and TensorCores) and Cloud TPU system architecture for all Cloud TPU versions are described in the Cloud TPU System Architecture page.

Each Cloud TPU version requires specific accelerator types for training or inference. These accelerator types are described in v5p configurations.

Manage TPU resources

All commands you can use to manage your TPU VMs are described in Managing TPUs or Queued resources user guide for managing queued resources.

Framework Setup

This section describes the general setup process for model training using JAX or PyTorch with TPU v5p.

Setup for JAX

If you have slice shapes greater than 4 chips, you will have multiple VMs in one slice. In this case, you need to use the --worker=all flag to run the installation on all TPU VMs using a single command:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

You can run the following command to check number of devices (the outputs shown here were produced with a v5p-32 slice). This code tests that everything is installed correctly by checking that JAX sees the Cloud TPU TensorCores and can run basic operations:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

The output will be similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() shows the total number of chips in the given slice. jax.local_device_count() indicates the count of chips accessible by a single VM in this slice.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

The output will be similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Use --node=all to run the command on all Multislice workers.

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

Try the JAX tutorials in this document to get started with v5p training using JAX.

Setup for PyTorch

The PJRT runtime is the only supported runtime for v5p, and PyTorch 2.1+ uses PJRT as the default runtime for all TPU versions. This section describes how to start using PJRT on v5p Pods with PyTorch/XLA 2.2.0 for all workers.

Install dependencies

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
sudo apt-get update
sudo apt-get install libopenblas-dev -y
pip3 install numpy
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
'

Use a Python script with PJRT to do a validation of the installation to show available TPU devices (the outputs shown here were produced with a v5p-32 slice).

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']
['xla:0', 'xla:1', 'xla:2', 'xla:3']

Use --node=all to run the command on all Multislice workers.

gcloud compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --node=all --worker=all \
--command='
PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
'

Try the PyTorch tutorials in this document to get started with v5p training using PyTorch.

Monitor and profile

Cloud TPU v5p supports monitoring and profiling using the same methods as previous generations of Cloud TPU. You can read Profile your model with Cloud TPU tools to learn more about profiling and Monitoring Cloud TPU VMs to learn more about monitoring.

Training tutorials

This section focuses on training tutorials for a single slice. Adapting these tutorials to Multislice training can be achieved by adding the --node=all flag to SSH commands. For details and best practices, refer to the Multislice introduction.

JAX tutorials

Train Diffusion 2.1

This tutorial shows you how to train the Stable Diffusion model from HuggingFace using the Pokémon dataset on Cloud TPU v5p.

The Stable Diffusion model is a latent text-to-image model that generates photo-realistic images from any text input. For more information, see the following resources:

Set up

  1. Set up a storage bucket for your model output.

    gcloud storage buckets create gs://your_bucket \
     --project=your_project \
     --location=us-east5-a
    
  2. Create environment variables:

    export GCS_BUCKET_NAME=your-bucket
    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your-service-account
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-qr-name
    export QUOTA_TYPE=spot
    export VALID_UNTIL_DURATION=1d
    

    Command flag descriptions

    Variable Description
    PROJECT_ID Google Cloud Project Name
    ACCELERATOR_TYPE See the TPU versions page for your TPU version.
    ZONE See the TPU regions and zones document for the supported zones.
    RUNTIME_VERSION For v5p use v2-alpha-tpuv5 for the RUNTIME_VERSION.
    SERVICE_ACCOUNT This is the address of your service account that you can find in Google Cloud console -> IAM -> Service Accounts. For example: tpu-service-account@myprojectID.iam.gserviceaccount.com
    TPU_NAME The user-assigned text ID of the TPU which is created when the queued resource request is allocated.
    QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request. See the Queued Resources document for information about queued resources.
    QUOTA_TYPE Can be reserved or spot. If neither of these are specified, the QUOTA_TYPE defaults to on-demand. See quotas for information on the different types of quotas supported by Cloud TPU.
    VALID_UNTIL_DURATION The duration for which the request is valid. See Queued resources for information about the different valid durations.
  3. Create a TPU resource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_UNTIL_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state. Check the state of your queued resource by running the following command:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

    state: ACTIVE
    
  4. Train the model

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --project $PROJECT_ID --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use.
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install the latest version of JAX
    pip3 install -r requirements.txt
    pip3 install .
    export LIBTPU_INIT_ARGS=""
    python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run base_output_directory=gs://$GCS_BUCKET_NAME enable_profiler=False"
    

Clean up

Delete your TPU and queued resource request at the end of your session or to remove queued resource requests that are in the "FAILED" state. To delete a queued resource, delete the slice(s) and then the queued resource request in 2 steps:

   gcloud compute tpus tpu-vm delete ${TPU_NAME} --project=${PROJECT_ID} \
      --zone=${ZONE} --quiet
   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet

Or, use --force to delete the slice(s) and the queued resource request in a single step:

# With --force
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project ${PROJECT_ID} --zone ${ZONE} --quiet --force

Benchmark results

The Stable Diffusion training script ran on v5p-8, v5p-32, and v5p-128. The following table shows the throughput.

v5p-8

v5p-32

v5p-128

Train Step

150

150

150

Global batch size

32

64

64

Throughput (examples/sec)

12.10

18.08

19.10

MaxText

This tutorial shows you how to train the MaxText model using a synthetic dataset on Cloud TPU.

MaxText is a high performance, arbitrarily scalable, open source, well-tested LLM written in pure Python/JAX targeting Cloud TPUs. MaxText empowers researchers and developers with an accessible and adaptable tool for advancing the frontiers of Natural Language Processing (NLP) research and development.

Before running this tutorial, you need to set up your Cloud TPU environment.

  1. Set up environment variables

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name # user defined TPU name
    export ACCELERATOR_TYPE=v5p-256
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export RUN_NAME=your_experiment_run_name # user defined name for this run
    export GCS_BUCKET_NAME=your_bucket_name # Output cloud folder. Should start with gs://
    export MAXTEXT_OUTPUT_PATH=${GCS_BUCKET_NAME}/your_experiment_output_path
    export NUM_SLICES=1 # Update the value to a number >1 for Multislice.
    

    Command flag descriptions

    Variable Description
    PROJECT_ID Google Cloud Project Name
    TPU_NAME A user defined name for your TPU.
    ACCELERATOR_TYPE See the TPU versions page for your TPU version.
    ZONE See the TPU regions and zones document for the supported zones.
    RUNTIME_VERSION For v5p use v2-alpha-tpuv5 for the runtime version.
    RUN_NAME User supplied experiment run name.

    Optional setup recommended for Multislice:

    export NETWORK_NAME=your_network_name
    export FIREWALL_RULE_NAME=your_firewall_rule_name
    

    If you're running Multislice workloads and want optimal network performance, consider creating a dedicated network with a Maximum Transmission Unit (MTU) of 8896 bytes and configuring appropriate firewall rules. While optional, this step can significantly improve performance, especially when scaling up the number of slices over the data-center network (DCN). Note creating a network requires compute.networks.create permission in the project. The following examples show how to create a dedicated network and firewall rule.

    Create a dedicated network:

    gcloud compute networks create ${NETWORK_NAME} \
    --mtu=8896 \
    --project=${PROJECT_ID} \
    --subnet-mode=auto \
    --bgp-routing-mode=regional
    

    Create a firewall rule:

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT_ID}
    
  2. Clone the MaxText repository

    git clone https://github.com/google/maxtext.git
    
  3. Train the model

    The following sections describe two options for training MaxText.

    Option 1

    If you want a script to manage the entire workflow, from provisioning Cloud TPUs and installing dependencies to running your model and tearing down resources, you can use multihost_job.py.

    cd maxtext && python3 multihost_job.py --PROJECT=${PROJECT_ID} --ZONE=${ZONE} \
    --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} \
    --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} #user defined run name \
    --BUCKET_NAME=${GCS_BUCKET_NAME} \ #used to store logs and configs
    --COMMAND="bash setup.sh && bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

    After initiating the script, you should see a message similar to the following in the log. The log location is referenced in the output message. Click the first link to access the logs of all workers once TPU provisioning is complete.

    ------------------------------------
    
    multihost_job finished running, TPUs are starting up to run your job remotely.
    
    Logs for your job are displayed here:
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22_log%22%2529;?project=PROJECT_ID
    
    To see the output of a single host, you may edit the slice and worker
    number in the `log_file_path` property here:
    
    https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22RUN_NAME_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2FRUN_NAME%2Fmain_command_log_slice_0_worker_0%22;?project=PROJECT_ID
    
    When your job is finished, the main command log is in your Cloud Storage
    bucket:
    https://console.cloud.google.com/storage/browser/YOUR_BUCKET_NAME/RUN_NAME?project=PROJECT_ID
    
    View the status of the created TPUs using:
    gcloud compute tpus queued-resources list --filter=RUN_NAME
    --zone=ZONE --project=PROJECT_ID
    
Option 2

To run the training script multiple times on a provisioned Cloud TPU, use the multihost_runner.py script to use the resource.

  1. Set up variables to create a TPU.

    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export VALID_DURATION=1d
    export QUOTA_TYPE=quota_type
    
    --node-count ${NODE_COUNT} \
    --node-prefix ${NODE_PREFIX} # optional, the default is QUEUED_RESOURCE_ID
    
  2. Create a TPU resource.

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    You will be able to connect to your TPU VMs using SSH once your QueuedResource is in state ACTIVE:

    Use the describe command to query the status of your queued resource.

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}
    --project ${PROJECT_ID} --zone ${ZONE}
    

    When your queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Connect to your TPU using SSH

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. Install dependencies

    export TPU_NAME=your_tpu_name
    export MAXTEXT_OUTPUT_PATH=output-path
    
    cd maxtext && python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND='bash setup.sh'
    
  5. Run the model with various configuration scripts, such as 32b.sh, 64b.sh. If you are running the script from a TPU VM, you need to add the flag --INTERNAL_IP=true.

    python3 multihost_runner.py --TPU_PREFIX=${TPU_NAME} \
    --COMMAND="bash MaxText/configs/experimental/64b.sh RUN_NAME=${RUN_NAME}
    OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} PLATFORM=gce"
    

Clean up

Delete your TPU and queued resources.

Benchmark results

The MaxText training script was run from 32B to 1160B with bf16 precision. The results of these runs are shown in the following table.

No. of params

Accelerator Type

TFLOP/chip/sec

Model flops utilization

(MFU)

32B

v5p-128

3.28E+02

71.47%

64B

v5p-128

3.23E+02

70.31%

128B

v5p-256

3.15E+02

68.68%

128B

v5p-512

3.15E+02

68.53%

256B

v5p-1024

3.16E+02

68.82%

512B

v5p-1024

2.94E+02

63.99%

1024B

v5p-2048

2.49E+02

64.05%

1024B

v5p-4096

2.97E+02

64.80%

1160B

v5p-7680

2.95E+02

64.27%

1160B

v5p-12288

3.04E+02

66.23%

The 256B parameter model has been tested on v5p-512 and v5p-1024 using both bf16 and int8 weights. The following table displays the results of these tests.

v5p-512

v5p-512

v5p-1024

v5p-1024

Global batch size

(tokens)

5.24E+05

5.24E+05

1.05E+06

1.05E+06

Precision

bf16

int8

bf16

int8

TFLOP/chip/sec

307

408

308

414

Model flops utilization

(MFU)

66.98%

88.85%

67.09%

90.23%

TensorFlow tutorials

Train ResNet on a single host v5p

This tutorial describes how to train ImageNet on a v5p-8 TPU using a fake dataset. If you want to use a different dataset, refer to Preparing the dataset.

Set up

  1. Create environment variables:

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.17.0-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    For this tutorial, use v5p-8 as the ACCELERATOR_TYPE.

  2. Create a TPU resource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    You will be able to connect to your TPU VM using SSH once your queued resource is in the ACTIVE state. To check the state of your queued resource, use the following command:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. Connect to your TPU using SSH

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. Set some environment variables

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
    
  5. Change to the models repository directory and install requirements.

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt
    

Train the model

  1. Run the training script.

    python3 official/vision/train.py \
      --tpu=local \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

Clean up

Delete your TPU and queued resources.

Train ResNet on a multi-host v5p

This tutorial describes how to train ImageNet on v5p-16 or larger using a fake dataset. If you want to use a different dataset, see Preparing the dataset.

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export TPU_NAME=your_tpu_name
    export ZONE=us-east1-c
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=tpu-vm-tf-2.17.0-pod-pjrt
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    ACCELERATOR_TYPE can be either v5p-16 or larger.

  2. Create a TPU resource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --${QUOTA_TYPE}
    

    You will be able to connect to your TPU VM using SSH once your queued resource is in the ACTIVE state.

    Use the describe command to query the status of your queued resource:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  3. Connect to your TPU (worker zero) using SSH

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
      --project ${PROJECT_ID} \
      --zone ${ZONE}
    
  4. Set some environment variables

    export TPU_NAME=your_tpu_name
    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
    
  5. Change to the models repository directory and install requirements.

    cd $MODELS_REPO && git checkout r2.15.0
    pip install -r official/requirements.txt
    

Train the model

  1. Run the training script.

    python3 official/vision/train.py \
      --tpu=${TPU_NAME} \
      --experiment=resnet_imagenet \
      --mode=train_and_eval \
      --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
      --model_dir=${MODEL_DIR} \
      --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
    

Clean up

Delete your TPU and queued resources.

PyTorch/XLA

Llama 2

This tutorial will cover how to train the Llama 2 7B model on v5p using a fork of the HuggingFace repository on PyTorch/XLA with General and Scalable Parallelization for ML Computation Graphs (GSPMD).

Setup

  1. Create variables for project ID, accelerator type, zone, runtime version, and TPU name.

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_DURATION=1d
    
  2. Create a TPU resource

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id ${TPU_NAME} \
    --project ${PROJECT_ID} \
    --zone ${ZONE} \
    --accelerator-type ${ACCELERATOR_TYPE} \
    --runtime-version ${RUNTIME_VERSION} \
    --valid-until-duration ${VALID_DURATION} \
    --service-account ${SERVICE_ACCOUNT} \
    --${QUOTA_TYPE}
    

    You will be able to connect to your TPU VM using SSH once your QueuedResource is in the ACTIVE state:

    Use the describe command to query the status of your queued resource.

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
    --project ${PROJECT_ID} \
    --zone ${ZONE}
    

    When your queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    

  3. Install Pytorch/XLA and required dependencies.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
    --project ${PROJECT_ID} \
    --zone  ${ZONE} \
    --worker=all \
    --command='
    sudo apt-get update
    sudo apt-get install libopenblas-dev -y
    pip3 install numpy
    pip3 install typing-extensions
    pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
    '
    
  4. Download the HuggingFace repository and install requirements.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git
    cd transformers
    pip3 install git+file://$PWD
    pip3 install datasets accelerate evaluate scikit-learn'
    
  5. Download the 7B model config.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command="curl https://huggingface.co/TheBloke/Llama-2-7B-fp16/raw/main/config.json --output ~/config.json"
    
  6. Train the model

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
    export PJRT_DEVICE=TPU
    export XLA_USE_BF16=1
    export XLA_IR_DEBUG=1
    export XLA_HLO_DEBUG=1
    
    export LIBTPU_INIT_ARGS="--xla_enable_async_collective_permute=true
    --xla_tpu_enable_async_collective_fusion_multiple_steps=true
    --xla_tpu_enable_async_collective_fusion=true
    --xla_tpu_overlap_compute_collective_tc=true
    --xla_enable_async_all_gather=true
    --xla_jf_spmd_threshold_for_windowed_einsum_mib=0"
    
    export PROFILE_EPOCH=0
    export PROFILE_STEP=3
    export PROFILE_DURATION_MS=20000
    export PROFILE_LOGDIR=/tmp/home/
    
    cd transformers
    python examples/pytorch/language-modeling/run_clm.py \
     --tokenizer_name hf-internal-testing/llama-tokenizer \
     --dataset_name wikitext \
     --dataset_config_name wikitext-2-raw-v1 \
     --per_device_train_batch_size 96 \
     --per_device_eval_batch_size 8 \
     --num_train_epochs 1 \
     --do_train \
     --output_dir /tmp/output \
     --overwrite_output_dir \
     --config_name ~/config.json \
     --save_strategy no \
     --logging_strategy no \
     --remove_unused_columns no \
     --optim adafactor \
     --torch_dtype bfloat16 \
     --dataloader_drop_last yes \
     --block_size 2048 \
     --spmd_2d_sharding 1 \
     --spmd_grad_chkpt
    '
    

If you're running in a multislice environment, you need to set the flag --spmd_dcn_parallelism to the number of slices.

The SPMD_USER_GUIDE provides a more in-depth user guide that explains all the different environment variables and toggles of the HF script. To be noted, the LIBTPU_INIT_ARGS will be incorporated into PyTorch/XLA and on by default in future releases.

Clean up

Delete your TPU and queued resources.

Benchmark results

Throughput for all three Llama 2 model sizes are included in the following table.

v5p-8

v5p-128

v5p-128

Model size

7B

13B

70B

Global batch size

96

1024

128

Sharding mesh shape

(4, 1)

(64, 1)

(16, 4)

Model flops utilization

(MFU)

56.67%

55.80%

51.85%

Support and Feedback

We welcome all feedback! To share feedback or request support, fill out the Cloud TPU Support or Feedback form.