This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Recurrent Attention Models with Object-centric Capsule Representation for Multi-object Recognition

Hossein Adeli1∗,  Seoyoung Ahn1,  Gregory Zelinsky1,2

 1Department of Psychology, 2Department of Computer Science
Stony Brook University
{hossein.adelijelodar, seoyoung.ahn, gregory.zelinsky}@stonybrook.edu
corresponding author
Abstract

The visual system processes a scene using a sequence of selective glimpses, each driven by spatial and object-based attention. These glimpses reflect what is relevant to the ongoing task and are selected through recurrent processing and recognition of the objects in the scene. In contrast, most models treat attention selection and recognition as separate stages in a feedforward process. Here we show that using capsule networks to create an object-centric hidden representation in an encoder-decoder model with iterative glimpse attention yields effective integration of attention and recognition. We evaluate our model on three multi-object recognition tasks; highly overlapping digits, digits among distracting clutter and house numbers, and show that it learns to effectively move its glimpse window, recognize and reconstruct the objects, all with only the classification as supervision. Our work takes a step toward a general architecture for how to integrate recurrent object-centric representation into the planning of attentional glimpses.

1 Introduction

Visual inputs are perceived selectively and sequentially by the brain through attention sampling processes [Gottlieb et al., 2013]. Inspired by this, a wide range of attention mechanisms have been explored and incorporated recently in Deep Learning models of vision [Mnih et al., 2014, Ba et al., 2014, Zoran et al., 2020, Jaegle et al., 2021, Xu et al., 2015] among other domains [Vaswani et al., 2017]. The mechanisms range from ‘soft’ highlighting of task relevant features [Hu et al., 2018] or spatial areas [Wang et al., 2017] to ‘hard’ glimpse-based mechanisms [Elsayed et al., 2019, Mnih et al., 2014] inspired by fixation behavior. The glimpse-based mechanisms have the promise of making object recognition models more efficient and interpretable as they would only have to focus processing resources on smaller and relevant areas of the image. However, models in this domain typically treat attention selection and recognition as two separate processing stages [Cordonnier et al., 2021], first detecting the salient and relevant parts and features of the visual input before applying subsequent processing (e.g. recognition). This could limit their generalizability to domains where the relevance of certain areas and features in the image need to be determined dynamically and through ongoing object-based hypothesis formation. Understanding the dynamics of integrated attention-recognition mechanisms is important for building models that can learn to optimally sample an image, leading to more human-like and interpretable models.

Another recent development for making object recognition models more human-like and interpretable are Capsule Networks (CapsNets; Sabour et al. [2017], Hinton et al. [2018]). CapsNets attempt to represent scenes as parse trees, with capsules in different layers representing visual entities at different levels of object granularity in the image, from small object parts in the lower levels to whole objects at the top level (with objects being represented by the corresponding category capsules). However, there are limitations that can prevent the wider applicability of CapsNets. First, the model does not address how visual information could be processed across multiple timesteps, as new information becomes available either through attention sampling or subsequent frames in a video. The models so far have focused on how well objects can be represented and recognized based on a single sample of an image (but see Hinton [2021] for recent ideas on how to address this). Second, if there is one top-level capsule assigned to representing each category, how could a model represent and recognize multiple instances of the same category?

In this work we present OCRA, an Object-Centric Recurrent Attention model that combines recurrent glimpse-based attention and capsule methods. Like a CapsNet, it performs encapsulation of features to structure the higher level representations for object recognition. However, we place this structure within an encoder-decoder model with recurrent attention, thereby enabling integration of structured information across multiple attentional glimpses. Our model addresses the aforementioned limitations of the original CapsNets and shows that capsule-based binding of object features and grouping (part-whole matching) is effective in sequential detection and recognition of multiple overlapping and distinct objects. This synthesis of approaches can pave the way for building better recognition models for challenging conditions requiring a mechanism for parsing the scene to entities and a recurrent process for iterative attentional sampling and accumulation of new information.

2 Related Works

Soft attention models:

Many models have been proposed that learn spatial maps or filter weights to allocate processing resources to the most relevant areas or features of an image to improve performance [Chen et al., 2017, Hu et al., 2018, Park et al., 2018, Wang et al., 2017]. These “attention” processes range from bottom-up, where each layer decides what weighting of the features to route to the next [Chen et al., 2017, Hu et al., 2018], to top-down modulations coming from a higher level representation [Cao et al., 2015, Xu et al., 2015]. In this vein, transformer-based attention mechanisms [Vaswani et al., 2017] have recently been used to create models with an integrated process of “soft” sampling and recognition [Zoran et al., 2020, Jaegle et al., 2021]. Note however that soft attention models of recognition were mostly tested on images with a single prominent object, which avoids the problems of object feature binding and grouping. The sampling behavior of these models has also not yet been shown to be qualitatively similar to the fixation behavior of people, although the soft attention mechanisms could potentially capture aspects of bottom-up saliency [Itti and Koch, 2000] or feature-based attention where top-down modulations can weight the incoming representation based on task relevance [Desimone and Duncan, 1995, Maunsell and Treue, 2006].

Glimpse attention models:

In contrast to “soft” attention mechanism of applying a weighting to feature maps obtained from the entire image, many other models have incorporated a process of sampling the visual input by selecting restricted “glimpses” of the image. Models in this domain broadly scatter along a few dimensions: whether they are differentiable [Cordonnier et al., 2021, Gregor et al., 2015] or not [Mnih et al., 2014, Ba et al., 2014], whether they are sequential [Fu et al., 2017] or use a single feedforward pass [Cordonnier et al., 2021], whether they are trained supervised [Ba et al., 2014] or self-supervised (trained to reconstruct the input objects) [Eslami et al., 2016] or whether they are applied to images or videos (motion processing or activity recognition) [Kahou et al., 2017, Kosiorek et al., 2017, Duta et al., 2020]. DRAW [Gregor et al., 2015] introduced a differentiable sequential spatial attention mechanism to Variational Autoencoders (VAE, Kingma and Welling [2013]) and showed that its gradual glimpse-based process improved reconstruction performance. In contrast, supervised glimpse-based models are trained directly for recognizing one or more objects in images. RAM [Mnih et al., 2014] learned to move its glimpse window to sample image locations for object recognition, although its non-differentiable reinforcement learning-based attention mechanism proved difficult to train. The Saccader model applied a similar mechanism to object classification using the ImageNet dataset [Elsayed et al., 2019], but in order to make the training more tractable the model preprocesses the whole image to get the class logits for all categories at all potential patch locations. While this approach can make the models more interpretable, it does not take full advantage of the glimpse behavior to efficiently sample the image as part of an integrated recognition-attention process. Glimpse selection can also be done by having a separate stage of selecting patches before feeding them to the next stage of processing (e.g. classification) [Cordonnier et al., 2021]. These models therefore do not learn a policy for moving attention over an image that depends on the current object hypothesis, as people do in their application of object-based attention. An extension of RAM [Ba et al., 2014] used the whole image in a separate pathway to provide a context for glimpse selection, thereby similarly segregating the processes of attentional selection and recognition. A broad definition of glimpse models would also include models designed to have specific “zooming in” mechanisms for single object recognition, used to detect objects in large high-resolution images [Papadopoulos et al., 2021] or for fine-grained classification [Fu et al., 2017].

Models with recurrent and feedback connections:

There is a growing body of work on modeling the role of feedback and recurrent connections [Gilbert and Li, 2013] for different perceptual tasks [Kreiman and Serre, 2020]. Some “task optimized” models take an agnostic approach, adding recurrent and top-down connections and training end-to-end (similar to feedforward networks) to achieve the best performance in a task [Liang and Hu, 2015, Zamir et al., 2017, Kim et al., 2019]. An insight that these efforts have yielded is that the use of recurrence can allow for more compact models to have similar computational depth and reach similar levels of accuracy as bigger feedforward models [Nayebi et al., 2021, van Bergen and Kriegeskorte, 2020]. Another insight from these models is that recurrence assists challenging recognition tasks, such as those with high degrees of occlusion [Spoerer et al., 2017, Wyatte et al., 2014], which is hypothesized to be due to leveraging contextualized iterative computations [van Bergen and Kriegeskorte, 2020]. Other work has assigned more specific roles to the recurrent and feedback connections, for example to provide hypotheses for object locations [Cao et al., 2015] or the optimal feature weightings [Li et al., 2018] for the subsequent feedforward pass, using mechanisms that overlap with soft attention models [Stollenga et al., 2014, Wang et al., 2014].

