Utilizzare i modelli aperti Gemma con Dataflow

Gemma è una famiglia di modelli aperti leggeri e all'avanguardia, creati a partire dalla ricerca e dalla tecnologia utilizzati per creare i modelli Gemini. Puoi utilizzare i modelli Gemma nelle pipeline di inferenza Apache Beam. Il termine ponderazione aperta indica che vengono rilasciati i parametri preaddestrati di un modello, o ponderazioni. Non vengono forniti dettagli come il set di dati originale, l'architettura del modello e il codice di addestramento.

Casi d'uso

Puoi utilizzare i modelli Gemma con Dataflow per l'analisi del sentiment. Con Dataflow e i modelli Gemma, puoi elaborare eventi, come le recensioni dei clienti, man mano che arrivano. Esegui le revisioni attraverso il modello per analizzarle, quindi genera consigli. Combinando Gemma con Apache Beam, puoi completare questo flusso di lavoro senza problemi.

Assistenza e limitazioni

I modelli aperti Gemma sono supportati con Apache Beam e Dataflow con i seguenti requisiti:

  • Disponibile per le pipeline in modalità batch e flusso che utilizzano l'SDK Apache Beam Python 2.46.0 e versioni successive.
  • I job Dataflow devono utilizzare Runner v2.
  • I job Dataflow devono utilizzare GPU. Per un elenco dei tipi di GPU supportati con Dataflow, consulta Disponibilità. Sono consigliati i tipi di GPU T4 e L4.
  • Il modello deve essere scaricato e salvato nel formato file .keras.
  • Il gestore del modello TensorFlow è consigliato, ma non obbligatorio.

Prerequisiti

  • Accedi ai modelli Gemma tramite Kaggle.
  • Compila il modulo di consenso e accetta i termini e le condizioni.
  • Scarica il modello Gemma. Salvalo nel formato file .keras in una posizione a cui può accedere il job Dataflow, ad esempio un bucket Cloud Storage. Quando specifichi un valore per la variabile di percorso del modello, utilizza il percorso per questa posizione di archiviazione.
  • Per eseguire il job su Dataflow, crea un'immagine container personalizzata. Questo passaggio consente di eseguire la pipeline con GPU sul servizio Dataflow. Per maggiori informazioni, consulta Creare un'immagine container personalizzata in "Esegui una pipeline con GPU".

Utilizza Gemma nella tua pipeline

Per utilizzare un modello Gemma nella tua pipeline Apache Beam, segui questi passaggi.

  1. Nel codice Apache Beam, dopo aver importato le dipendenze della pipeline, includi un percorso al modello salvato:

    model_path = "MODEL_PATH"
    

    Sostituisci MODEL_PATH con il percorso in cui hai salvato il modello scaricato. Ad esempio, se salvi il modello in un bucket Cloud Storage, il percorso avrà il formato gs://STORAGE_PATH/FILENAME.keras.

  2. L'implementazione Keras dei modelli Gemma prevede un metodo generate() che genera testo in base a un prompt. Per passare elementi al metodo generate(), utilizza una funzione di inferenza personalizzata.

    def gemma_inference_function(model, batch, inference_args, model_id):
      vectorized_batch = np.stack(batch, axis=0)
      # The only inference_arg expected here is a max_length parameter to
      # determine how many words are included in the output.
      predictions = model.generate(vectorized_batch, **inference_args)
      return utils._convert_to_result(batch, predictions, model_id)
    
  3. Esegui la pipeline, specificando il percorso del modello addestrato. Questo esempio utilizza un gestore del modello TensorFlow.

    class FormatOutput(beam.DoFn):
      def process(self, element, *args, **kwargs):
        yield "Input: {input}, Output: {output}".format(input=element.example, output=element.inference)
    
    # Instantiate a NumPy array of string prompts for the model.
    examples = np.array(["Tell me the sentiment of the phrase 'I like pizza': "])
    # Specify the model handler, providing a path and the custom inference function.
    model_handler = TFModelHandlerNumpy(model_path, inference_fn=gemma_inference_function)
    with beam.Pipeline() as p:
      _ = (p | beam.Create(examples) # Create a PCollection of the prompts.
             | RunInference(model_handler, inference_args={'max_length': 32}) # Send the prompts to the model and get responses.
             | beam.ParDo(FormatOutput()) # Format the output.
             | beam.Map(print) # Print the formatted output.
      )
    

Passaggi successivi