Executar um cálculo em uma VM da Cloud TPU usando o JAX

Este documento apresenta uma breve introdução sobre como trabalhar com o JAX e o Cloud TPU.

Antes de começar

Antes de executar os comandos neste documento, é necessário criar uma conta Google Cloud, instalar a Google Cloud CLI e configurar o comando gcloud. Para mais informações, consulte Configurar o ambiente do Cloud TPU.

Criar uma VM do Cloud TPU usando gcloud

  1. Defina algumas variáveis de ambiente para facilitar o uso dos comandos.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Descrições das variáveis de ambiente

    PROJECT_ID
    O ID do Google Cloud projeto.
    ACCELERATOR_TYPE
    O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos para cada versão de TPU, consulte Versões de TPU.
    ZONE
    A zona em que você planeja criar o Cloud TPU.
    RUNTIME_VERSION
    A versão do ambiente de execução do Cloud TPU. Para mais informações, consulte Imagens de VM de TPU
    .
    TPU_NAME
    O nome atribuído pelo usuário ao Cloud TPU.
  2. Crie sua VM de TPU executando o comando a seguir em um Cloud Shell ou no terminal do computador em que a Google Cloud CLI está instalada.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

Conecte-se à VM do Cloud TPU

Conecte-se à VM da TPU usando SSH com o seguinte comando:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
--project=$PROJECT_ID \
--zone=$ZONE

Instalar o JAX na VM do Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificação do sistema

Verifique se o JAX pode acessar o TPU e executar operações básicas:

  1. Inicie o interpretador do Python 3:

    (vm)$ python3
    >>> import jax
  2. Veja o número de núcleos de TPU disponíveis:

    >>> jax.device_count()

O número de núcleos de TPU é exibido. O número de cores exibidas depende da versão da TPU que você está usando. Para mais informações, consulte Versões da TPU.

Faça um cálculo:

>>> jax.numpy.add(1, 1)

O resultado da adição "numpy" é exibido:

Resultado do comando:

Array(2, dtype=int32, weak_type=true)

Saia do interpretador Python:

>>> exit()

Como executar código do JAX em uma VM de TPU

Agora é possível executar qualquer código do JAX que você quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:

  1. Instalar dependências de exemplos do Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instalar o FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Execute o script de treinamento MNIST do FLAX

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

O script faz o download do conjunto de dados e inicia o treinamento. A saída do script será semelhante a esta:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Limpar

Para evitar cobranças na conta do Google Cloud pelos recursos usados nesta página, siga estas etapas.

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit
  2. Exclua o Cloud TPU.

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
      --project=$PROJECT_ID \
      --zone=$ZONE
  3. Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.

    $ gcloud compute tpus tpu-vm list \
      --zone=$ZONE

Notas de performance

Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.

Preenchimento

Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:

  • As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
  • A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.

bfloat16 dtype

Por padrão, a multiplicação de matriz no JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:

  • precision=jax.lax.Precision.DEFAULT: usa a precisão bfloat16 mista (mais rápida)
  • precision=jax.lax.Precision.HIGH: usa vários passes MXU para aumentar a precisão
  • precision=jax.lax.Precision.HIGHEST: usa ainda mais passes MXU para alcançar uma precisão float32 completa.

O JAX também adiciona o tipo de dados bfloat16, que pode ser usado para transmitir matrizes explicitamente para bfloat16, por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

A seguir

Para mais informações sobre o Cloud TPU, consulte: