Executar um cálculo em uma VM da Cloud TPU usando o JAX
Este documento fornece uma breve introdução sobre como trabalhar com JAX e Cloud TPU.
Antes de seguir este guia de início rápido, você precisa criar uma conta do Google Cloud Platform
Google Cloud, instalar a Google Cloud CLI e configurar o comando gcloud
.
Para mais informações, consulte Configurar uma conta e um projeto do Cloud TPU.
Instalar a CLI do Google Cloud
A CLI do Google Cloud contém ferramentas e bibliotecas para interagir produtos e serviços do Google Cloud. Para mais informações, consulte Como instalar a Google Cloud CLI.
Configurar o comando gcloud
Execute os comandos a seguir para configurar gcloud
para usar o projeto do Google Cloud
e
instalar os componentes necessários para a visualização da VM da TPU.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Ativar a API Cloud TPU
Ative a API Cloud TPU usando o comando
gcloud
a seguir no Cloud Shell. Você também pode ativá-lo a partir da Console do Google Cloud).$ gcloud services enable tpu.googleapis.com
Execute o comando a seguir para criar uma identidade de serviço.
$ gcloud beta services identity create --service tpu.googleapis.com
Criar uma VM do Cloud TPU com gcloud
Com as VMs do Cloud TPU, seu modelo e código são executados diretamente no host da TPU máquina virtual. Conecte-se via SSH diretamente ao host da TPU. É possível executar códigos arbitrários, instalar pacotes, visualizar registros e depurar código diretamente no host da TPU.
Crie sua VM da TPU executando o seguinte comando em um Cloud Shell ou o terminal do computador em que a Google Cloud CLI está instalado.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central1-a \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base
Campos obrigatórios
zone
- A zona em que você planeja criar a Cloud TPU.
accelerator-type
- O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU.
version
- A versão do software do Cloud TPU. Para todos os tipos de TPU, use
tpu-ubuntu2204-base
:
Conecte-se à VM do Cloud TPU
Use o comando a seguir para se conectar à VM da TPU:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a
Campos obrigatórios
tpu_name
- O nome da VM de TPU a que você está se conectando.
zone
- A zona em que você criou o Cloud TPU.
Instalar o JAX na VM do Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificação do sistema
Verifique se o JAX consegue acessar a TPU e executar operações básicas:
Inicie o interpretador do Python 3:
(vm)$ python3
>>> import jax
Veja o número de núcleos de TPU disponíveis:
>>> jax.device_count()
O número de núcleos de TPU é exibido. Se você estiver usando uma TPU v4, esse valor será
4
. Caso esteja usando uma TPU v2 ou v3, use 8
.
Faça um cálculo simples:
>>> jax.numpy.add(1, 1)
O resultado da adição "numpy" é exibido:
Resultado do comando:
Array(2, dtype=int32, weak_type=true)
Saia do interpretador Python:
>>> exit()
Como executar código do JAX em uma VM de TPU
Agora você pode executar qualquer código JAX que quiser. Os exemplos de flax são um ótimo lugar para começar a executar modelos padrão de ML no JAX. Por exemplo, para treinar uma rede convolucional básica MNIST:
Instale dependências de exemplos do Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Instalar o FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Execute o script de treinamento MNIST do FLAX
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
O script faz o download do conjunto de dados e inicia o treinamento. A saída do script será semelhante a esta:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Limpar
Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.
Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:
(vm)$ exit
Exclua o Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central1-a
Execute o seguinte comando para verificar se os recursos foram excluídos. Verifique se a TPU não está mais listada. A exclusão pode levar vários minutos.
$ gcloud compute tpus tpu-vm list \ --zone=us-central1-a
Notas de desempenho
Veja alguns detalhes importantes que são relevantes principalmente para usar TPUs no JAX.
Preenchimento
Uma das causas mais comuns do desempenho lento em TPUs é o preenchimento involuntário:
- As matrizes no Cloud TPU estão em blocos. Isso envolve o preenchimento de uma das dimensões em um múltiplo de 8 e de uma dimensão diferente em um múltiplo de 128.
- A unidade de multiplicação de matriz tem um melhor desempenho com pares de matrizes grandes que minimizam a necessidade de preenchimento.
bfloat16 dtype
Por padrão, a multiplicação de matrizes do JAX em TPUs usa bfloat16 com acumulação float32. Isso pode ser controlado com o argumento de precisão em chamadas de função jax.numpy relevantes (matmul, dot, einsum etc.). Especificamente:
precision=jax.lax.Precision.DEFAULT
: usa bfloat16 misto precisão (mais rápido)precision=jax.lax.Precision.HIGH
: usa vários passes MXU para aumentar a precisãoprecision=jax.lax.Precision.HIGHEST
: usa ainda mais passes MXU. para alcançar a precisão total de float32
O JAX também adiciona o dtype bfloat16, que pode ser usado para transmitir matrizes explicitamente
bfloat16
, por exemplo:
jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Como executar o JAX em um Colab
Quando você executa o código JAX em um bloco do Colab, o Colab automaticamente cria um nó legado de TPU. Os nós de TPU têm uma arquitetura diferente. Para mais informações, consulte Arquitetura do sistema.
A seguir
Para mais informações sobre o Cloud TPU, consulte: