Ejecutar cargas de trabajo de TPU en un contenedor Docker

Los contenedores Docker facilitan la configuración de las aplicaciones, ya que combinan el código y todas las dependencias necesarias en un paquete distribuible. Puedes ejecutar contenedores Docker en VMs de TPU para simplificar la configuración y el uso compartido de tus aplicaciones de TPU de Cloud. En este documento se describe cómo configurar un contenedor de Docker para cada framework de aprendizaje automático compatible con Cloud TPU.

Entrenar un modelo de PyTorch en un contenedor Docker

Dispositivo de TPU

  1. Crear una máquina virtual de TPU de Cloud

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. Conectarse a la VM de TPU mediante SSH

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. Asegúrate de que se ha concedido el rol Lector de Artifact Registry a tu usuario de Google Cloud . Para obtener más información, consulta el artículo sobre cómo asignar roles de Artifact Registry.

  4. Iniciar un contenedor en la VM de TPU con la imagen nocturna de PyTorch/XLA

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. Configurar el tiempo de ejecución de la TPU

    Hay dos opciones de tiempo de ejecución de PyTorch/XLA: PJRT y XRT. Te recomendamos que uses PJRT a menos que tengas un motivo para usar XRT. Para obtener más información sobre las diferentes configuraciones de tiempo de ejecución, consulta la documentación del tiempo de ejecución de PJRT.

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. Clona el repositorio de PyTorch XLA

    git clone --recursive https://github.com/pytorch/xla.git
  7. Entrenar ResNet50

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

Cuando se complete la secuencia de comandos de formación, asegúrate de limpiar los recursos.

  1. Escribe exit para salir del contenedor Docker.
  2. Escribe exit para salir de la VM de TPU.
  3. Eliminar la VM de TPU

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

Slice de TPU

Cuando ejecutas código de PyTorch en un segmento de TPU, debes ejecutarlo en todos los trabajadores de TPU al mismo tiempo. Una forma de hacerlo es usar el comando gcloud compute tpus tpu-vm ssh con las marcas --worker=all y --command. En el siguiente procedimiento se muestra cómo crear una imagen de Docker para facilitar la configuración de cada trabajador de TPU.

  1. Crear una VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. Añadir el usuario actual al grupo Docker

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. Clona el repositorio de PyTorch XLA

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. Ejecutar la secuencia de comandos de entrenamiento en un contenedor en todos los trabajadores de TPU

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Marcas de comandos de Docker:

    • --rm elimina el contenedor después de que finalice su proceso.
    • --privileged expone el dispositivo TPU al contenedor.
    • --net=host vincula todos los puertos del contenedor a la VM de TPU para permitir la comunicación entre los hosts del pod.
    • -e define variables de entorno.

Cuando se complete la secuencia de comandos de formación, asegúrate de limpiar los recursos.

Elimina la VM de TPU con el siguiente comando:

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

Entrenar un modelo de JAX en un contenedor Docker

Dispositivo de TPU

  1. Crear la VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. Conectarse a la VM de TPU mediante SSH

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. Iniciar el daemon de Docker en una máquina virtual de TPU

    sudo systemctl start docker
  4. Iniciar el contenedor Docker

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. Instalar JAX

    pip install jax[tpu]
  6. Instalar FLAX

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. Instalar los paquetes tensorflow y tensorflow-dataset

    pip install tensorflow
    pip install tensorflow-datasets
  8. Ejecutar la secuencia de comandos de entrenamiento de MNIST de Flax

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Cuando se complete la secuencia de comandos de formación, asegúrate de limpiar los recursos.

  1. Escribe exit para salir del contenedor Docker.
  2. Escribe exit para salir de la VM de TPU.
  3. Eliminar la VM de TPU

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

Slice de TPU

Cuando ejecutas código JAX en un segmento de TPU, debes ejecutarlo en todos los trabajadores de TPU al mismo tiempo. Una forma de hacerlo es usar el comando gcloud compute tpus tpu-vm ssh con las marcas --worker=all y --command. En el siguiente procedimiento se muestra cómo crear una imagen de Docker para facilitar la configuración de cada trabajador de TPU.

  1. Crea un archivo llamado Dockerfile en el directorio actual y pega el siguiente texto:

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. Preparar un Artifact Registry

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. Crear la imagen de Docker

    docker build -t your-image-name .
  4. Añade una etiqueta a tu imagen de Docker antes de enviarla a Artifact Registry. Para obtener más información sobre cómo trabajar con Artifact Registry, consulta el artículo Trabajar con imágenes de contenedor.

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. Enviar la imagen Docker a Artifact Registry

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. Crear una VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. Extrae la imagen Docker de Artifact Registry en todos los trabajadores de TPU

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. Ejecutar el contenedor en todos los trabajadores de TPU

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. Ejecuta la secuencia de comandos de entrenamiento en todos los trabajadores de TPU

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

Cuando se complete la secuencia de comandos de formación, asegúrate de limpiar los recursos.

  1. Cerrar el contenedor en todos los trabajadores

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. Eliminar la VM de TPU

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

Entrenar un modelo de JAX en un contenedor Docker con JAX Stable Stack

Puedes crear las imágenes de Docker MaxText y MaxDiffusion con la imagen base de JAX Stable Stack.

JAX Stable Stack proporciona un entorno coherente para MaxText y MaxDiffusion al agrupar JAX con paquetes principales como orbax, flax, optax y libtpu.so. Estas bibliotecas se han probado para asegurar la compatibilidad y proporcionan una base estable para crear y ejecutar MaxText y MaxDiffusion. De esta forma, se eliminan posibles conflictos debidos a versiones de paquetes incompatibles.

JAX Stable Stack incluye una versión completa y cualificada de libtpu.so, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red ICI de los programas de TPU. La versión de libtpu sustituye a la compilación nocturna que usaba JAX anteriormente y asegura que las computaciones de XLA en TPU tengan una funcionalidad coherente con las pruebas de calificación a nivel de PJRT en IRs de HLO y StableHLO.

Para compilar la imagen de Docker de MaxText y MaxDiffusion con JAX Stable Stack, cuando ejecutes la secuencia de comandos docker_build_dependency_image.sh, asigna el valor stable_stack a la variable MODE y asigna a la variable BASEIMAGE la imagen base que quieras usar.

docker_build_dependency_image.sh se encuentra en el repositorio de GitHub MaxDiffusion y en el repositorio de GitHub MaxText. Clona el repositorio que quieras usar y ejecuta el script docker_build_dependency_image.sh de ese repositorio para crear la imagen Docker.

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
git clone https://github.com/AI-Hypercomputer/maxtext.git

El siguiente comando genera una imagen Docker para usarla con MaxText y MaxDiffusion usando us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 como imagen base.

sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Para ver una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.

Siguientes pasos