Object-centric models (“Slots” and “Capsules”):

Objects are how people interact with the world and are therefore central to human scene understanding [Scholl, 2001, Spelke, 1990]. Visual objects are formed by (bottom-up) part-whole matching and Gestalt processes interacting with (top-down) objectness priors and knowledge of object categories [Greff et al., 2020, Vecera, 2000, Wagemans et al., 2012]. Object-centric models use these processes to discover objects and segregate their representations into different “slots” [Greff et al., 2020, Goyal et al., 2019]. Attention mechanisms have played a major role in object-centric models by enabling the iterative discovery and representation of an object’s properties. AIR [Eslami et al., 2016] used a spatial transformer component [Jaderberg et al., 2015] to attend to objects, infer their properties and reconstruct them sequentially. MONET [Burgess et al., 2019] first creates spatial object masks that are then used sequentially to reconstruct each entity using VAEs, effectively separating the attention routing and representation learning processes into two stages. This approach has been shown to be effective for downstream object reasoning tasks [Ding et al., 2020]. Although not using sequential spatial attention, some other models in this domain relatedly employ iterative processes to dynamically route and segregate objects to different slots [Greff et al., 2019, Locatello et al., 2020]. Capsule networks [Hinton et al., 2018, Sabour et al., 2017] are slot-based models that have capsule slots for each category placed at the top of a hierarchy of capsules representing a scene parse tree. Similar to other object-centric models, the entity encapsulation performed by CapsNets, creates object-centric representations that can lead to better downstream task performance [Qin et al., 2020].

3 Approach

Our approach is object-centric through the use of the entity encapsulation property of CapsNets, using it to structure representations for classification and attention planning. It is also consistent with soft attention and recurrent models of recognition. Adding glimpses to these models of recognition enables our model not only to leverage recurrent computation, but also to select the most relevant incoming visual inputs. OCRA therefore intersects all of the discussed modeling approaches, which we demonstrate to be key to its performance.

OCRA’s attention mechanism builds on the DRAW architecture [Gregor et al., 2015], which we found to be cognitively plausible. The “attention window” is a grid of filters applied to a small or a large area of the image. However, because the number of filters covering the attention window remains constant, as the window gets bigger it samples increasingly low-resolution information, creating a tradeoff between the size of the attention window and the resolution of information extracted. This property is aligned with “zoom lens” theories of human attention [Eriksen and James, 1986, Müller et al., 2003], that similarly propose a variable-sized attention process that can be broadly or narrowly allocated to an input, but one that is constrained by the same trade off. The original DRAW model [Gregor et al., 2015] was formulated as an autoregressive VAE, as it was trained for stepwise self-supervised reconstruction. In contrast, our formulation uses a deterministic encoder-decoder approach to make the model easier to train using both classification and reconstruction losses, because we want it to predict both the category classification based on latent class capsules, and image reconstruction based on decoder output.

3.1 OCRA Architecture

Refer to caption
Figure 1: OCRA performing multi-object recognition on an overlapping-digit image.

The OCRA architecture is shown in Fig. 1. At each timestep, a new glimpse from the image is encoded by a hierarchy of modules to yield a structured latent representation. The decoder then reconstructs the glimpse using this latent representation. The encoder-decoder steps are taken sequentially, determining the attention glimpse location for the following step. We provide an overview of the OCRA components and loss functions below. Pseudocodes and implementation details are provided in the Supplementary A. A pytorch [Paszke et al., 2019] implementation of OCRA with additional details and results are provided in this repository: github.com/hosseinadeli/OCRA.

Read and Write Attention:

At each timestep a glimpse, gtg_{t}, is sampled through applying a grid of N ×\times N Gaussian filters on the input image xx. We set the glimpse size to 18×\times18 for our experiments, with a sample glimpse shown in Fig. 1 left. The Gaussian filters are generated using four parameters: gX,gY,δ,σ2g_{X},g_{Y},\delta,\sigma^{2}, which specify the center coordinates of the attention window, the distance between equally spaced Gaussian filters in the grid, and the variance of the filters, respectively. All of these parameters are computed via a linear transformation of the previous step decoder RNN (Recurrent Neural Network) output ht1dech^{dec}_{t-1} using a weight matrix WreadW_{read}, which makes the attention mechanism fully-differentiable. A similar procedure applies to the write attention operation. The decoder RNN output htdech^{dec}_{t} is linearly transformed into an M ×\times M write patch wtw_{t} (set to 18 ×\times 18 in our experiments), which is then multiplied by the Gaussian filters to reconstruct the written parts in the original image size (Fig. 1 right). The Gaussian filters used for the write operation differed from those used for the read operation, and were computed from four parameters obtained from a separate linear transformation, Wwrite_attentionW_{write\_attention}, of the decoder RNN output htdech^{dec}_{t}. Detailed algorithms for Read and Write attention operations are provided in Pseudocode 3 and 4 in the Supplementary along with an illustration of the attention mechanisms (A.7).

Encoder:

After a glimpse is selected from the input image by the read attention operation, it is processed first using a two-layer convolutional neural network (CNN) with 32 filters in each layer. Kernel sizes were set to 5 and 3 respectively for the first and the second layers. Each convolutional layer is followed by max pooling with a kernel size of 2 and a stride of 2, and Rectified Linear Units (ReLU) were used for non-linear activation functions. Given the glimpse size of 18 ×\times 18, the resulting 32 feature maps are of size 4 ×\times 4. The feature maps, gtconvg^{conv}_{t}, are reshaped (to a vector of size 512) and used as input to the encoder Recurrent Neural Network (RNN), along with the encoder RNN hidden state from the previous step; ht1ench^{enc}_{t-1}. We used LSTM [Hochreiter and Schmidhuber, 1997] units (size 512) for the recurrent layers in our model.

Latent Capsule Representations and Dynamic Routing:

We use a vector implementation of capsules [Sabour et al., 2017] where the length of the vector represents the existence of the visual entity and the orientation characterizes its visual properties. The primary level capsules are generated through a linear read out of the encoder RNN; htench^{enc}_{t}. These capsules are meant to represent lower-level visual entities (“parts”) that belong to one of the higher-level capsules in the object capsule layer (“whole"). To find this part-whole relationship, we used the dynamic routing algorithm proposed by Sabour et al. [2017]. Dynamic routing is an iterative process where the assignments of parts to whole objects (coupling coefficients) are progressively determined by agreement between the two capsules (measured by the dot product between the two vector representations). For example, if the prediction for a digit capsule jj from a primary capsule ii, (p^tj|iWtijpti\hat{p}^{j|i}_{t}\leftarrow W^{ij}_{t}p^{i}_{t}), highly agrees with the computed digit capsule (ictijp^tj|i)(\sum_{i}c^{ij}_{t}\hat{p}^{j|i}_{t}), the coupling coefficient ctijc^{ij}_{t} is enhanced so that more information is routed from primary capsule ii to object capsule jj. Coupling coefficients are normalized across the class capsule dimension following the max-min normalization [Zhao et al., 2019] as in the Supplementary Eq. 3. This routing procedure iterates three times. We used this method instead of the softmax normalization in Sabour et al. [2017] because we observed the latter method would not differentiate between the coupling coefficients. In our experiments we used 40 primary level capsules, each a vector of size 8. The object capsules are vectors of size 16 and there are 10 of them corresponding to the 10 digit categories. For the object level capsules, we use a squash function (the Supplementary Eq. 4) to ensure that its vector length is within the range of 0 to 1 to represent the probability of a digit being present in the glimpse at each step. Once the routing is completed, we compute the vector length (L2 norm) of each object capsule to obtain classification scores. The final digit classification is predicted based on the scores accumulated over all timesteps. Algorithms are provided in Pseudocode 2.

Decoder:

The object capsules provide a structured representation that can be used for decoding and glimpse selection. We first mask the object capsules so that only the vector representation from the most active capsule is forwarded to the decoder RNN, which also inputs the hidden state from the previous step, ht1dech^{dec}_{t-1}. Because the decoder maintains through recurrence the ongoing and evolving object-based representation of the image, it is best suited to determine the next read glimpse location (as discussed earlier). The state of the decoder RNN is also used through two linear operations to determine what and where to write in the canvas to reconstruct the image.

3.2 Loss Function

OCRA outputs object classification scores (cumulative capsule lengths) and image reconstruction (cumulative write canvas). Losses are computed for each output and combined with a weighting as in Eq. 1. For reconstruction loss, we simply computed the mean squared differences in pixel intensities between the input image and the model’s reconstruction. For classification, we used margin loss (Eq. 2). For each class capsule j, the first term is only active if the target object is present (Tj>0T_{j}>0) where minimizing the loss pushes the capsule length to be bigger than target capsule length minus a small margin (m). The second term is only active when the target capsule length is zero and in that case minimizing the loss pushes the predicted capsule length to be below a small margin (m). For all the experiments in this paper we used Adam optimizer [Kingma and Ba, 2014]. See Supplementary A.8 for a detailed explanation and the pseudocode for the loss functions.

TotalLoss=jclassClassLossj+λreconReconLoss\displaystyle Total\,Loss=\sum_{j\in class}Class\,Loss_{j}+\lambda_{recon}Recon\,Loss (1)
ClassLoss=jclassmax(0,min(Tj,1))max(0,(Tjm)dj)2+λabsentmax(0,1Tj)max(0,djm)2\displaystyle\begin{aligned} Class\,Loss=\sum_{j\in class}max(0,min(T_{j},1))\cdot max(0,(T_{j}-m)-\|d_{j}\|)^{2}+\\ \lambda_{absent}\cdot max(0,1-T_{j})\cdot max(0,\|d_{j}\|-m)^{2}\end{aligned} (2)

4 Results

Attention is hypothesized to be helpful in difficult recognition tasks, involving multiple (small) objects or difficult feature discrimination, heavily occluded objects, or objects appearing against noisy object-similar backgrounds. Considering this, here we use three proof of concept multi-object recognition tasks to illustrate the effectiveness of our proposed attention-recognition mechanism. The MultiMNIST task [Sabour et al., 2017] tests the model’s recognition performance on highly overlapping digits. We hypothesize that our object-based attention sampling, paired with recurrent processing, will be effective in recognizing objects despite high degrees of occlusion. The second task uses the MultiMNIST cluttered dataset [Ba et al., 2014]. This task tests the model’s ability to learn to make attention glimpses to individual distant objects while ignoring object-similar background clutter. We selected these two MultiMNIST tasks to make sure that individual objects can be recognized easily and that the difficulty of these (still challenging) tasks would be in having multiple objects, occlusions and distractors, to be able to isolate the effectiveness of the proposed object-centric recurrent attention mechanism. Also these two tasks require different attention mechanisms and to our knowledge no model has been shown to perform both at a high level. The third task is sequence prediction on the Street View House Numbers dataset (SVHN) [Goodfellow et al., 2013] and can test our model on its ability to scale up to more real-world datasets with a different number of objects in images.

All model accuracies presented in this section are averages of 5 runs (see A.10 for details of training and hyperparameters selection for different experiments). All error rates are measured on image level, meaning that the response is incorrect if any object in the image is not recognized correctly.

4.1 MultiMNIST

Refer to caption
Figure 2: Examples of OCRA outputs on MultiMNIST classification (error cases on the right)
Table 1: MultiMNIST classification error rates. The error rates for the CapsNet and the CNN baseline are taken directly from Sabour et al., (2018)
Model Model size Train size Error rate
CNN Baseline 24.56M 60M 8.1%
CapsNet (Sabour et al., 2018) 11.36M 60M 5.2%
OCRA-3glimpse 3.87M 3M 7.24% (±\pm0.11)
OCRA-10glimpse 3.87M 3M 5.08% (±\pm0.17)
MultiMNIST Dataset:

We generate the MultiMNIST dataset following the method from Sabour et al. [2017]. Each image in this dataset contains two overlapping digits sampled from different classes of the MNIST hand-written digits dataset [LeCun et al., 1998] (size 28×\times28 pixels). After the two digits are overlaid, each is shifted randomly up to 4 pixels in horizontal and vertical directions, resulting in images of size 36×\times36 pixels with on average 80% overlap between the digit bounding boxes. We generated 3M images for training, and 500K images for testing, and ensured that the training/testing sets were generated from the corresponding MNIST sets (i.e., the training set in MultiMNIST was only generated from the training set in MNIST). See Fig. 2 for examples.

4.1.1 MultiMNIST Results

Fig. 2 (left) shows some samples of the model making correct predictions and the resulting reconstructions. The 6 images on the right show all the errors that OCRA made in one test batch (size 128). Table 1 shows the OCRA accuracy for this task compared to two competing models. Our model with 10 timesteps outperforms the CapsNet model Sabour et al. [2017] while having only a third of the number of parameters. Moreover, the training set for our model is 20 fold smaller (3M compared to 60M), and in contrast to the CapsNet model ours does not use individual segmentations for each object during training. We saw a clear effect of the number of timesteps in our model, with the error rate dropping from 3 timesteps to 10 timesteps. Fig. 1 (bottom) shows the glimpse behavior on a sample image. The model starts with a more global glimpse but then moves its attention window, first to one object and then to the other, recognizing and reconstructing each sequentially. The gradual spreading of the reconstruction shown (bottom row) is consistent with the spreading of attention within an object hypothesized by object-based models of attention [Jeurissen et al., 2016, Ekman et al., 2020]. The most active capsule at each timestep is indicated by the digits in between the two rows.

4.1.2 Ablation experiments

In this section we start from the OCRA-3glimpse model and probe different model components; the object-centric representation, the glimpse mechanism and the recurrent processing; to measure their impact on accuracy. Results are shown in Table 2.

The effect of recurrent attention mechanisms:

In the first ablation experiment we asked how well a recurrent model using object-centric representation would perform without the ability to sample glimpses. OCRA-Recurrent (Table 2) performs multi-step processing on the input image using the recurrence in its encoder and decoder RNNs but lacks the ability to glimpse at specific locations. Therefore the model receives the entire image as input at each processing step for which it requires a bigger number of parameters (36×\times36 pixel input in contrast to 18×\times18 pixel glimpse). We then trained this model using three timesteps to make it comparable to OCRA-3glimpse. As shown in Table 2, this model performs worse than OCRA-3glimpse highlighting the important role of glimpse mechanism in our model performance. We then asked what if we removed the recurrent attention mechanism completely from our model. What is left is a model that makes one feedforward pass with the full resolution image as its input, binds features for either objects in separate category capsules, and then feeds them to the decoder (without masking the object capsules) to reconstruct the whole image at once. This is the Feedfoward model in Table 2. Similar to the OCRA-recurrent, having the model input the whole image in full resolution requires a bigger backbone and a larger number of parameters. This model shows higher error rate compared to the OCRA-recurrent supporting previous work on how recurrent dynamics can assist recognition tasks high degrees of occlusion [Spoerer et al., 2017, Wyatte et al., 2014]. Taken together, the results show that while recurrent computation can be effective for challenging recognition tasks with high degrees of occlusion, when it is paired with an attention glimpse mechanism more compact and better performing models can be built.

Table 2: Ablation study results on MultiMNIST classification
Model Model size Error rate
OCRA-3glimpse (3.87M) 7.24% (±\pm0.11)
Bigger models with the whole image as input
(no glimpse mechanism)
OCRA-Recurrent-3step (8.58M) 8.98% (±\pm0.14)
-Feedforward (6.47M) 10.63% (±\pm0.10)
OCRA-3glimpse with 1 routing step (3.87M) 7.77% (±\pm0.24)
-3glimpse without capsule representation (3.87M) 8.04% (±\pm0.13)
The role of capsule architecture and dynamic routing between capsules:

We performed two ablation studies to examine the impact of capsule architecture on model performance. We first asked whether the multiple dynamic routing steps between the two capsule layers has an impact on the error rate since the efficiency of the routing mechanism used by CapsNets has had mixed results to date [Gu et al., 2021, Tsai et al., 2020, Wang and Liu, 2018, Zhao et al., 2019]. As shown in Table 2, decreasing the routing step to 1 (uniform coupling coefficients) resulted in increased error rate compared to OCRA-3glimpse which had three dynamic routing steps. We then examined the role of Capsules in the model accuracy, by replacing that architecture with two fully connect layers (with the same size as two capsule layers; 320 and 160 units) and a classification readout. The new model still iteratively processes the image and the classification scores are read out from the second fully connected layer representation at each time step, which are then all combined across timesteps to make the final classification decision.  As shown in Table 2, we saw further increase in the error rate compared to the previous model showing the effect of encapsulation of information for performing this task. While these effects offer evidence for the role of capsule architecture and dynamic routing, they are relatively smaller compared to the ones from ablating the the recurrent attention. This suggests that the models can to some extent compensate for smaller routing steps or removal of the capsule architecture on this task. This could be due to two factors. First, the dynamic routing mechanism can flexibly learn to route information in a single routing step, allowing the weights between the capsule layers to do the binding without the need for multiple refining steps. Second, because the model has a glimpse mechanism it can route information globally, thereby potentially reducing the benefit derived from feature encapsulation and local dynamic routing.

4.2 MultiMNIST Cluttered

In this section we show that the attention mechanism in OCRA can handle noisy backgrounds and that its capsules can be used to recognize multiple objects from the same category.

MultiMNIST Cluttered dataset:

We generated the MultiMNIST Cluttered dataset, similar to the cluttered translated MNIST dataset from Mnih et al. [2014]. For each image, 2 digits and 6 digit-like clutter pieces are placed in random locations on a 100×\times100 blank canvas. All digits were sampled from the original MNIST dataset [LeCun et al., 1998] and the two digits in each image could be from the same or different categories. Clutter pieces were generated from other MNIST images by randomly cropping 8×\times8 patches. We generated 180K images for training and 30K for testing, ensuring to maintain the same MNIST training/testing separation.

Refer to caption
Figure 3: Attention, recognition and reconstruction processes of OCRA on MultiMNIST Cluttered

4.3 MultiMNIST Cluttered results

For this task we added a background capsule to the 10 class capsules, similar in approach to Qin et al. [2020]). This gives the model the choice to dynamically route background noise in its attention windows to a non-class capsule, thereby allowing the model to exclude the noise from its object representations. We also define a reconstruction mask that is the averaged sum of all the read glimpses converted into image dimensions (Fig. 3 middle row, right). We integrate this mask into our loss function by multiplying it by the input image before comparing it to the model reconstruction. This mask effectively focuses the loss so that the model is accountable for reconstructing only the areas where it had glimpsed, allowing the model to be selective with its glimpses and write operations.

