Utiliser des modèles ouverts Gemma avec Dataflow

Gemma est une famille de modèles ouverts, légers et à la pointe de la technologie, basés sur la recherche et la technologie utilisées pour créer les modèles Gemini. Vous pouvez utiliser des modèles Gemma dans vos pipelines d'inférence Apache Beam. Le terme pondération ouverte signifie que les paramètres pré-entraînés d'un modèle, ou pondérations, sont publiés. Les détails tels que l'ensemble de données d'origine, l'architecture du modèle et le code d'entraînement ne sont pas fournis.

Cas d'utilisation

Vous pouvez utiliser des modèles Gemma avec Dataflow pour l'analyse des sentiments. Avec Dataflow et les modèles Gemma, vous pouvez traiter des événements, tels que les avis de clients, à mesure qu'ils arrivent. Faites passer les avis dans le modèle pour les analyser, puis générez des recommandations. En associant Gemma à Apache Beam, vous pouvez facilement mettre en œuvre ce workflow.

Compatibilité et limites

Les modèles ouverts Gemma sont compatibles avec Apache Beam et Dataflow, avec les exigences suivantes :

  • Disponible pour les pipelines par lot et par flux qui utilisent le SDK Apache Beam pour Python version 2.46.0 et ultérieure.
  • Les jobs Dataflow doivent utiliser Runner v2.
  • Les jobs Dataflow doivent utiliser des GPU. Pour obtenir la liste des types de GPU compatibles avec Dataflow, consultez la section Disponibilité. Les types de GPU T4 et L4 sont recommandés.
  • Le modèle doit être téléchargé et enregistré au format de fichier .keras.
  • Le gestionnaire de modèles TensorFlow est recommandé, mais pas obligatoire.

Prérequis

  • Accédez aux modèles Gemma via Kaggle.
  • Remplissez le formulaire d'autorisation et acceptez les conditions d'utilisation.
  • Téléchargez le modèle Gemma. Enregistrez-le au format .keras dans un emplacement auquel votre tâche Dataflow peut accéder, tel qu'un bucket Cloud Storage. Lorsque vous spécifiez une valeur pour la variable de chemin d'accès du modèle, utilisez le chemin d'accès à cet emplacement de stockage.
  • Pour exécuter votre job sur Dataflow, créez une image de conteneur personnalisée. Cette étape permet d'exécuter le pipeline avec des GPU sur le service Dataflow.

Utiliser Gemma dans votre pipeline

Pour utiliser un modèle Gemma dans votre pipeline Apache Beam, procédez comme suit :

  1. Dans votre code Apache Beam, après avoir importé vos dépendances de pipeline, incluez un chemin d'accès au modèle enregistré :

    model_path = "MODEL_PATH"
    

    Remplacez MODEL_PATH par le chemin d'accès où vous avez enregistré le modèle téléchargé. Par exemple, si vous enregistrez votre modèle dans un bucket Cloud Storage, le chemin d'accès est au format gs://STORAGE_PATH/FILENAME.keras.

  2. L'implémentation Keras des modèles Gemma dispose d'une méthode generate() qui génère du texte basé sur une invite. Pour transmettre des éléments à la méthode generate(), utilisez une fonction d'inférence personnalisée.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Exécutez votre pipeline en spécifiant le chemin d'accès au modèle entraîné. Cet exemple utilise un gestionnaire de modèles TensorFlow.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

Étapes suivantes