Executar um cálculo em uma VM da Cloud TPU usando o Jax
Neste documento, você verá uma introdução breve sobre como trabalhar com o JAX e o Cloud TPU.
Antes de seguir este guia de início rápido, é preciso criar uma conta do Google Cloud Platform, instalar a CLI do Google Cloud e configurar o comando gcloud
.
Para mais informações, consulte Configurar uma conta e um projeto do Cloud TPU.
Instale a CLI do Google Cloud
A CLI do Google Cloud contém ferramentas e bibliotecas para interagir com os produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar a CLI do Google Cloud.
Configurar o comando gcloud
Execute os comandos a seguir para configurar gcloud
para usar o projeto do GCP e
instalar os componentes necessários para a visualização da VM de TPU.
$ gcloud config set account your-email-account $ gcloud config set project project-id
Ativar a API Cloud TPU
Ative a API Cloud TPU usando o seguinte comando
gcloud
no Cloud Shell. Também é possível ativá-lo no Console do Google Cloud.$ gcloud services enable tpu.googleapis.com
Execute o comando a seguir para criar uma identidade de serviço.
$ gcloud beta services identity create --service tpu.googleapis.com
Criar uma VM do Cloud TPU com gcloud
Com as VMs do Cloud TPU, seu modelo e código são executados diretamente na máquina host da TPU. Conecte-se via SSH diretamente ao host da TPU. É possível executar código arbitrário, instalar pacotes, visualizar registros e depurar código diretamente no host da TPU.
Crie sua VM de TPU executando o comando a seguir em um Cloud Shell do GCP ou no terminal do computador em que a CLI do Google Cloud está instalada.
(vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \ --zone europe-west4-a \ --accelerator-type v3-8 \ --version tpu-vm-tf-2.8.0
Conectar-se à VM do Cloud TPU
Use o comando a seguir para se conectar à VM da TPU:
$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a
Campos obrigatórios
tpu_name
- O nome da VM de TPU a que você está se conectando.
zone
- A zona em que você criou o Cloud TPU.
Instalar o JAX na VM do Cloud TPU
(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificação do sistema
Para verificar se tudo está instalado corretamente, verifique se o JAX vê os núcleos do Cloud TPU e pode executar operações básicas:
Inicie o interpretador do Python 3:
(vm)$ python3
>>> import jax
Veja o número de núcleos de TPU disponíveis:
>>> jax.device_count()
O número de núcleos de TPU será exibido. Ele deve ser 8
.
Faça um cálculo simples:
>>> jax.numpy.add(1, 1)
O resultado da adição "numpy" é exibido:
Resultado do comando:
DeviceArray(2, dtype=int32)
Saia do interpretador Python:
>>> exit()
Como executar código do JAX em uma VM de TPU
Agora é possível executar qualquer código do JAX que você quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:
Instale conjuntos de dados do TensorFlow
(vm)$ pip install --upgrade clu
Instale o FLAX.
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user -e flax
Execute o script de treinamento MNIST do FLAX
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
O resultado do script de treinamento será assim:
I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00 I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05 I0513 21:09:37.321380
Limpeza
Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.
Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:
(vm)$ exit
Exclua o Cloud TPU.
$ gcloud alpha compute tpus tpu-vm delete tpu-name \ --zone europe-west4-a
Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.
Notas de desempenho
Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.
Preenchimento
Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:
- As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
- A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.
bfloat16 dtype
Por padrão, a multiplicação de matriz no JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:
precision=jax.lax.Precision.DEFAULT
: usa a precisão bfloat16 mista (mais rápida)precision=jax.lax.Precision.HIGH
: usa vários passes MXU para aumentar a precisãoprecision=jax.lax.Precision.HIGHEST
: usa ainda mais passes MXU para alcançar uma precisão float32 completa.
O JAX também adiciona o bfloat16 dtype, que pode ser usado para transmitir matrizes explicitamente para
bfloat16
, por exemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Como executar o JAX em um Colab
Quando você executa o código JAX em um notebook do Colab, o Colab automaticamente cria um nó legado de TPU. Os nós de TPU têm uma arquitetura diferente. Para mais informações, consulte Arquitetura do sistema.