Guia de início rápido: executar um cálculo em uma VM da Cloud TPU usando o Jax

Executar um cálculo em uma VM da Cloud TPU usando o Jax

Neste documento, você verá uma introdução breve sobre como trabalhar com o JAX e o Cloud TPU.

Antes de seguir este guia de início rápido, é preciso criar uma conta do Google Cloud Platform, instalar a CLI do Google Cloud e configurar o comando gcloud. Para mais informações, consulte Configurar uma conta e um projeto do Cloud TPU.

Instale a CLI do Google Cloud

A CLI do Google Cloud contém ferramentas e bibliotecas para interagir com os produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar a CLI do Google Cloud.

Configurar o comando gcloud

Execute os comandos a seguir para configurar gcloud para usar o projeto do GCP e instalar os componentes necessários para a visualização da VM de TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Ativar a API Cloud TPU

  1. Ative a API Cloud TPU usando o seguinte comando gcloud no Cloud Shell. Também é possível ativá-lo no Console do Google Cloud.

    $ gcloud services enable tpu.googleapis.com
    
  2. Execute o comando a seguir para criar uma identidade de serviço.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Criar uma VM do Cloud TPU com gcloud

Com as VMs do Cloud TPU, seu modelo e código são executados diretamente na máquina host da TPU. Conecte-se via SSH diretamente ao host da TPU. É possível executar código arbitrário, instalar pacotes, visualizar registros e depurar código diretamente no host da TPU.

  1. Crie sua VM de TPU executando o comando a seguir em um Cloud Shell do GCP ou no terminal do computador em que a CLI do Google Cloud está instalada.

    (vm)$  gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version tpu-vm-tf-2.8.0
    

    Campos obrigatórios

    zone
    A zona em que você planeja criar a Cloud TPU.
    accelerator-type
    O tipo da Cloud TPU a ser criada.
    version
    Versão do software do Cloud TPU.

Conectar-se à VM do Cloud TPU

Use o comando a seguir para se conectar à VM da TPU:

$  gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Campos obrigatórios

tpu_name
O nome da VM de TPU a que você está se conectando.
zone
A zona em que você criou o Cloud TPU.

Instalar o JAX na VM do Cloud TPU

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificação do sistema

Para verificar se tudo está instalado corretamente, verifique se o JAX vê os núcleos do Cloud TPU e pode executar operações básicas:

Inicie o interpretador do Python 3:

(vm)$ python3
>>> import jax

Veja o número de núcleos de TPU disponíveis:

>>> jax.device_count()

O número de núcleos de TPU será exibido. Ele deve ser 8.

Faça um cálculo simples:

>>> jax.numpy.add(1, 1)

O resultado da adição "numpy" é exibido:

Resultado do comando:

DeviceArray(2, dtype=int32)

Saia do interpretador Python:

>>> exit()

Como executar código do JAX em uma VM de TPU

Agora é possível executar qualquer código do JAX que você quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:

  1. Instale conjuntos de dados do TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Instale o FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Execute o script de treinamento MNIST do FLAX

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    O resultado do script de treinamento será assim:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Limpeza

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit
    
  2. Exclua o Cloud TPU.

    $  gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

Notas de desempenho

Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.

Preenchimento

Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.

bfloat16 dtype

Por padrão, a multiplicação de matriz no JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa a precisão bfloat16 mista (mais rápida)
  • precision=jax.lax.Precision.HIGH: usa vários passes MXU para aumentar a precisão
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passes MXU para alcançar uma precisão float32 completa.

O JAX também adiciona o bfloat16 dtype, que pode ser usado para transmitir matrizes explicitamente para bfloat16, por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Como executar o JAX em um Colab

Quando você executa o código JAX em um notebook do Colab, o Colab automaticamente cria um nó legado de TPU. Os nós de TPU têm uma arquitetura diferente. Para mais informações, consulte Arquitetura do sistema.