Executar um cálculo em uma VM do Cloud TPU usando o JAX

Este documento apresenta uma breve introdução sobre como trabalhar com o JAX e o Cloud TPU.

Antes de começar

Antes de executar os comandos neste documento, crie uma conta do Google Cloud, instale a CLI do Google Cloud e configure o comando gcloud. Para mais informações, consulte Configurar o ambiente do Cloud TPU.

Criar uma VM do Cloud TPU usando gcloud

  1. Defina algumas variáveis de ambiente para facilitar o uso dos comandos.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descrições de variáveis de ambiente

    Variável Descrição
    PROJECT_ID O ID do projeto do Google Cloud . Use um projeto atual ou crie um novo.
    TPU_NAME O nome da TPU.
    ZONE A zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU.
    ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU.
    RUNTIME_VERSION A versão do software do Cloud TPU.

  2. Crie sua VM de TPU executando o comando a seguir em um Cloud Shell ou no terminal do computador em que a CLI do Google Cloud está instalada.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Conectar-se à VM do Cloud TPU

Conecte-se à VM de TPU por SSH usando o seguinte comando:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

Se você não consegue se conectar a uma VM de TPU usando SSH, ela pode não ter um endereço IP externo. Para acessar uma VM de TPU sem um endereço IP externo, siga as instruções em Conectar-se a uma VM de TPU sem um endereço IP público.

Instalar o JAX na VM do Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificação do sistema

Verifique se o JAX pode acessar a TPU e executar operações básicas:

  1. Inicie o interpretador do Python 3:

    (vm)$ python3
    >>> import jax
  2. Confira o número de núcleos de TPU disponíveis:

    >>> jax.device_count()

O número de núcleos de TPU vai aparecer. O número de núcleos exibidos depende da versão da TPU que você está usando. Para mais informações, consulte Versões de TPU.

Fazer um cálculo

>>> jax.numpy.add(1, 1)

O resultado da adição numpy é exibido:

Resultado do comando:

Array(2, dtype=int32, weak_type=True)

Sair do interpretador do Python

>>> exit()

Como executar códigos do JAX em uma VM de TPU

Agora é possível executar qualquer código do JAX. Os exemplos de Flax são ótimos para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional MNIST básica:

  1. Instale as dependências dos exemplos do Flax:

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instale o Flax:

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Execute o script de treinamento MNIST do Flax:

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

O script faz o download do conjunto de dados e inicia o treinamento. O resultado do script de treinamento será assim:

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Limpeza

Para evitar cobranças na conta do Google Cloud pelos recursos usados nesta página, siga as etapas abaixo.

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Cloud TPU, caso ainda não tenha feito isso:

    (vm)$ exit

    Agora o prompt precisa ser username@projectname, mostrando que você está no Cloud Shell.

  2. Exclua o Cloud TPU:

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. Execute o comando abaixo para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

Notas de desempenho

Confira alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.

Preenchimento

Uma das causas mais comuns de desempenho lento em TPUs é o preenchimento involuntário:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matrizes tem um desempenho melhor com pares de matrizes grandes que minimizam a necessidade de preenchimento.

bfloat16 dtype

Por padrão, a multiplicação de matrizes no JAX em TPUs usa bfloat16 com acúmulo de float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa a precisão bfloat16 mista (mais rápida).
  • precision=jax.lax.Precision.HIGH: usa vários passes MXU para aumentar a precisão.
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passes MXU para alcançar uma precisão float32 completa.

O JAX também adiciona o bfloat16 dtype, que pode ser usado para converter matrizes explicitamente em bfloat16. Por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

A seguir

Para mais informações sobre o Cloud TPU, consulte: