Entraîner un modèle TensorFlow avec Keras sur Google Kubernetes Engine
Restez organisé à l'aide des collections
Enregistrez et classez les contenus selon vos préférences.
La section suivante fournit un exemple d'ajustement d'un modèle BERT pour la classification de séquences à l'aide de la bibliothèque Hugging Face Transformers avec TensorFlow. L'ensemble de données est téléchargé dans un volume Parallelstore associé, ce qui permet à l'entraînement du modèle de lire directement les données du volume.
Prérequis
Assurez-vous que votre nœud dispose d'au moins 8 Go de mémoire disponible.
Sauf indication contraire, le contenu de cette page est régi par une licence Creative Commons Attribution 4.0, et les échantillons de code sont régis par une licence Apache 2.0. Pour en savoir plus, consultez les Règles du site Google Developers. Java est une marque déposée d'Oracle et/ou de ses sociétés affiliées.
Dernière mise à jour le 2025/09/04 (UTC).
[[["Facile à comprendre","easyToUnderstand","thumb-up"],["J'ai pu résoudre mon problème","solvedMyProblem","thumb-up"],["Autre","otherUp","thumb-up"]],[["Difficile à comprendre","hardToUnderstand","thumb-down"],["Informations ou exemple de code incorrects","incorrectInformationOrSampleCode","thumb-down"],["Il n'y a pas l'information/les exemples dont j'ai besoin","missingTheInformationSamplesINeed","thumb-down"],["Problème de traduction","translationIssue","thumb-down"],["Autre","otherDown","thumb-down"]],["Dernière mise à jour le 2025/09/04 (UTC)."],[],[],null,["# Train a TensorFlow model with Keras on Google Kubernetes Engine\n\nThe following section provides an example of\n[fine-tuning a BERT model](https://huggingface.co/docs/transformers/training#train-a-tensorflow-model-with-keras)\nfor sequence classification using the\n[Hugging Face transformers](https://github.com/huggingface/transformers) library\nwith TensorFlow. The dataset is downloaded into a mounted\nParallelstore-backed volume, allowing the model training to directly read data\nfrom the volume.\n\nPrerequisites\n-------------\n\n- Ensure your node has at least 8 GiB of memory available.\n- [Create a PersistentVolumeClaim requesting for a Parallelstore-backed volume](/kubernetes-engine/docs/how-to/persistent-volumes/parallelstore-csi-new-volume#pvc).\n\nSave the following YAML manifest (`parallelstore-csi-job-example.yaml`) for your model training Job. \n\n apiVersion: batch/v1\n kind: Job\n metadata:\n name: parallelstore-csi-job-example\n spec:\n template:\n metadata:\n annotations:\n gke-parallelstore/cpu-limit: \"0\"\n gke-parallelstore/memory-limit: \"0\"\n spec:\n securityContext:\n runAsUser: 1000\n runAsGroup: 100\n fsGroup: 100\n containers:\n - name: tensorflow\n image: jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d\n command: [\"bash\", \"-c\"]\n args:\n - |\n pip install transformers datasets\n python - \u003c\u003cEOF\n from datasets import load_dataset\n dataset = load_dataset(\"glue\", \"cola\", cache_dir='/data')\n dataset = dataset[\"train\"]\n from transformers import AutoTokenizer\n import numpy as np\n tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n tokenized_data = tokenizer(dataset[\"sentence\"], return_tensors=\"np\", padding=True)\n tokenized_data = dict(tokenized_data)\n labels = np.array(dataset[\"label\"])\n from transformers import TFAutoModelForSequenceClassification\n from tensorflow.keras.optimizers import Adam\n model = TFAutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\")\n model.compile(optimizer=Adam(3e-5))\n model.fit(tokenized_data, labels)\n EOF\n volumeMounts:\n - name: parallelstore-volume\n mountPath: /data\n volumes:\n - name: parallelstore-volume\n persistentVolumeClaim:\n claimName: parallelstore-pvc\n restartPolicy: Never\n backoffLimit: 1\n\nApply the YAML manifest to the cluster.\n\n`kubectl apply -f parallelstore-csi-job-example.yaml`\n\nCheck your data loading and model training progress with the following command: \n\n POD_NAME=$(kubectl get pod | grep 'parallelstore-csi-job-example' | awk '{print $1}')\n kubectl logs -f $POD_NAME -c tensorflow\n\n| **Note:** The model training takes approximately five minutes to complete."]]