Executar um cálculo em uma VM da Cloud TPU usando o JAX

Este documento fornece uma breve introdução sobre como trabalhar com JAX e Cloud TPU.

Antes de seguir este guia de início rápido, você precisa criar uma conta do Google Cloud Platform Google Cloud, instalar a Google Cloud CLI e configurar o comando gcloud. Para mais informações, consulte Configurar uma conta e um projeto do Cloud TPU.

Instalar a CLI do Google Cloud

A CLI do Google Cloud contém ferramentas e bibliotecas para interagir produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar a Google Cloud CLI.

Configurar o comando gcloud

Execute os comandos a seguir para configurar gcloud para usar o projeto do Google Cloud e instalar os componentes necessários para a visualização da VM da TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Ativar a API Cloud TPU

  1. Ative a API Cloud TPU usando o comando gcloud a seguir no Cloud Shell. Você também pode ativá-lo a partir da Console do Google Cloud).

    $ gcloud services enable tpu.googleapis.com
  2. Execute o comando a seguir para criar uma identidade de serviço.

    $ gcloud beta services identity create --service tpu.googleapis.com

Criar uma VM do Cloud TPU com gcloud

Com as VMs do Cloud TPU, seu modelo e código são executados diretamente no host da TPU máquina virtual. Conecte-se via SSH diretamente ao host da TPU. É possível executar códigos arbitrários, instalar pacotes, visualizar registros e depurar código diretamente no host da TPU.

  1. Crie sua VM da TPU executando o seguinte comando em um Cloud Shell ou o terminal do computador em que a Google Cloud CLI está instalado.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central1-a \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base

    Campos obrigatórios

    zone
    A zona em que você planeja criar a Cloud TPU.
    accelerator-type
    O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU.
    version
    A versão do software do Cloud TPU. Para todos os tipos de TPU, use tpu-ubuntu2204-base:

Conecte-se à VM do Cloud TPU

Use o comando a seguir para se conectar à VM da TPU:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a

Campos obrigatórios

tpu_name
O nome da VM de TPU a que você está se conectando.
zone
A zona em que você criou o Cloud TPU.

Instalar o JAX na VM do Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificação do sistema

Verifique se o JAX consegue acessar a TPU e executar operações básicas:

Inicie o interpretador do Python 3:

(vm)$ python3
>>> import jax

Veja o número de núcleos de TPU disponíveis:

>>> jax.device_count()

O número de núcleos de TPU é exibido. Se você estiver usando uma TPU v4, esse valor será 4. Caso esteja usando uma TPU v2 ou v3, use 8.

Faça um cálculo simples:

>>> jax.numpy.add(1, 1)

O resultado da adição "numpy" é exibido:

Resultado do comando:

Array(2, dtype=int32, weak_type=true)

Saia do interpretador Python:

>>> exit()

Como executar código do JAX em uma VM de TPU

Agora você pode executar qualquer código JAX que quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:

  1. Instale dependências de exemplos do Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instalar o FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Execute o script de treinamento MNIST do FLAX

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

O script faz o download do conjunto de dados e inicia o treinamento. A saída do script será semelhante a esta:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Limpar

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit
  2. Exclua o Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central1-a
  3. Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central1-a

Notas de desempenho

Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.

Preenchimento

Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.

bfloat16 dtype

Por padrão, a multiplicação de matrizes do JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa bfloat16 misto precisão (mais rápido)
  • precision=jax.lax.Precision.HIGH: usa vários passes MXU para aumentar a precisão
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passes MXU. para alcançar a precisão total de float32

O JAX também adiciona o dtype bfloat16, que pode ser usado para transmitir matrizes explicitamente bfloat16, por exemplo: jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Como executar o JAX em um Colab

Quando você executa o código JAX em um bloco do Colab, o Colab automaticamente cria um nó legado de TPU. Os nós de TPU têm uma arquitetura diferente. Para mais informações, consulte Arquitetura do sistema.

A seguir

Para mais informações sobre o Cloud TPU, consulte: