Salvataggio dei modelli TensorFlow per AI Explanations

Questa pagina spiega come salvare un modello TensorFlow da utilizzare con AI Explanations, se utilizzi TensorFlow 2.x o TensorFlow 1.15.

TensorFlow 2

Se utilizzi TensorFlow 2.x, utilizza tf.saved_model.save per salvare un modello di machine learning.

Un'opzione comune per ottimizzare i modelli TensorFlow salvati è che gli utenti possono fornire firme. Puoi specificare le firme di input durante il salvataggio del modello. Se hai una sola firma di input, AI Explanations utilizza automaticamente il valore funzione di servizio predefinita per le tue richieste di spiegazioni, seguendo la funzione predefinita comportamento di tf.saved_model.save. Scopri di più su come specificare le firme di pubblicazione in TensorFlow.

Più firme di input

Se il modello ha più di una firma di input, AI Explanations non può determinare automaticamente quale definizione di firma utilizzare quando si recupera una previsione dal modello. Pertanto, devi specificare quale definizione di firma vuoi che venga utilizzata da AI Explanations. Quando salvi il modello, specifica la firma della funzione predefinita di pubblicazione in una chiave univoca, xai-model:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_model': my_signature_default_fn # Required for AI Explanations
    })

In questo caso, AI Explanations utilizza la firma della funzione del modello fornita con Chiave xai_model per interagire con il modello e generare spiegazioni. Utilizza la stringa esatta xai_model per la chiave. Per ulteriori informazioni, consulta questa panoramica delle definizioni della firma.

Funzioni di pre-elaborazione

Se utilizzi una funzione di pre-elaborazione, devi specificare le firme per la funzione di pre-elaborazione e la funzione del modello quando salvi il modello. Utilizza le funzionalità di il tasto xai_preprocess per specificare la funzione di pre-elaborazione:

tf.saved_model.save(m, model_dir, signatures={
    'serving_default': serving_fn,
    'xai_preprocess': preprocess_fn, # Required for AI Explanations
    'xai_model': model_fn # Required for AI Explanations
    })

In questo caso, AI Explanations utilizza la tua funzione di pre-elaborazione e il tuo modello per le tue richieste di spiegazione. Assicurati che l'output della funzione di preelaborazione corrisponda all'input previsto dalla funzione del modello.

Prova i blocchi note di esempio completi di TensorFlow 2:

TensorFlow 1.15

Se utilizzi TensorFlow 1.15, non usare tf.saved_model.save. Questa funzione non è supportata con AI Explanations quando si utilizza TensorFlow 1. Utilizza invece tf.estimator.export_savedmodel insieme a un token appropriato tf.estimator.export.ServingInputReceiver

Modelli creati con Keras

Se crei e addestri il tuo modello in Keras, devi convertirlo in TensorFlow Estimator e poi esportarlo in un SavedModel. Questa sezione è incentrata sul salvataggio di un modello. Per un esempio completo funzionante, vedi i blocchi note di esempio:

Dopo aver creato, compilato, addestrato e valutato il modello Keras, devi seguire questi passaggi:

  • Convertire il modello Keras in un TensorFlow Estimator utilizzando tf.keras.estimator.model_to_estimator
  • Fornisci una funzione input di pubblicazione utilizzando tf.estimator.export.build_raw_serving_input_receiver_fn
  • Esporta il modello come SavedModel utilizzando tf.estimator.export_saved_model.
# Build, compile, train, and evaluate your Keras model
model = tf.keras.Sequential(...)
model.compile(...)
model.fit(...)
model.predict(...)

## Convert your Keras model to an Estimator
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir='export')

## Define a serving input function appropriate for your model
def serving_input_receiver_fn():
  ...
  return tf.estimator.export.ServingInputReceiver(...)

## Export the SavedModel to Cloud Storage using your serving input function
export_path = keras_estimator.export_saved_model(
    'gs://' + 'YOUR_BUCKET_NAME', serving_input_receiver_fn).decode('utf-8')

print("Model exported to: ", export_path)

Passaggi successivi