Fig. 3 shows OCRA performing detection, classification and reconstruction of the digits in five timesteps for a sample image (top right) from this task. The top row shows the attention windows and the middle row shows the read glimpse at each step. The glimpses are converted to image dimensions for creating the reconstruction mask, as seen on the middle right. The most active capsule that is routed to the decoder at each step, and its length, are provided above the cumulative canvas (bottom row). Utilizing the reconstruction mask, the model, with only classification supervision, learns that the best strategy is to move its glimpse to digits and to write to the canvas only when it is confident of the digit classification. The model leverages the object-centric representation for attention planning and reconstruction, reflected in the behavior of clearly selecting objects and ignoring the distractors. We provide many more examples of the model behavior in the github repo. Fig. 4 shows more examples of model predictions for both recognition and reconstruction, with correct responses on the left side and errors on the right. Most errors by OCRA on this task are due to a digit overlapping with the other digit or noise pieces in ways that change their appearance from the underlying ground truth.

Table 3: MultiMNIST Cluttered classification error rates. The error rates for other models are taken directly from the corresponding papers (no error bars were provided)
Model Error rate
RAM 9%
DRAM w/o context 7%
DRAM 5%
OCRA-5glimpse 8.37% (±\pm1.43)
OCRA-7glimpse 7.12% (±\pm1.05)
Refer to caption
Figure 4: OCRA MultiMNIST cluttered sample predictions for recognition and reconstruction (errors are in bold on the right side)

Table 3 shows the classification results on this task. Two versions of the DRAM model [Ba et al., 2014] are included, with the one utilizing a context network performing best among the baseline models with a 5 percent error rate (taken directly from the reference paper where no error bars were provided). However, the context network in this model inputs the whole image in a separate pathway to plan attention selection, thereby separating it from the recognition pathway. OCRA has one pathway and strings together glimpses from multiple steps to have an integrated recognition and attention planning mechanism. Our model does this through its “zoom lens” attention processes that can switch between local and global processing as needed. The early glimpses in the model taken from the whole image, even though they are low resolution gists (Fig. 3, middle row, left), allow the model to plan its future glimpses. Later glimpses are focused on individual objects to mediate the recognition and reconstruction processes. OCRA’s performance on this task is comparable to the baseline models, despite the model being much smaller in size (DRAM has over 10M parameters) and easier to train (given its differentiable attention mechanism). We performed a similar ablation study to the previous task here by replacing the CapsNet architecture with two fully connected layers and a classification read out. We were not able to achieve error rates below 20 percent for the resulting model, indicating how essential the object-centric representation of CapsNet is for successfully performing this task.

4.4 SVHN

Street View House Numbers (SVHN) dataset [Goodfellow et al., 2013] consists of real-world images of house numbers, each containing a sequence of one to five digits. This dataset tests OCRA on both its applicability to more complex real-world stimuli and also on handling a varying number of objects in an image with multiple instances of any category. The original dataset is divided to train, test and extra images. We combined the train and extra sets to create a bigger training set and also converted the images to grayscale following [Ba et al., 2014], resulting in a train set of size 235K (10 percent for validatoin) and test set of size 13K. We made two changes to the model for this task. First, we increased the number of convolutional filters in the backbone from 32 to 64 in each of the two layers. Second, we added a readout layer to predict the digits in a sequence based on the capsule lengths as the model makes its pass across the image. The resulting model had 5.1 Million parameters. We train the model to "read" the digits from left to right by having the order of the predicted sequence match the ground truth from left to right. We allow the model to make 12 glimpses, with the first two not used for readout and the object capsule lengths from every following two glimpses will be read out for the output digits (e.g. the capsule lengths from the 3rd and 4th glimpses are read out to predict digit number 1; the left-most digit; and so on). Fig. 5 shows the model behavior for a sample image, as it glimpses, recognizes and reconstructs the objects in the sequence. The model achieved 2.65% (±\pm0.11) error rate on recognizing individual digits and 10.07% (±\pm0.53) error rate for recognizing whole sequences. We believe exploring bigger convolutional backbones (compared to a simple two-layer CNN we used) would be essential for improving the sequence prediction accuracy. More examples of model behavior and some classification errors are provided in the github repo.

Refer to caption
Figure 5: OCRA glimpse behavior (top row) and cumulative canvas (bottom) on an SVHN image

5 Discussion

We believe visual perception to be a sequential process, but one that requires the integration of attention and recognition. While models might use pre-processing or patch-selection mechanisms to solve specific tasks, methods should strive to capture a more dynamic integration of attention and recognition to be able to replicate human-level performance. We also believe that this modeling approach can be used in cognitive and neuroscience research to inform debates on the role of different connections in the visual system (feedforward, recurrent and feedback) for different tasks.

OCRA builds on capsule methods [Sabour et al., 2017], recurrent attention and DRAW architecture [Gregor et al., 2015]. While it addressed some of their limitations, there is still room for improvement. For one, we believe that a better capsule routing algorithm will be important in applying our method to more complex objects that require more complex part-whole matching. Also using a bigger backbone and generally a bigger architecture will be crucial. It is a direction for future work to determine whether our approach will scale up well to more complex objects and scenes.

We made architecture choices in OCRA to generally align with constraints known to exist in visual perception. We believe that these constraints, if captured correctly, would not only have the potential to improve model performance, but will also allow the model to predict human fixation behavior. We further see a broader application of our model in understating and modeling expert attention. For example, in the context of medical imaging and cancer screening by radiologists [Shen et al., 2021, Mall et al., 2018, Pesce et al., 2019], a model of experts’ attention and recognition behavior could lead to a greater understanding of underlying reasons for errors (misses and false positives). Moreover, having more interpretable models that could be build using effective glimpse attention and part-whole matching mechanisms can mitigate the risk of these expert models in practice.

References

  • Ba et al. [2014] J. Ba, V. Mnih, and K. Kavukcuoglu. Multiple object recognition with visual attention. arXiv preprint arXiv:1412.7755, 2014.
  • Burgess et al. [2019] C. P. Burgess, L. Matthey, N. Watters, R. Kabra, I. Higgins, M. Botvinick, and A. Lerchner. Monet: Unsupervised scene decomposition and representation. arXiv preprint arXiv:1901.11390, 2019.
  • Cao et al. [2015] C. Cao, X. Liu, Y. Yang, Y. Yu, J. Wang, Z. Wang, Y. Huang, L. Wang, C. Huang, W. Xu, et al. Look and think twice: Capturing top-down visual attention with feedback convolutional neural networks. In Proceedings of the IEEE international conference on computer vision, pages 2956–2964, 2015.
  • Chen et al. [2017] L. Chen, H. Zhang, J. Xiao, L. Nie, J. Shao, W. Liu, and T.-S. Chua. Sca-cnn: Spatial and channel-wise attention in convolutional networks for image captioning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 5659–5667, 2017.
  • Cordonnier et al. [2021] J.-B. Cordonnier, A. Mahendran, A. Dosovitskiy, D. Weissenborn, J. Uszkoreit, and T. Unterthiner. Differentiable patch selection for image recognition. arXiv preprint arXiv:2104.03059, 2021.
  • Desimone and Duncan [1995] R. Desimone and J. Duncan. Neural mechanisms of selective visual attention. Annual review of neuroscience, 18(1):193–222, 1995.
  • Ding et al. [2020] D. Ding, F. Hill, A. Santoro, and M. Botvinick. Object-based attention for spatio-temporal reasoning: Outperforming neuro-symbolic models with flexible distributed architectures. arXiv preprint arXiv:2012.08508, 2020.
  • Duta et al. [2020] I. Duta, A. Nicolicioiu, and M. Leordeanu. Discovering dynamic salient regions with spatio-temporal graph neural networks. arXiv preprint arXiv:2009.08427, 2020.
  • Ekman et al. [2020] M. Ekman, P. R. Roelfsema, and F. P. de Lange. Object selection by automatic spreading of top-down attentional signals in v1. Journal of Neuroscience, 40(48):9250–9259, 2020.
  • Elsayed et al. [2019] G. F. Elsayed, S. Kornblith, and Q. V. Le. Saccader: improving accuracy of hard attention models for vision. arXiv preprint arXiv:1908.07644, 2019.
  • Eriksen and James [1986] C. W. Eriksen and J. D. S. James. Visual attention within and around the field of focal attention: A zoom lens model. Perception & psychophysics, 40(4):225–240, 1986.
  • Eslami et al. [2016] S. Eslami, N. Heess, T. Weber, Y. Tassa, D. Szepesvari, K. Kavukcuoglu, and G. E. Hinton. Attend, infer, repeat: Fast scene understanding with generative models. arXiv preprint arXiv:1603.08575, 2016.
  • Fu et al. [2017] J. Fu, H. Zheng, and T. Mei. Look closer to see better: Recurrent attention convolutional neural network for fine-grained image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4438–4446, 2017.
  • Gilbert and Li [2013] C. D. Gilbert and W. Li. Top-down influences on visual processing. Nature Reviews Neuroscience, 14(5):350–363, 2013.
  • Goodfellow et al. [2013] I. J. Goodfellow, Y. Bulatov, J. Ibarz, S. Arnoud, and V. Shet. Multi-digit number recognition from street view imagery using deep convolutional neural networks. arXiv preprint arXiv:1312.6082, 2013.
  • Gottlieb et al. [2013] J. Gottlieb, P.-Y. Oudeyer, M. Lopes, and A. Baranes. Information-seeking, curiosity, and attention: computational and neural mechanisms. Trends in cognitive sciences, 17(11):585–593, 2013.
  • Goyal et al. [2019] A. Goyal, A. Lamb, J. Hoffmann, S. Sodhani, S. Levine, Y. Bengio, and B. Schölkopf. Recurrent independent mechanisms. arXiv preprint arXiv:1909.10893, 2019.
  • Greff et al. [2019] K. Greff, R. L. Kaufman, R. Kabra, N. Watters, C. Burgess, D. Zoran, L. Matthey, M. Botvinick, and A. Lerchner. Multi-object representation learning with iterative variational inference. In International Conference on Machine Learning, pages 2424–2433. PMLR, 2019.
  • Greff et al. [2020] K. Greff, S. van Steenkiste, and J. Schmidhuber. On the binding problem in artificial neural networks. arXiv preprint arXiv:2012.05208, 2020.
  • Gregor et al. [2015] K. Gregor, I. Danihelka, A. Graves, D. Rezende, and D. Wierstra. Draw: A recurrent neural network for image generation. In International Conference on Machine Learning, pages 1462–1471. PMLR, 2015.
  • Gu et al. [2021] J. Gu, V. Tresp, and H. Hu. Capsule network is not more robust than convolutional network. arXiv preprint arXiv:2103.15459, 2021.
  • Hinton [2021] G. Hinton. How to represent part-whole hierarchies in a neural network. arXiv preprint arXiv:2102.12627, 2021.
  • Hinton et al. [2018] G. E. Hinton, S. Sabour, and N. Frosst. Matrix capsules with em routing. In International conference on learning representations, 2018.
  • Hochreiter and Schmidhuber [1997] S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
  • Hu et al. [2018] J. Hu, L. Shen, and G. Sun. Squeeze-and-excitation networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 7132–7141, 2018.
  • Itti and Koch [2000] L. Itti and C. Koch. A saliency-based search mechanism for overt and covert shifts of visual attention. Vision research, 40(10-12):1489–1506, 2000.
  • Jaderberg et al. [2015] M. Jaderberg, K. Simonyan, A. Zisserman, and K. Kavukcuoglu. Spatial transformer networks. arXiv preprint arXiv:1506.02025, 2015.
  • Jaegle et al. [2021] A. Jaegle, F. Gimeno, A. Brock, A. Zisserman, O. Vinyals, and J. Carreira. Perceiver: General perception with iterative attention. arXiv preprint arXiv:2103.03206, 2021.
  • Jeurissen et al. [2016] D. Jeurissen, M. W. Self, and P. R. Roelfsema. Serial grouping of 2d-image regions with object-based attention in humans. Elife, 5:e14320, 2016.
  • Kahou et al. [2017] S. E. Kahou, V. Michalski, R. Memisevic, C. Pal, and P. Vincent. Ratm: recurrent attentive tracking model. In 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops (CVPRW), pages 1613–1622. IEEE, 2017.
  • Kim et al. [2019] J. Kim, D. Linsley, K. Thakkar, and T. Serre. Disentangling neural mechanisms for perceptual grouping. arXiv preprint arXiv:1906.01558, 2019.
  • Kingma and Ba [2014] D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kingma and Welling [2013] D. P. Kingma and M. Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Kosiorek et al. [2017] A. R. Kosiorek, A. Bewley, and I. Posner. Hierarchical attentive recurrent tracking. arXiv preprint arXiv:1706.09262, 2017.
  • Kreiman and Serre [2020] G. Kreiman and T. Serre. Beyond the feedforward sweep: feedback computations in the visual cortex. Annals of the New York Academy of Sciences, 1464(1):222, 2020.
  • LeCun et al. [1998] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • Li et al. [2018] X. Li, Z. Jie, J. Feng, C. Liu, and S. Yan. Learning with rethinking: Recurrently improving convolutional neural networks through feedback. Pattern Recognition, 79:183–194, 2018.
  • Liang and Hu [2015] M. Liang and X. Hu. Recurrent convolutional neural network for object recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3367–3375, 2015.
  • Locatello et al. [2020] F. Locatello, D. Weissenborn, T. Unterthiner, A. Mahendran, G. Heigold, J. Uszkoreit, A. Dosovitskiy, and T. Kipf. Object-centric learning with slot attention. arXiv preprint arXiv:2006.15055, 2020.
  • Mall et al. [2018] S. Mall, P. C. Brennan, and C. Mello-Thoms. Modeling visual search behavior of breast radiologists using a deep convolution neural network. Journal of Medical Imaging, 5(3):035502, 2018.
  • Maunsell and Treue [2006] J. H. Maunsell and S. Treue. Feature-based attention in visual cortex. Trends in neurosciences, 29(6):317–322, 2006.
  • Mnih et al. [2014] V. Mnih, N. Heess, A. Graves, and K. Kavukcuoglu. Recurrent models of visual attention. arXiv preprint arXiv:1406.6247, 2014.
  • Müller et al. [2003] N. G. Müller, O. A. Bartelt, T. H. Donner, A. Villringer, and S. A. Brandt. A physiological correlate of the “zoom lens” of visual attention. Journal of Neuroscience, 23(9):3561–3565, 2003.
  • Nayebi et al. [2021] A. Nayebi, J. Sagastuy-Brena, D. M. Bear, K. Kar, J. Kubilius, S. Ganguli, D. Sussillo, J. J. DiCarlo, and D. L. Yamins. Goal-driven recurrent neural network models of the ventral visual stream. bioRxiv, 2021.
  • Papadopoulos et al. [2021] A. Papadopoulos, P. Korus, and N. Memon. Hard-attention for scalable image classification. arXiv preprint arXiv:2102.10212, 2021.
  • Park et al. [2018] J. Park, S. Woo, J.-Y. Lee, and I. S. Kweon. Bam: Bottleneck attention module. arXiv preprint arXiv:1807.06514, 2018.
  • Paszke et al. [2019] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. arXiv preprint arXiv:1912.01703, 2019.
  • Pesce et al. [2019] E. Pesce, S. J. Withey, P.-P. Ypsilantis, R. Bakewell, V. Goh, and G. Montana. Learning to detect chest radiographs containing pulmonary lesions using visual attention networks. Medical image analysis, 53:26–38, 2019.
  • Qin et al. [2020] Y. Qin, N. Frosst, C. Raffel, G. Cottrell, and G. Hinton. Deflecting adversarial attacks. arXiv preprint arXiv:2002.07405, 2020.
  • Sabour et al. [2017] S. Sabour, N. Frosst, and G. E. Hinton. Dynamic routing between capsules. arXiv preprint arXiv:1710.09829, 2017.
  • Scholl [2001] B. J. Scholl. Objects and attention: The state of the art. Cognition, 80(1-2):1–46, 2001.
  • Shen et al. [2021] Y. Shen, N. Wu, J. Phang, J. Park, K. Liu, S. Tyagi, L. Heacock, S. G. Kim, L. Moy, K. Cho, et al. An interpretable classifier for high-resolution breast cancer screening images utilizing weakly supervised localization. Medical image analysis, 68:101908, 2021.
  • Spelke [1990] E. S. Spelke. Principles of object perception. Cognitive science, 14(1):29–56, 1990.
  • Spoerer et al. [2017] C. J. Spoerer, P. McClure, and N. Kriegeskorte. Recurrent convolutional neural networks: a better model of biological object recognition. Frontiers in psychology, 8:1551, 2017.
  • Stollenga et al. [2014] M. Stollenga, J. Masci, F. Gomez, and J. Schmidhuber. Deep networks with internal selective attention through feedback connections. arXiv preprint arXiv:1407.3068, 2014.
  • Tsai et al. [2020] Y.-H. H. Tsai, N. Srivastava, H. Goh, and R. Salakhutdinov. Capsules with inverted dot-product attention routing. arXiv preprint arXiv:2002.04764, 2020.
  • van Bergen and Kriegeskorte [2020] R. S. van Bergen and N. Kriegeskorte. Going in circles is the way forward: the role of recurrence in visual inference. Current Opinion in Neurobiology, 65:176–193, 2020.
  • Vaswani et al. [2017] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need. arXiv preprint arXiv:1706.03762, 2017.
  • Vecera [2000] S. P. Vecera. Toward a biased competition account of object-based segregation and attention. Brain and Mind, 1(3):353–384, 2000.
  • Wagemans et al. [2012] J. Wagemans, J. H. Elder, M. Kubovy, S. E. Palmer, M. A. Peterson, M. Singh, and R. von der Heydt. A century of gestalt psychology in visual perception: I. perceptual grouping and figure–ground organization. Psychological bulletin, 138(6):1172, 2012.
  • Wang and Liu [2018] D. Wang and Q. Liu. An optimization view on dynamic routing between capsules. 2018.
  • Wang et al. [2017] F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, and X. Tang. Residual attention network for image classification. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3156–3164, 2017.
  • Wang et al. [2014] Q. Wang, J. Zhang, S. Song, and Z. Zhang. Attentional neural network: Feature selection using cognitive feedback. arXiv preprint arXiv:1411.5140, 2014.
  • Wyatte et al. [2014] D. Wyatte, D. J. Jilk, and R. C. O’Reilly. Early recurrent feedback facilitates visual object recognition under challenging conditions. Frontiers in psychology, 5:674, 2014.
  • Xu et al. [2015] K. Xu, J. Ba, R. Kiros, K. Cho, A. Courville, R. Salakhudinov, R. Zemel, and Y. Bengio. Show, attend and tell: Neural image caption generation with visual attention. In International conference on machine learning, pages 2048–2057. PMLR, 2015.
  • Zamir et al. [2017] A. R. Zamir, T.-L. Wu, L. Sun, W. B. Shen, B. E. Shi, J. Malik, and S. Savarese. Feedback networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1308–1317, 2017.
  • Zhao et al. [2019] Z. Zhao, A. Kleinhans, G. Sandhu, I. Patel, and K. Unnikrishnan. Capsule networks with max-min normalization. arXiv preprint arXiv:1903.09662, 2019.
  • Zoran et al. [2020] D. Zoran, M. Chrzanowski, P.-S. Huang, S. Gowal, A. Mott, and P. Kohli. Towards robust image classification using sequential attention models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 9483–9492, 2020.

Supplementary Material

Appendix A Method Details

A.1 Pseudocode for OCRA Overview

Pseudocode 1 OCRA Architecture Overview
1:Initialize encoder and decoder RNN hidden states, h0enc,h0dech_{0}^{enc},h_{0}^{dec}, to zero
2:Initialize classification score vector s0s_{0} and reconstruction canvas0canvas_{0} to zero
3:for tt in timesteps=1,2,timesteps=1,2,\ldots do
4:     Get read glimpse: gtREAD(x,ht1dec)g_{t}\leftarrow READ(x,h^{dec}_{t-1}) \\ read operation in Pseudocode 3
5:     Apply two convolutional layers with max pooling: gtconvCNN(gt)g^{conv}_{t}\leftarrow\,CNN(g_{t})
6:     Update encoder RNN’s hidden representations: htencRNNenc(ht1enc,gtconv)h^{enc}_{t}\leftarrow RNN^{enc}(h^{enc}_{t-1},g^{conv}_{t})
7:     Compute object capsules through dynamic routing: dtCAPSULE(htenc)d_{t}\leftarrow CAPSULE(h^{enc}_{t}) \\ capsule operation in Pseudocode 2
8:     Accumulate classification score for digit jj: stjst1j+dtjs_{t}^{j}\leftarrow s_{t-1}^{j}+\|d^{j}_{t}\|
9:     Mask with zeros all but the digit capsule with max classification score: dtmaskmask(dt)d^{mask}_{t}\leftarrow mask(d_{t})
10:     Update decoder RNN’s hidden representations: htdecRNNdec(ht1dec,dtmask)h^{dec}_{t}\leftarrow RNN^{dec}(h^{dec}_{t-1},d^{mask}_{t})
11:     Reconstruct image with write attention: wtWRITE(htdec)w_{t}\leftarrow WRITE(h^{dec}_{t}) \\ write operation in Pseudocode 4
12:     Update reconstruction canvas: canvast=canvast1+wtcanvas_{t}=canvas_{t-1}+w_{t}
13:end for

A.2 Pseudocode for Capsule Representation and Dynamic Routing

The dynamic routing algorithm performs as follows: At each routing step, each primary level capsule ii provides a prediction for each object level capsule jj. These predictions are then combined using the coupling coefficients cijc^{ij} to compute the object level capsules. Then the agreement (dot product) between the object level capsules and the predictions from each primary level capsule impacts the coupling coefficients for the next routing step.

Pseudocode 2 Capsule Representation and Dynamic Routing
1:htench^{enc}_{t} = current encoder representation
2:procedure Capsule(htench^{enc}_{t})
3:     Apply linear transform on htench^{enc}_{t} to create primary capsule: ptWp(htenc)p_{t}\leftarrow{W_{p}}(h^{enc}_{t})
4:     For all object capsule jj and primary capsule ii initialize btijb^{ij}_{t} to zero
5:     for rr in routings=1,2,routings=1,2,\ldots do
6:         For all object capsule jj compute prediction from all primary capsule ii: p^tj|iWtijpti\hat{p}^{j|i}_{t}\leftarrow W^{ij}_{t}p^{i}_{t}
7:         For all object capsule jj: dtjsquash(ictijp^tj|i)d^{j}_{t}\leftarrow squash(\sum_{i}c^{ij}_{t}\hat{p}^{j|i}_{t}) \\ squash function in Eq. 4
8:         For all object capsule jj and primary capsule ii: btijbtij+p^tj|idtjb^{ij}_{t}\leftarrow b^{ij}_{t}+\hat{p}^{j|i}_{t}\cdot{d^{j}_{t}}
9:         Max-Min Normalize: ctijmaxmin(btij)c^{ij}_{t}\leftarrow maxmin(b^{ij}_{t}) \\ maxmin function in Eq.3
10:     end for
11:     For all object capsule jj: dtjsquash(ictijp^tj|i)d^{j}_{t}\leftarrow squash(\sum_{i}c^{ij}_{t}\hat{p}^{j|i}_{t})
12:end procedure

A.3 Max-Min Normalization

In the original CapsNets [Sabour et al., 2017], the softmax normalization was used for normalizing coupling coefficients in dynamic routing. However, the softmax operation fails to differentiate between the coupling coefficients. To remedy this, we used max-min normalization that is proposed in Zhao et al. [2019], applying it after each routing operation. Lower- and upper-bounds for normalization, lblb and ubub, were set to 0.01 and 1.0, respectively.

ctij=lb+(ublb)ctijmin(ctij)max(ctij)min(ctij)c^{ij}_{t}=lb+(ub-lb)\frac{c^{ij}_{t}-min(c^{ij}_{t})}{max(c^{ij}_{t})-min(c^{ij}_{t})} (3)

A.4 Squash Function

We use a vector implementation of capsules [Sabour et al., 2017] where the length of the vector represents the existence of the visual entity and the orientation characterizes its visual properties. To ensure that the capsule vector length ranges from 0 to 1, e.g., 0 for absence and 1 for presence, we used the non-linear squash function as below. The vtjv^{j}_{t} is the weighted sum of all prediction vectors from primary capsules for an object capsule jj (line 7 in Pseudocode 2).

dtj=vtj21+vtj2vtjvtjd^{j}_{t}=\frac{\|v^{j}_{t}\|^{2}}{1+\|v^{j}_{t}\|^{2}}\frac{v^{j}_{t}}{\|v^{j}_{t}\|} (4)

A.5 Pseudocode for Read Attention

Pseudocode 3 Read Attention
1:ht1dech^{dec}_{t-1} = previous decoder hidden representation
2:NN= the size of read attention glimpse; W,HW,H= (image width, height)
3:procedure ReadAttention(x,ht1decx,h^{dec}_{t-1}) \\ Get a N×\timesN read glimpse gtg_{t} by applying horizontal and vertical Gaussian filterbank matrices FX,FYF_{X},F_{Y} to the image
4:     Get attention grid parameters: gX,gY,logδ,logσ2Wread_attention(ht1dec)g_{X},g_{Y},\log\delta,\log\sigma^{2}\leftarrow W_{read\_attention}(h^{dec}_{t-1})
5:     Scale attention grid centers: gXA+12(gX+1)g_{X}\leftarrow\frac{A+1}{2}(g_{X}+1), gYA+12(gY+1)g_{Y}\leftarrow\frac{A+1}{2}(g_{Y}+1)
6:     Scale attention grid stride: δmax(A,B)1N1δ\delta\leftarrow\frac{max(A,B)-1}{N-1}\delta
7:     Get x, y locations of each grid point i,ji,j:
8:           μXi=gX+(iN/20.5)δ\mu^{i}_{X}=g_{X}+(i-N/2-0.5)\delta
9:           μYi=gY+(jN/20.5)δ\mu^{i}_{Y}=g_{Y}+(j-N/2-0.5)\delta
10:     Compute Gaussian filterbank matrices (ZXZ_{X} and ZYZ_{Y} are normalization constants):
11:           FX[i,w]=1ZXexp((wμXi)22σ2)F_{X}[i,w]=\frac{1}{Z_{X}}\exp({-\frac{(w-\mu^{i}_{X})^{2}}{2\sigma^{2}}})
12:           FY[j,h]=1ZYexp((hμXj)22σ2)F_{Y}[j,h]=\frac{1}{Z_{Y}}\exp({-\frac{(h-\mu^{j}_{X})^{2}}{2\sigma^{2}}})
13:     Get a read glimpse gtg_{t}: gtFYxFXTg_{t}\leftarrow F_{Y}xF^{T}_{X}
14:end procedure

A.6 Pseudocode for Write Attention

Pseudocode 4 Write Attention
1:htdech^{dec}_{t} = current decoder hidden representation
2:MM= the size of write attention glimpse, W,HW,H= (image width, height)
3:procedure WriteAttention(htdech^{dec}_{t}) \\ Reconstruct a M×\timesM write glimpse wtw_{t} to the original image size by applying transposed Gaussian filterbank matrices FX^,FY^\hat{F_{X}},\hat{F_{Y}}
4:     Get a write glimpse: wtWwrite(htdec)w_{t}\leftarrow W_{write}(h^{dec}_{t})
5:     Get attention grid parameters: gX^,gY^,logδ^,logσ^2Wwrite_attention(htdec)\hat{g_{X}},\hat{g_{Y}},\log\hat{\delta},\log\hat{\sigma}^{2}\leftarrow W_{write\_attention}(h^{dec}_{t})
6:     Scale attention grid centers: gX^A+12(gX+1)\hat{g_{X}}\leftarrow\frac{A+1}{2}(g_{X}+1), gY^A+12(gY+1)\hat{g_{Y}}\leftarrow\frac{A+1}{2}(g_{Y}+1)
7:     Scale attention grid stride: δ^max(A,B)1M1δ^\hat{\delta}\leftarrow\frac{max(A,B)-1}{M-1}\hat{\delta}
8:     Get x, y locations of each grid point i,ji,j:
9:           μ^Xi=gX^+(iM/20.5)δ^\hat{\mu}^{i}_{X}=\hat{g_{X}}+(i-M/2-0.5)\hat{\delta}
10:           μ^Yi=gY^+(jM/20.5)δ^\hat{\mu}^{i}_{Y}=\hat{g_{Y}}+(j-M/2-0.5)\hat{\delta}
11:     Compute Gaussian filterbank matrices (ZX^\hat{Z_{X}} and ZY^\hat{Z_{Y}} are normalization constants):
12:           FX^[i,w]=1ZX^exp((wμ^Xi)22σ^2)\hat{F_{X}}[i,w]=\frac{1}{\hat{Z_{X}}}\exp({-\frac{(w-\hat{\mu}^{i}_{X})^{2}}{2\hat{\sigma}^{2}}})
13:           FY^[j,h]=1ZY^exp((hμ^Xj)22σ^2)\hat{F_{Y}}[j,h]=\frac{1}{\hat{Z_{Y}}}\exp({-\frac{(h-\hat{\mu}^{j}_{X})^{2}}{2\hat{\sigma}^{2}}})
14:     Convert the write glimpse into the original image size: wtF^YTwtF^Xw_{t}\leftarrow\hat{F}^{T}_{Y}w_{t}\hat{F}_{X}
15:end procedure

