Executar um cálculo em uma VM da Cloud TPU usando o JAX
Este documento apresenta uma breve introdução sobre como trabalhar com o JAX e o Cloud TPU.
Antes de começar
Antes de executar os comandos neste documento, é necessário criar uma conta Google Cloud, instalar a Google Cloud CLI e configurar o comando gcloud
. Para
mais informações, consulte Configurar o ambiente do Cloud TPU.
Criar uma VM do Cloud TPU usando gcloud
Defina algumas variáveis de ambiente para facilitar o uso dos comandos.
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Descrições das variáveis de ambiente
PROJECT_ID
- O ID do Google Cloud projeto.
ACCELERATOR_TYPE
- O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos para cada versão de TPU, consulte Versões de TPU.
ZONE
- A zona em que você planeja criar o Cloud TPU.
RUNTIME_VERSION
- A versão do ambiente de execução do Cloud TPU. Para mais informações, consulte Imagens de VM de TPU .
TPU_NAME
- O nome atribuído pelo usuário ao Cloud TPU.
Crie sua VM de TPU executando o comando a seguir em um Cloud Shell ou no terminal do computador em que a Google Cloud CLI está instalada.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Conecte-se à VM do Cloud TPU
Conecte-se à VM da TPU usando SSH com o seguinte comando:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Instalar o JAX na VM do Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificação do sistema
Verifique se o JAX pode acessar o TPU e executar operações básicas:
Inicie o interpretador do Python 3:
(vm)$ python3
>>> import jax
Veja o número de núcleos de TPU disponíveis:
>>> jax.device_count()
O número de núcleos de TPU é exibido. O número de cores exibidas depende da versão da TPU que você está usando. Para mais informações, consulte Versões da TPU.
Faça um cálculo:
>>> jax.numpy.add(1, 1)
O resultado da adição "numpy" é exibido:
Resultado do comando:
Array(2, dtype=int32, weak_type=true)
Saia do interpretador Python:
>>> exit()
Como executar código do JAX em uma VM de TPU
Agora é possível executar qualquer código do JAX que você quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:
Instalar dependências de exemplos do Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Instalar o FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Execute o script de treinamento MNIST do FLAX
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
O script faz o download do conjunto de dados e inicia o treinamento. A saída do script será semelhante a esta:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Limpar
Para evitar cobranças na conta do Google Cloud pelos recursos usados nesta página, siga estas etapas.
Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.
Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:
(vm)$ exit
Exclua o Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Notas de performance
Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.
Preenchimento
Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:
- As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
- A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.
bfloat16 dtype
Por padrão, a multiplicação de matriz no JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:
precision=jax.lax.Precision.DEFAULT
: usa a precisão bfloat16 mista (mais rápida)precision=jax.lax.Precision.HIGH
: usa vários passes MXU para aumentar a precisãoprecision=jax.lax.Precision.HIGHEST
: usa ainda mais passes MXU para alcançar uma precisão float32 completa.
O JAX também adiciona o tipo de dados bfloat16, que pode ser usado para transmitir matrizes explicitamente para
bfloat16
, por exemplo,
jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
A seguir
Para mais informações sobre o Cloud TPU, consulte: