Guia de início rápido do JAX da VM do Cloud TPU

Neste documento, apresentamos uma breve introdução ao trabalho com o JAX e o Cloud TPU.

Antes de seguir este guia de início rápido, é preciso criar uma conta do Google Cloud Platform, instalar o SDK do Google Cloud Platform. e configure o comando gcloud. Para mais informações, consulte Configurar uma conta e um projeto do Cloud TPU.

Instalar o Google Cloud SDK

O SDK do Google Cloud contém ferramentas e bibliotecas para interagir com os produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar o SDK do Google Cloud.

Configurar o comando gcloud

Execute os comandos a seguir para configurar gcloud para usar o projeto do GCP e instalar os componentes necessários para a visualização da VM da TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Ativar a API Cloud TPU

  1. Ative a API Cloud TPU usando o seguinte comando gcloud no Cloud Shell. Também é possível ativá-lo no Console do Google Cloud.

    $ gcloud services enable tpu.googleapis.com
    
  2. Execute o comando a seguir para criar uma identidade de serviço:

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crie uma VM do Cloud TPU com gcloud

Com as VMs do Cloud TPU, seu modelo e código são executados diretamente na máquina host do TPU. Conecte-se via SSH diretamente ao host da TPU. É possível executar código arbitrário, instalar pacotes, visualizar registros e depurar código diretamente no TPU TPU.

  1. Crie sua VM de TPU executando o comando a seguir em um GCP Cloud Shell ou no terminal do computador em que o SDK do Google Cloud está instalado.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Campos obrigatórios

    zone
    A zona em que você planeja criar a Cloud TPU.
    accelerator-type
    O tipo da Cloud TPU a ser criada.
    version
    A versão do ambiente de execução do Cloud TPU. Defina como "v2-alpha" quando usar o JAX em dispositivos de TPU única, em frações de pod ou em pods inteiros.

Conecte-se à VM da Cloud TPU

Use o comando a seguir para se conectar à VM da TPU:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Campos obrigatórios

tpu_name
O nome da VM da TPU à qual você está se conectando.
zone
A zona em que você criou o Cloud TPU.

Instalar o JAX na VM do Cloud TPU

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificação do sistema

Para verificar se tudo está instalado corretamente, verifique se o JAX vê os núcleos do Cloud TPU e pode executar operações básicas:

Inicie o interpretador do Python 3:

(vm)$ python3
>>> import jax

Exiba o número de núcleos de TPU disponíveis:

>>> jax.device_count()

O número de núcleos de TPU será exibido. Ele precisa ser 8.

Faça um cálculo simples:

>>> jax.numpy.add(1, 1)

O resultado da adição de "numpy" é exibido:

Saída do comando:

DeviceArray(2, dtype=int32)

Saia do interpretador Python:

>>> exit()

Como executar código JAX em uma VM de TPU

Agora você pode executar qualquer código JAX que quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:

  1. Instalar conjuntos de dados do TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Instale o FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Executar o script de treinamento MNIST do FLAX

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    A saída do script será semelhante a esta:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Como fazer a limpeza

Quando terminar de usar a VM da TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit
    
  2. Exclua o Cloud TPU.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Verifique se os recursos foram excluídos executando o seguinte comando. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

Notas de desempenho

Veja alguns detalhes importantes que são relevantes principalmente para o uso de TPUs no JAX.

Preenchimento

Uma das causas mais comuns de desempenho lento em TPUs é o preenchimento involuntário:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matriz tem melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.

dtype bfloat16

Por padrão, a multiplicação de matrizes no JAX nas TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, ponto, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa a precisão bfloat16 mista (mais rápida)
  • precision=jax.lax.Precision.HIGH: usa vários cartões MXU para alcançar maior precisão
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passagens MXU para alcançar precisão completa float32.

O JAX também adiciona o dtype bfloat16, que pode ser usado para transmitir matrizes explicitamente para bfloat16, por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Como executar o JAX em um Colab

Quando você executa o código JAX em um notebook Colab, o Colab cria automaticamente um nó legado da TPU. Os nós da TPU têm uma arquitetura diferente. Para ver mais informações, consulte Arquitetura do sistema.