A.7 Visual Illustrations of Read and Write Attention

A glimpse, gtg_{t}, is sampled through application of a grid of N ×\times N (18×\times18 in our experiments) Gaussian filters on the input image xx. The Gaussian filters are generated using four parameters: gX,gY,δ,σ2g_{X},g_{Y},\delta,\sigma^{2}, which specify the center coordinates of the attention window, the distance between equally spaced Gaussian filters in the grid, and the variance of the filters, respectively. The decoder output htdech^{dec}_{t} is linearly transformed into an M ×\times M write patch wtw_{t} (set to 18 ×\times 18 in our experiments), which is then multiplied by the Gaussian filters to reconstruct the written parts in original image size. The Gaussian filters used for the write operation differed from those used for the read operation.

Refer to caption
Figure 6: Visual illustrations of read and write attention mechanisms

A.8 Loss functions

OCRA outputs object classification scores (cumulative capsule lengths) and image reconstructions (cumulative write canvas). Losses are computed for each output and combined with a weighting λrecon\lambda_{recon} (Eq. 1). For reconstruction loss, we simply computed the mean squared differences in pixel intensities between the input image and the model’s reconstruction. For classification, we used margin loss (Eq. 2) so as to give us the flexibility to define ground truth classification that can have multiple objects in each category. This loss ensures that the network yields classification scores similar to ground truth for each image. The loss term has two terms. For each class j, the first term is multiplied by the ground truth TjT_{j} so this term only would come into play when an object class is present in the ground truth (LpresentL_{present} in Pytorch code). Minimizing this loss eventually pushes the model scores for this capsule to be bigger than (Tjm)(T_{j}-m), and with mm set at 0.1 this means bigger than .9 when one object is present or bigger than 1.9 when two objects from this category are present in the ground truth. The second term is multiplied by a weighting first and then by max(0,1Tj)\max(0,1-T_{j}), which makes this loss come into play only when the class is not present in the ground truth (LabsentL_{absent} in Pytorch code). In that case the loss pushes the model scores for this class to be smaller than mm (<0.1<0.1). These two terms together, when summed across all classes, provide the margin loss for classification. Pytorch implementation of the loss functions are provided below.

TotalLoss=jclassClassLossj+λreconReconLoss\displaystyle Total\,Loss=\sum_{j\in class}Class\,Loss_{j}+\lambda_{recon}Recon\,Loss (5)
ClassLoss=jclassmax(0,min(Tj,1))max(0,(Tjm)dj)2+λabsentmax(0,1Tj)max(0,djm)2\displaystyle\begin{aligned} Class\,Loss=\sum_{j\in class}max(0,min(T_{j},1))\cdot max(0,(T_{j}-m)-\|d_{j}\|)^{2}+\\ \lambda_{absent}\cdot max(0,1-T_{j})\cdot max(0,\|d_{j}\|-m)^{2}\end{aligned} (6)
1# Margin classification loss
2m = 0.1 # margin of error
3lam_abs = 0.5 # down-weighting loss for absent digits
4
5L_present = torch.clamp(y_true, min=0., max=1.) * torch.clamp((y_true-m) - y_pred, min=0.) ** 2
6L_absent = lam_abs * torch.clamp(1 - y_true, min=0.) * torch.clamp(y_pred-m, min=0.) ** 2
7L_margin = (L_present+L_absent).sum(dim=1).mean()
8
9# Reconstruction loss
10L_recon = nn.MSELoss()(c, x)
11
12# Total loss
13Loss = L_margin+ lam_recon*L_recon
Listing 1: Pytorch Implementation of Margin Loss and Reconstruction Loss

A.9 Measuring accuracy

When we convert the model scores to multi-object classification, we take into account the thresholds that are set in the loss function. If the model prediction for one class capsule is larger than 1.8 (conservatively selected to be slightly lower than 1.9), this signals the presence of two objects from this class in the image. If no class score is larger than this threshold, the top two highest scores are selected as the model predictions for the two objects in the image.

A.10 Training Regime and Hyperparameter selection

OCRA is implemented in Pytorch and the code is available in this repository. All the model training took place on a single GPU workstation. All the Models trained on the MultiMNIST task were early stopped after 50 epochs of training (taking about 40 hours for the 10glimpse model). The models trained on the MultiMNIST cluttered task were early stopped after 1000 epochs (taking about 40 hours for the 5glimpse model). The models trained on the MultiSVHN task were also stopped after 1000 epochs (taking about 50 hours). We used the Adam optimizer [Kingma and Ba, 2014] for all the experiments.

Table 4: Hyperparameter Setting for MultiMNIST and MultiMNIST Cluttered tasks
Hyperparameters MultiMNIST-(10/3)glimpse Cluttered-(5/7)glimpse MultiSVHN
# timesteps, t 10/3 5/7 12
# epoch 50 1000 1000
lr 0.001 0.001 0.001
batch size 128 128 128
read glimpse size, NN 18 18 18
write glimpse size, MM 18 18 18
# conv1 filters 32 32 64
# conv2 filters 32 32 64
lstm size, dim(henc),dim(hdec)dim(h_{enc}),dim(h_{dec}) 512 512 512
# primary capsule 40 40 40
primary capsule dimension, dim(p)dim(p) 8 8 8
# routings r 3 3 3
object capsule dimension, dim(d)dim(d) 16 16 16
# background capsules 0 1 0
reconstruction loss weight, λrecon\lambda_{recon} 10/3 175/200 200
clipping final canvas to [0,1] TRUE FALSE TRUE
use reconstruction mask FALSE TRUE FALSE

Table 4 provides all the hyperparameters used for the three tasks. The hyperparameters were mostly the same between the tasks with the few differences. For the MultiMNIST Cluttered task, we added one background capsule and utilized a reconstruction mask, both to allow the model to focus on the main objects and ignore the background clutter. The image is larger (100×\times100 vs 36×\times36) in the MultiMNIST Cluttered task, with most of it being empty, the reconstruction loss therefore has a much smaller range compared to the other task. For this reason, and the use of the reconstruction mask, we use a much larger weight for the reconstruction loss to make it comparable to the margin loss.

For the MultiMNIST task, we clip the cumulative reconstruction canvas to be between [0,1] before comparing it to the input image. This allows the model to overlap different segments in the multi-step process of writing to the canvas without increasing the loss, improving model reconstruction given the high degree of overlap between the digits.

For the MultiSVHN task, we increased the number of the filters in each of the two convolutional layers to 64.

Appendix B Additional Results

Refer to caption
Figure 7: OCRA MultiMNIST output with 10 glimpses, the top 5 rows are the inputs and the bottom 5 rows are model reconstructions.
Refer to caption
Figure 8: OCRA MultiMNIST output with 3 glimpses, the top 5 rows are the inputs and the bottom 5 rows are model reconstructions.
Refer to caption
Figure 9: OCRA MultiMNIST Cluttered output with 5 glimpses, the top 5 rows are the inputs and the bottom 5 rows are model reconstructions.
Refer to caption
Figure 10: OCRA MultiMNIST Cluttered output with 5 glimpses showing the gradual object-based reconstruction on the cumulative canvas.
Refer to caption
Figure 11: OCRA MultiSVHN output showing the stepwise reconstruction.