Guia de início rápido da VM do Cloud TPU

Neste documento, você verá uma breve introdução de como trabalhar com JAX e Cloud TPU.

Faça login na sua Conta do Google. Se você ainda não tiver uma, inscreva-se em uma nova conta. No Console do Google Cloud, selecione ou crie um projeto do Cloud na página do seletor de projetos. Verifique se o faturamento está ativado no seu projeto.

Instalar o Google Cloud SDK

O SDK do Google Cloud contém ferramentas e bibliotecas para interagir com os produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar o SDK do Google Cloud.

Configurar o comando gcloud

Execute os seguintes comandos para configurar gcloud para usar o projeto do GCP e instalar os componentes necessários para a visualização da VM do TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Ativar a API Cloud TPU

  1. Ative a API Cloud TPU usando o seguinte comando gcloud no Cloud Shell. Também é possível ativá-lo no Console do Google Cloud.

    $ gcloud services enable tpu.googleapis.com
    
  2. Execute o seguinte comando para criar uma identidade de serviço.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crie uma VM do Cloud TPU com gcloud

Com VMs do Cloud TPU, seu modelo e código são executados diretamente na máquina host da TPU. Você se conecta SSH diretamente ao host da TPU. É possível executar códigos arbitrários, instalar pacotes, ver registros e depurar códigos diretamente no host da TPU.

  1. Crie a VM de TPU executando o seguinte comando em um Cloud Shell do GCP ou no terminal do computador em que o SDK do Google Cloud está instalado.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Campos obrigatórios

    zone
    A zona em que você planeja criar a Cloud TPU.
    accelerator-type
    O tipo do Cloud TPU a ser criado.
    version
    A versão do ambiente de execução do Cloud TPU. Defina isto como "v2-alpha" quando você estiver usando o JAX em dispositivos de TPU única, frações de pod ou pods inteiros.

Conecte-se à VM do Cloud TPU

SSH na VM de TPU usando o seguinte comando:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Campos obrigatórios

tpu_name
O nome da VM da TPU a que você está se conectando.
zone
A zona em que você criou o Cloud TPU.

Instalar o JAX na VM do Cloud TPU

(vm)$ pip3 install --upgrade jax jaxlib

Verificação do sistema

Verifique se tudo está instalado corretamente verificando se o JAX vê os núcleos do Cloud TPU e pode executar operações básicas:

Inicie o interpretador do Python 3:

(vm)$ python3
>>> import jax

Exiba o número de núcleos de TPU disponíveis:

>>> jax.device_count()

O número de núcleos de TPU é exibido. Ele precisa ser 8.

Realize um cálculo simples:

>>> jax.numpy.add(1, 1)

O resultado da adição do numpy é exibido:

Saída do comando:

DeviceArray(2, dtype=int32)

Saia do interpretador do Python:

>>> exit()

Como executar o código JAX em uma VM de TPU

Agora, você pode executar qualquer código JAX por conta própria. Os exemplos de flax são um ótimo lugar para começar a executar modelos de ML padrão no JAX. Por exemplo, para treinar uma rede convolucional básica do MNIST:

  1. Instalar conjuntos de dados do TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Instale o XcodeX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Executar o script de treinamento MNISTX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    A saída do script será semelhante a esta:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Limpeza

Quando terminar sua VM de TPU, siga estas etapas para limpar seus recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit
    
  2. Exclua o Cloud TPU.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Verifique se os recursos foram excluídos executando o comando a seguir. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

Observações sobre o desempenho

Veja alguns detalhes importantes que são particularmente relevantes para o uso de TPUs no JAX.

Preenchimento

Uma das causas mais comuns de desempenho lento em TPUs é incluir preenchimento inadvertido:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matrizes funciona melhor com pares de matrizes grandes que minimizam a necessidade de preenchimento.

Dtype bfloat16

Por padrão, a multiplicação de matrizes em JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, ponto, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa precisão bfloat16 mista (mais rápida).
  • precision=jax.lax.Precision.HIGH: usa vários cartões MXU para conseguir uma precisão maior.
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passagens MXU para conseguir uma precisão total de float32

O JAX também adiciona o tipo bfloat16, que pode ser usado para converter matrizes explicitamente em bfloat16, por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Como executar JAX em um Colab

Ao executar o código JAX em um notebook do Colab, o Colab cria automaticamente um nó de TPU legado. Os nós da TPU têm uma arquitetura diferente. Para mais informações, consulte Arquitetura do sistema.