Offene Gemma-Modelle mit Dataflow verwenden

Gemma ist eine Familie einfacher, hochmoderner offener Modelle, die auf Forschung und Technologie basieren, mit denen die Gemini-Modelle erstellt werden. Sie können Gemma-Modelle in Ihren Apache Beam-Inferenzpipelines verwenden. Der Begriff Offene Gewichtung bedeutet, dass die vortrainierten Parameter oder Gewichtungen eines Modells veröffentlicht werden. Details wie das ursprüngliche Dataset, die Modellarchitektur und der Trainingscode werden nicht angegeben.

Anwendungsfälle

Sie können Gemma-Modelle mit Dataflow für die Sentimentanalyse verwenden. Mit Dataflow und den Gemma-Modellen können Sie Ereignisse wie Kundenrezensionen verarbeiten, sobald sie eintreffen. Lassen Sie die Rezensionen durch das Modell laufen, um sie zu analysieren und dann Empfehlungen zu generieren. Durch die Kombination von Gemma mit Apache Beam können Sie diesen Workflow nahtlos ausführen.

Unterstützung und Einschränkungen

Offene Gemma-Modelle werden mit Apache Beam und Dataflow mit den folgenden Anforderungen unterstützt:

  • Verfügbar für Batch- und Streamingpipelines, die die Apache Beam Python SDK-Versionen 2.46.0 und höher verwenden.
  • Dataflow-Jobs müssen Runner v2 verwenden.
  • Dataflow-Jobs müssen GPUs verwenden. Eine Liste der mit Dataflow unterstützten GPU-Typen finden Sie unter Verfügbarkeit. Die GPU-Typen T4 und L4 werden empfohlen.
  • Das Modell muss heruntergeladen und im Dateiformat .keras gespeichert werden.
  • Der TensorFlow-Modell-Handler wird empfohlen, ist aber nicht erforderlich.

Vorbereitung

  • Greifen Sie auf Gemma-Modelle über Kaggle zu.
  • Füllen Sie das Zustimmungsformular aus und akzeptieren Sie die Nutzungsbedingungen.
  • Laden Sie das Gemma-Modell herunter. Speichern Sie sie im Dateiformat .keras an einem Speicherort, auf den der Dataflow-Job zugreifen kann, z. B. in einem Cloud Storage-Bucket. Wenn Sie einen Wert für die Modellpfadvariable angeben, verwenden Sie den Pfad zu diesem Speicherort.
  • Erstellen Sie ein benutzerdefiniertes Container-Image, um Ihren Job in Dataflow auszuführen. Mit diesem Schritt kann die Pipeline mit GPUs im Dataflow-Dienst ausgeführt werden.

Gemma in einer Pipeline verwenden

So verwenden Sie ein Gemma-Modell in Ihrer Apache Beam-Pipeline:

  1. Fügen Sie in Ihrem Apache Beam-Code nach dem Import der Pipeline-Abhängigkeiten einen Pfad zu Ihrem gespeicherten Modell ein:

    model_path = "MODEL_PATH"
    

    Ersetzen Sie MODEL_PATH durch den Pfad, unter dem Sie das heruntergeladene Modell gespeichert haben. Wenn Sie beispielsweise Ihr Modell in einem Cloud Storage-Bucket speichern, hat der Pfad das Format gs://STORAGE_PATH/FILENAME.keras.

  2. Die Keras-Implementierung der Gemma-Modelle verfügt über eine generate()-Methode, die Text anhand eines Prompts generiert. Verwenden Sie eine benutzerdefinierte Inferenzfunktion, um Elemente an die Methode generate() zu übergeben.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Führen Sie die Pipeline aus und geben Sie den Pfad zum trainierten Modell an. In diesem Beispiel wird ein TensorFlow-Modell-Handler verwendet.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

Nächste Schritte