Hard-attention Explanations for Image Classification

Customers can use an API that will classify an image while also providing explanations for which parts of the image contributed to the classification. Previous efforts to build these kinds of explanations suffered from either brittle explanations or low classification performance.

Apply for access Private documentation

Intended use

Problem types:

Although deep convolutional neural networks achieve state-of-the-art performance across nearly all image classification tasks, their decisions are difficult to interpret. One approach that offers some level of interpretability by design is hard attention, which uses only relevant portions of the image. However, training hard attention models with only class label supervision is challenging, and hard attention has proved difficult to scale to complex datasets. Here, we propose a novel hard attention model, which adds a pretraining step that requires only class labels and provides initial attention locations for policy gradient optimization. Our best models narrow the gap to common ImageNet baselines, achieving 75% top-1 and 91% top-5 while attending to less than one-third of the image.

Inputs and outputs:

  • Users provide: A square image array of size height x width x 3 with pixel values in [0, 255] range.
  • Users receive: A prediction of the class of the image from the 1000 ImageNet classes and the input image with hard attention patches marked to explain the prediction.

Industries and functions:

Use cases are not constrained to specific industries or functions. This experiment may be helpful whenever a user needs to classify an image from one of the 1000 ImageNet classes while getting a sequence of glimpses (aka saccades) that shows areas in image needed for to make the decision. A prototypical use case is to perform an image class labeling and localization of regions in input used to provide the image label, though other use cases exist.

What data do I need?

Data and label types: This experiment has been designed to help customers label images and receive attention explanations for the labels.

  • It is likely to be effective with natural images (i.e. those that can be captured with a camera in the real world similar to those from ImageNet).
  • It may not be effective with highly unusual or abstract image types, such as specialized medical images and scans.
  • Labels for the provided pretrained models are constrained to the 1000 ImageNet classes. To support other class labels, please apply to the experiment and the research team may be able to provide additional guidance.

Specifications:

  • Data:

    • Users must provide an image.
    • Image size: 224x224 pixels (user should ensure input image is of that size). Sample code below:

      # image for example a PIL image.
      ** User needs to resize or crop the image to size 224x224x3.
      IMAGE_URI="cat.jpg"
      with tf.gfile.Open(IMAGE_URI,"w") as f:
       image.save(f)
      
    • Stored in the following formats: .json. Sample code below:

      with tf.gfile.Open(IMAGE_URI, 'rb') as image_file:
       im = image_file.read()
       encoded_string = base64.b64encode(im).decode('utf-8')
      image_bytes = {'b64': str(encoded_string)}
      instances = {'input_image_bytes': image_bytes}
      tf.gfile.Open("prediction_instances.json","w") as f:
       f.write(json.dumps(instances))
      

What skills do I need?

As with all AI Workshop experiments, successful users are likely to be savvy with core AI concepts and skills in order to both deploy the experiment technology and interact with our AI researchers and engineers.

In particular, users of this experiment should:

  • Be familiar with accessing Google APIs.