Attentional Prototype Inference for Few-Shot Segmentation
Abstract
This paper aims to address few-shot segmentation. While existing prototype-based methods have achieved considerable success, they suffer from uncertainty and ambiguity caused by limited labeled examples. In this work, we propose attentional prototype inference (API), a probabilistic latent variable framework for few-shot segmentation. We define a global latent variable to represent the prototype of each object category, which we model as a probabilistic distribution. The probabilistic modeling of the prototype enhances the model’s generalization ability by handling the inherent uncertainty caused by limited data and intra-class variations of objects. To further enhance the model, we introduce a local latent variable to represent the attention map of each query image, which enables the model to attend to foreground objects while suppressing the background. The optimization of the proposed model is formulated as a variational Bayesian inference problem, which is established by amortized inference networks. We conduct extensive experiments on four benchmarks, where our proposal obtains at least competitive and often better performance than state-of-the-art prototype-based methods. We also provide comprehensive analyses and ablation studies to gain insight into the effectiveness of our method for few-shot segmentation.
keywords:
Few-Shot Segmentation, Variational Inference, Probabilistic Model, Latent Attention1 Introduction
Semantic segmentation [1] has been a fundamental problem in computer vision with widespread application potential in a great variety of areas, e.g., autonomous driving and scene understanding. Existing models based on deep convolutional neural networks and trained on massive amounts of manually labeled images, e.g., [2, 3], have obtained impressive results. However, since their performance relies heavily on access to a large number of pixel-wise annotations, it remains challenging to achieve desirable results in practice when the training data is scarce. Therefore, few-shot segmentation [4] has emerged as a popular task to address the annotation scarcity issue of traditional semantic segmentation methods.
Few-shot segmentation generalizes the idea of few-shot classification under the setting of meta-learning [5]. In meta-learning, the dataset is split into meta-training, meta-validation, and meta-testing sets. Few-shot segmentation solutions usually sample an episode consisting of a support and query set from these meta-sets to train a learning procedure that takes the support set as input and produces the prediction on the query set. The model can then achieve effective adaption to new tasks at test time. In this paper, we focus on few-shot segmentation, where we aim to segment the object of an unseen category in a query image with the support of only a few annotated images.
To alleviate the scarcity of annotated data, most existing works extract supervisory information for the objects from a small set of support images. Inspired by the prototype theory from cognitive science [6] and prototype networks for few-shot classification [7], many segmentation models are designed to learn a prototype vector to represent a category in the support set [8]. The optimization goal is then to obtain a shared feature extractor that generalizes to the segmentation of new objects [9]. While prototype-based methods have shown great efficiency in few-shot segmentation, there still exist three major deficiencies. 1) existing methods map the support images into a deterministic prototype vector, which is often ambiguous and vulnerable to noise under the few-shot setting, especially for one-shot learning tasks. As illustrated in Figure 1, intrinsic ambiguities usually exist in images. Deterministic models neglect those ambiguities and merely provide the most likely hypothesis that might cause sub-optimal decision. 2) The prototype vector loses the structure information of the object in the query image. 3) A deterministic prototype vector contains only its first-order statistics, which are unable to represent large intra-class variations of objects in the same category. To address these issues, we propose a probabilistic latent variable framework, referred to as attentional prototype inference (API), for few-shot segmentation.

We make three contributions in this work. (1) We provide a fully probabilistic framework for few-shot segmentation. We introduce a global latent variable into this model, which represents the prototype of each object category. The probabilistic framework models the prototype as a distribution rather than a vector, making it more robust to noise and better equipped to deal with ambiguity than the deterministic model. The probabilistic prototype also better represents object categories by effectively capturing intra-class variations. (2) The second contribution is that we introduce a variational attention mechanism. We define the attention vector as the local latent variable associated with each image and infer its probabilistic distribution, which is jointly estimated within the same framework as the variational prototype inference. The main motivation of variational attention is to enable the model to capture the appearance variation of an object. (3) Our third contribution is to formulate the optimization as a variational inference problem to jointly estimate posteriors over latent variables. The optimization objective is built upon a newly derived evidence lower bound (ELBO), which fits the few-shot segmentation problem well and offers a principled way to model prototypes and attention for few-shot segmentation.
To evaluate our attentional prototype inference, we conduct extensive experiments on four benchmarks, i.e., Pascal-, MS-COCO, FSS-1000, and LIDC-IDRI. The comparison results show that our attentional prototype inference achieves at least competitive and often better performance than state-of-the-art prototype-based methods on both the 1-shot and 5-shot segmentation tasks, demonstrating its effectiveness for few-shot segmentation. Quantitative comparison in terms of the cross energy distance on the medical image dataset shows the good statistical property of the proposed algorithm, which can faithfully handle the uncertainty stemming from intrinsic ambiguities in images. We also conduct ablation studies to gain insight into the proposed attentional prototype inference by demonstrating the benefit of different model components to the overall performance.
2 Related Work
2.1 Many-Shot Semantic Segmentation
Semantic segmentation aims to segment a given image into several pre-defined classes and is often regarded as a pixel-level classification task [1]. State-of-the-art semantic segmentation methods [10, 11] based on deep convolutional neural networks have achieved astonishing success. The fully convolutional network (FCN) [3] was the first model to introduce end-to-end convolutional neural networks into segmentation tasks. The essential innovation in FCN is replacing the fully-connected layer with a fully convolutional architecture to preserve the spatial information for better performance. Follow-up efforts have attempted to aggregate multiple pixels to explicitly model context. For example, DeepLab [2] introduces a dilated convolution operation to enlarge the perception field while maintaining the resolution, and PSPNet [12] employs a pyramid pooling module to aggregate multi-scale context information. Another novel pyramid module [13] is designed to capture and filter the multi-scale information in a gated and pair-wise manner.
Though they achieve impressive performance, these methods heavily rely on labeled training samples with pixel-level annotations. However, pixel-level annotations are expensive and difficult to obtain. Moreover, the deep semantic segmentation models usually perform modestly on new categories of objects that are unseen in the training set, which restricts their use in practical applications.
2.2 Few-Shot Segmentation
Few-shot segmentation aims to segment images from arbitrary classes by learning transferable knowledge from scarce annotated support images, which has recently gained popularity in computer vision applications. Shaban et al. [14] introduced the first few-shot segmentation network based on a two-branch architecture, which uses a support branch to predict the parameters of the last layer of the query branch for segmentation. Recent works [15] also follow this two-branch architecture for few-shot segmentation. Dong et al. [9] generalized the idea of prototype networks [7] from few-shot recognition for few-shot segmentation. They designed the PLNet, in which the first branch learns a prototype vector that takes images and annotations as input and outputs the prototype. Meanwhile, the second branch takes both a new image and the prototype as input and outputs the segmentation mask. Since then, prototype-based methods have been further developed using different strategies [8, 16].
To achieve sufficient representation for the class prototype, Yang et al. [17] designed prototype mixture models to correlate diverse image regions with multiple prototypes. Tian et al. [18] proposed the feature enrichment module to overcome spatial inconsistency. Li et al. [19] designed a prototype alignment regularization to generate a more consistent prototype between support and query. Okazawa [16] proposed to reduce the similarity between prototypes of each class and leverage the relationship between classes in a batch. These works have demonstrated the effectiveness of prototype learning for few-shot segmentation. However, a deterministic prototype vector is not sufficiently representative for capturing the categorical concepts of objects and therefore can cause bias and reduced generalization when objects in the same categories vary. In this work, we develop a variational attention mechanism by placing a distribution over the attention vector, which enables the model to better capture the appearance variation of individual objects.
Other strategies have been adopted in recent years and achieve substantial improvement in this task. Singh et al. [20] adopted a gradient-based meta-learning algorithm and integrated augmentations for few-shot medical imaging segmentation. Min et al. [21] introduced hypercorrelation squeeze networks with 4D convolution layers to characterize correspondences in multiple visual aspects between support and query images. Recently, Swin Transformer [22] has been applied to handle a high-dimensional correlation map and achieved superior performance on those datasets. Gaussian process [23] has also been introduced to extract detailed relations among support images. Fan et al. [4] designed a novel self-support matching strategy to resolve the issue of Gestalt principle.
2.3 Variational Inference
Variational auto-encoder (VAE) [24] is a generative model that introduces variational inference (VI) [25] into the learning of directed graphical models. It has also been broadly applied in segmentation tasks. For example, Kohl et al. proposed the probabilistic U-net [26] which combines C-VAE with U-Net for image segmentation. It learns a distribution over the segmentation masks to handle ambiguities, especially for medical images. Zhang et al. [27] deployed a latent variable to denote the distribution of the entire dataset, which is inferred from the support set. They also showed that their variational learning strategy can be modified to classify proposals for instance segmentation.
We address few-shot segmentation based on prototypes using a probabilistic latent variable model. We treat the prototype that represents the concept of the object category as a global latent variable, which is modeled as a distribution instead of a single deterministic vector. We further introduce a local latent variable to generate the attention map, which is learned for each image to highlight foreground objects. We solve the whole model by variational Bayesian inference, in which latent variables are jointly learned in the same framework.
3 Methodology
We adopt the meta-learning setting to conduct few-shot segmentation. We learn a segmentation model on the meta-training set and then evaluate it on the meta-testing set . Different from the traditional semantic segmentation, there is no overlap between the object categories in and . To achieve few-shot segmentation, we follow the episodic paradigm for training and testing under a -shot setting, where denotes the number of training images in an episode. In practice, we sample one episode each time from for training or for evaluation. Each episode is composed of a support set and a query set . Here, denotes the support image with a height of and width of . is its corresponding support mask. Similarly, is the query image and is the associated ground-truth mask of the object to be segmented. The goal of the few-shot segmentation model is to extract transferable knowledge from the support set and apply it to the segmentation of a query image . The predicted segmentation map is denoted as for .
3.1 Attentional Prototype Inference
From a probabilistic perspective, the purpose of few-shot segmentation is to find the prediction that can maximize the probability of the conditional predictive distribution over the segmentation map for a given query image , when provided the support set .
3.1.1 Probabilistic Modeling
We introduce a latent variable to represent the class prototype, which is conditioned on the corresponding support set. By incorporating the latent variable, we have the conditional predictive log-likelihood as follows
(1) |
where is a conditional prior. The model in (1) provides a probabilistic modeling of prototypes for semantic segmentation, which was introduced in our preliminary work [28]. In this way, our prototype serves as a global representation of an object category while previous ones do not take into account the local spatial structure of the image [9].
To further enhance the model in (1), we introduce a local latent variable to represent the attention map associated with each image, which highlights the foreground object. The conditional predictive log-likelihood with respect to the two latent variables takes the following form
(2) | ||||
where we also deploy a conditional prior for , since the attention maps should be specific for each individual query image .
However, these posteriors are intractable in practice. Thus, we introduce the variational posterior to approximate the true posteriors by minimizing their Kullback-Leibler (KL) divergence. We employ the variational distributions and for the prototype and the attention map , respectively. By incorporating the variational posteriors into the conditional log-likelihood of (2), we arrive at
(3) | ||||
Applying Jensen’s inequality gives rise to the ELBO as follows
(4) | ||||
The first term of the ELBO is the expectation of the log-likelihood of the conditional generative distribution based on the inferred prototypes and attention maps . The second term is the KL divergence between the estimated posterior distribution and the prior distribution . Minimizing this term encourages the model to leverage the object information for the segmentation of the query image. Minimizing the third term of the KL divergence enables the model to generate attention maps that highlight the foreground object. We derive the optimization objective based on the ELBO.
3.1.2 Optimization Objective
Maximizing the ELBO can yield accurate predictions for the segmentation masks and narrow the gap between the posterior and the prior distributions. This encourages 1) the inferred prototype from the support dataset to match the full dataset by minimizing the first KL term; and 2) the inferred map from the query set to approach the one based merely on the query image by minimizing the second KL term. Based on the ELBO, we define the empirical objective function for optimization.
Given a batch of episodes, the empirical objective for stochastic optimization with the Monte Carlo estimation of expectations is as follows
(5) | ||||
where indexes over the sampled episode in the meta-training set , and are the variables sampled from their variational distribution. are the number of samples. Generally, the parameters of a neural network are optimized jointly with stochastic gradient descent. However, it is usually intractable to calculate the gradient of the sampling operation. Therefore, we deploy the reparameterization trick [24] to handle the non-differentiable problem of the sampling process. Specifically, supposing the two variational posterior distributions take the form of a multivariate Gaussian with a diagonal covariance, the sampling process is formulated as
(6) |
During training, the samples of the class prototype are obtained by:
(7) |
where denotes an element-wise multiplication and . The same operation is also deployed for sampling .
The first term of the empirical loss in (5) is implemented as a least square loss in [24]; we generalize this to a pixel-wise cross-entropy loss to penalize the difference between the predicted segmentation map and the ground truth . The number of samples and are set to 1 during training to speed up the learning process. Since the KL terms minimize the discrepancy between two distributions, the prior networks can mimic the behavior of the posterior networks that produce effective prototypes or attention maps at training time.
3.1.3 Segmentation Map Inference
The inference of segmentation maps varies between the learning and inference stages. At test time, instead of sampling from the variational posterior distributions, we draw samples of prototypes from the prior and samples of attention vectors from . is obtained by taking the average of segmentation maps from these samples
(8) |
where
(9) |
and
(10) |
3.2 Neural Networks Implementation
We implement our attentional prototype inference with neural networks using the amortization technique [24], which is seamlessly integrated into the autoencoder architecture, as shown in Figure 2. We parameterize the distributions as factorized Gaussian distributions with diagonal covariance matrices.

Prior Networks. The prior network for prototypes embeds the support set into a function space, where the conditional prior distribution lies. The prior network for attention maps is encouraged to generate an effective attention map as . For the prior network of prototypes, we construct a multi-layer perceptron (MLP) with three fully connected layers. The extracted deep features of images are selected with their segmentation masks to obtain the foreground features. A permutation-invariant pooling layer [15] then squeezes them into a single vector . In this work, we assume that the prior follows a diagonal covariance Gaussian distribution. Given the single feature vector, the mean and variance w.r.t. come from the output of the MLP:
(11) |
The main spirit behind the prior network of the attention maps is similar but we employ a transformer architecture [29] to extract the structure information with high fidelity. This prior transformer contains a pixel-level self-attention and aggregates all pixels into a vector. Then a two-layer perceptron is concatenated to the output and w.r.t. . The sampled is directly used for computing the attention map.
Posterior Networks. From the VI view, the posterior network for prototypes is trained to approximate its true posterior distribution given a pair of query samples and the support set . The posterior network for attention maps is trained in a similar way but using only the query pair. The posterior networks have the same architecture as the prior network of prototypes. However, their outputs come from the aggregation of outputs for since it has pairs of inputs. Here, the posterior network generates an attention vector about all the input pairs. We then compute the cosine distance between the attention vector and feature embedding at the pixel level to construct an attention map :
(12) |
Segmentation Network. Finally, the segmentation network takes the query image , and the sampled prototype vector , and estimates the attention maps as inputs to predict the segmentation map , which is the Monte Carlo estimation of the conditional generative distribution . Once we sample an attention map from this distribution, we multiply it by the deep feature of the query image to enhance a structured embedding. The segmentation net concatenates the attentive embedding and the prototype vector sampled from the prior (see Figure 8) together and produces the output segmentation map:
(13) |
At testing time, API generates multiple samples for the prototype and the attention map. This achieves a more accurate prediction using the ensemble of all outputs:
(14) |
The segmentation network adopts a multi-layer skip-connections structure [30] to incorporate more spatial information. Besides, there is a CNN-based encoder for the feature embedding shared by the prior, posterior, and segmentation networks. API is an elegant end-to-end framework. All networks are jointly optimized by minimizing the objective (5). The learning algorithm can be summarized in Algorithm 1. All gradients computations can be efficiently implemented by automatic differentiation tools.
4 Experiments
4.1 Datasets and Implementation Details
We conduct experiments on three commonly used few-shot segmentation benchmarks including PASCAL- [14], COCO- [34] and FSS-1000 [35] and one medical imaging dataset of LIDC-IDRI [36]. We provide detailed descriptions of these datasets associated with experimental settings as follows. The code for replicating our experiments is available on GitHub (https://github.com/haolsun/API)
4.1.1 Datasets
PASCAL- originates from PASCAL VOC12 and extends annotations. We follow the settings in [17], splitting the 20 original classes into four folds and conducting cross-validation among them. Specifically, we select 15 classes for training, while the remaining 5 classes are for testing. For a fair comparison, we adopt the same strategy as [17], randomly sampling 1,000 episodes of support-query pairs for evaluation.
COCO- is a challenging dataset built upon MS-COCO with 80 object categories. We also divide the 80 classes in MS-COCO into four folds and conduct four-fold cross-validation. Under the same settings as PASCAL-, 60 object categories are selected for training, while the remaining 20 categories are used for testing. In each fold, we sample 1000 support-query pairs from the 20 testing classes for evaluation, following [17].
FSS-1000 is a specialized few-shot segmentation dataset. It contains 1,000 object categories including 520 classes for training, 240 classes for validation, and 240 classes for testing. Following [35], we choose the same 240 categories for testing and train the model on the specified 520 classes with the support of the validation set. The number of testing episodes is 1,000.
LIDC-IDRI is a dataset to represent typical ambiguities in vision signals. It contains 1,018 thoracic CT cases from 1,010 lung patients. The task is to segment lesions of lung nodules given a CT slice. Each CT case has been annotated by 4 individual radiologists that provide segmentation masks for independently detected lesions. We obtain 2,630 sub-cases by extracting a series of slices from 3D volumes and cropping 2D slices to 128 128 pixels centered at the lesion positions. Each sub-case contains a set of 2D images with different cardinalities. Each image corresponds to four annotated masks from independent radiologists (See Figure 1). For the setting of few-shot segmentation, we consider a sub-case as one episode, i.e., randomly sampling support and query images from one sub-case without replacement. For a total of 2,630 sub-cases, we choose the first 2,030 sub-cases for training and testing the model on the last 300 sub-cases. The remaining 300 sub-cases are for validation. The total numbers of images for training, testing, and validation are 11,631, 1,943, and 1,974.



4.1.2 Implementation Details
We adopt a ResNet101 backbone pre-trained on ImageNet as the encoder. The decoder is designed as a skip-connection structure [30], which is composed of three convolutional blocks to generate segmentation maps. Each block receives the input of the concatenation with the corresponding encoded feature through the skip connections and the decoding features. The structure of the decoder, prior, and posterior networks are listed in Figures 3, 4, and 5. We choose Adam as the optimizer and train the model on four NVIDIA Tesla V100 with around 30 epochs. The learning rate is fixed to for the backbone and for other layers, and the batch normalization (BN) layers are frozen during training. The numbers of samples and are set to 10 during the test phase, which is analyzed in detail by our ablation study in Section 4.3.1.
We adopt the same metrics as [14] for evaluation, i.e. Class-IoU (C-IoU) and Binary-IoU (B-IoU). Class-IoU measures the intersection-over-union
(15) |
for each class, where TP, FP and FN are the number of pixels that are true positives, false positives and false negatives of the predicted segmentation masks for each foreground category . Binary-IoU measures the IoU between the foreground and background pixels, where all object classes are treated as foreground.
We generalize the IoU metric to the cross energy distance (CED) on the LIDC-IDRI dataset to measure the distance between two distributions. The core idea of CED is leveraging distances between observations. Let denote the annotation and the prediction, respectively. is the distance between two observations. For independent annotation samples and prediction samples, we have
(16) |
Here, we choose . A smaller CED value indicates that the prediction distribution is close to the ground truth distribution and can characterize the ambiguity that appears in images. In our LIDC-IDRI experiments, we set since there are four annotation masks for one image.
4.2 Comparison with State-of-the-Arts
C-IoU | B-IoU | ||||
---|---|---|---|---|---|
Method | Backbone | 1-shot | 5-shot | 1-shot | 5-shot |
PMM [17] | ResNet50 | 56.3 | - | 57.3 | - |
CRNet [31] | ResNet50 | 55.7 | - | 58.8 | - |
FS-PARN [19] | ResNet50 | 53.7 | 67.9 | 58.3 | 72.6 |
FWB [32] | ResNet101 | 56.2 | 59.9 | - | |
API (Ours) | ResNet101 | 57.4 | 71.4 | 60.7 | 73.2 |
HSNet [21] | ResNet101 | 66.2 | 72.5 | 70.4 | 80.6 |
VAT [22] | ResNet101 | 67.5 | 78.8 | 71.6 | 82.0 |
DGPNet [23] | ResNet101 | 64.8 | - | 75.4 | - |
SSP [4] | ResNet101 | 64.6 | - | 73.1 | - |
API* | ResNet101 | 64.7 | 76.3 | 70.2 | 80.4 |
4.2.1 Performance on PASCAL-
In Table 1, we compare the performance of API with typical methods on PASCAL- in terms of the Class-IoU metric and Binary-IoU. API outperforms those prototype-based methods by good margins under both the 1-shot and 5-shot settings (57.4%, 60.7%). The 1-shot setting is more challenging than the 5-shot setting due to the much larger intra-class variation. The Monte Carlo estimation in our probabilistic model serves as an ensemble of the prediction results. Specifically, API outperforms PMM which computes segmentation masks with multiple prototypes. This accounts for the robustness of our API in the 1-shot case. We also evaluate the model in terms of Binary-IoU. Our model again yields comparable performance under both the 1-shot and 5-shot settings of and .
However, due to the intrinsic frailty of prototype-based methods, API has limited capability of handling objects with small sizes or complex shapes. Recall the computation of the prototype is from feature maps at the high level, discriminative information of objects with small sizes may vanish as downsampling operations, especially for small objects with complex backgrounds. This challenge also exists when a large object has complex shapes, i.e.some parts of the object, such as table legs, might be omitted. Hierarchical prototypes would be potential to tackle it as we discussed in Conclusion. Therefore, we remove one class with the lowest confidence in each fold and demote it as API*. As shown in the bottom part of Table 1, API* can achieve considerable performance compared with the state-of-the-art, e.g., for the setting of 1-shot. Compared with other methods, our prototype-based model is more efficient and easy to be implemented. Because of the advanced architecture of transformers, VAT [22] obtained the highest IoU metric for most cases on PASCAL-.

Some qualitative results on PASCAL- are visualized in Figure 6. The proposed API is capable of producing more accurate segmentation under various challenging scenarios, where the query images vary in appearance and object size from the associated support images. For instance, in the fifth column, the size and viewpoint of the bus in the query image is significantly different from the annotated plane in the support image; in the eighth column, the annotated boat in the support image is much smaller than the one in the query image.
C-IoU | B-IoU | ||||
Method | Backbone | 1-shot | 5-shot | 1-shot | 5-shot |
PANet [33] | VGG16 | 20.9 | 29.7 | 59.2 | 63.0 |
PMM [17] | ResNet50 | 30.6 | 35.5 | - | - |
FS-PARN [19] | ResNet50 | 29.5 | 36.2 | - | - |
A-MCG [34] | ResNet101 | - | - | 52.0 | 54.7 |
FWB [32] | ResNet101 | 21.2 | 23.7 | - | - |
API (Ours) | ResNet101 | 36.3 | 41.0 | 61.9 | 62.7 |
HSNet [21] | ResNet101 | 41.2 | 49.5 | 69.1 | 72.4 |
VAT [22] | ResNet101 | 41.3 | 47.9 | 68.8 | 72.4 |
DGPNet [23] | ResNet101 | 46.7 | 57.9 | - | - |
SSP [4] | ResNet101 | 42.0 | 50.2 | - | - |
API* | ResNet101 | 41.8 | 49.4 | 68.2 | 70.8 |
4.2.2 Performance on COCO-
COCO- is more challenging than PASCAL- since the scenes in COCO- are more complex with more intra-class diversity. Therefore, few-shot segmentation on COCO- has more ambiguity and it is difficult to acquire an effective class-specific prototype. As can be seen in Table 2, our method outperforms the homogeneous method PMM [17] by and in terms of Class-IoU under the 1-shot and 5-shot settings. Since the COCO- dataset contains more complex objects than PASCAL-, we drop five classes with the lowest confidence in each fold. As shown in Table 2, we obtain reasonable results on the rest of 15 classes in each fold, e. g., of Class-IoU. As DGPNet [23] applied the powerful probabilistic model of Gaussian Process and introduced extra mask encoders, it non-trivially outperforms those recent works. The qualitative results are provided in Figure 7. Our method successfully predicts the segmentation maps for query images, though the objects in the query image are significantly different from those in the support images in terms of appearance, size, and viewpoints. As shown in the third column, the object can also be successfully segmented even with serve occlusions.

4.2.3 Performance on FSS-1000
We evaluate our method following the official evaluation protocols in [35]. The evaluation metric used for FSS-1000 is the IoU of positive labels in a binary segmentation map. Since this dataset contains the validation set, we select the prediction model with the support of the validation set (i. e., the model for the best validation class-IoU). Then, we report the result on the testing set. A performance comparison with other models in terms of Positive-IoU (P-IoU) is provided in Table 3. Our method improves over the state-of-the-art set by Wei et al. [35] by and in the 1-shot and 5-shot settings, demonstrating the effectiveness of our proposal for few-shot segmentation across large-scale semantic categories. Figure 8 visualizes segmentation results on the FSS-1000 dataset, where API produces accurate segmentation maps close to the ground truth. As shown in Figure 8, the foreground object is usually located at the center of the image and distinct from the background. This could lead to considerable performance for our model on FSS-1000.

4.2.4 Performance on LIDC-IDRI
To study the capability of handling ambiguities in few-shot segmentation, we conduct experiments on LIDC-IDRI with typical image ambiguities. The evaluation metric for LIDC-IDRI is the cross energy distance (CED) rather than IoU values. We compared our API with two state-of-the-art deterministic models, i.e., FS-PARN [19] and HSNet [21]. Since it is unfeasible to directly obtain prediction samples, we adopt an ensemble strategy that trains a set of predictors. Considering there are 4 annotations, we initialize 15 models and select different combinations of annotations to train the model. Thus, the total number of those combinations is . For each model, we train it with the data batch that randomly sampled annotation masks from the current combination. Once 15 predictors are obtained, we randomly select a subset of predictors to compute multiple segmentation maps for testing. As for our API model, we randomly choose one of 4 annotation masks at each update step. In this experiment, the sampling number in CED (Eq. 16) is set as 9 for all methods. We train the model on the training set of 2,030 sub-cases with the support of the validation set of 300 sub-cases. The results shown in Table 4 are CED values on the testing set. We test the model five times with different predictors selected randomly and report the average value for evaluation. Our API achieves non-trivial improvement in terms of CED compared with state-of-the-art methods. In contrast to evaluating the method with the class-IoU metric, the results of CED include all prediction samples and annotation masks, demonstrating the good statistical property of our proposed framework.
4.3 Ablation Study
4.3.1 Benefit of Probabilistic Modeling
Different from previous deterministic models by learning a deterministic prototype vector, our probabilistic methods infers the distributions of the class prototype and the attention vector for each image. To demonstrate the advantage of the proposed probabilistic modeling, we implement a deterministic counterpart. We utilize the same network architecture for fair comparison and predict a deterministic class prototype vector or an attention map by the branch and remove the KL divergence term during training. We implement both models with a VGG-16 , ResNet50, and ResNet101 backbone, which are commonly adopted in previous works [32, 37].
The results on PASCAL- are shown in Table 5. Our attentional prototype inference consistently achieves better performance than the deterministic models under both the and -shot settings in terms of Class-IoU and the Binary-IoU metrics, because the proposed probabilistic modeling of prototypes and attention maps is more expressive of object classes and has a compelling capability of capturing the categorical concepts of objects. Therefore, the learned model is endowed with a stronger generalization ability for query images that usually exhibits large variations. The results also illustrate the advantage of probabilistic modeling for few-shot segmentation. As the ResNet101 backbone outperforms both VGG16 and ResNet50, we adopt ResNet101 as the backbone network in our experiments.
VGG | ResNet50 | ResNet101 | ||||||||||
C-IoU | B-IoU | C-IoU | B-IoU | C-IoU | B-IoU | |||||||
k-shot | 1 | 5 | 1 | 5 | 1 | 5 | 1 | 5 | 1 | 5 | 1 | 5 |
Deter. | 51.6 | 53.1 | 64.1 | 65.3 | 54.1 | 57.4 | 65.7 | 68.9 | 55.4 | 58.7 | 66.3 | 69.9 |
API | 52.7 | 55.6 | 65.2 | 66.0 | 55.9 | 59.8 | 69.3 | 71.7 | 57.4 | 60.7 | 71.3 | 73.2 |
4.3.2 Effectiveness of Latent Attention Mechanism
The newly introduced attention mechanism faithfully leverages the structure information that is neglected by the prototype, which is essential for mask prediction. The transformer architecture adopted in our model can utilize the context knowledge from pixels with high fidelity and enhances the foreground to achieve accurate prediction. We conduct extensive comparison experiments with the baseline variational prototype inference (VPI) [28] to show the effectiveness of the latent attention mechanism. As shown in Table 6, we evaluate their performance on three benchmarks of PASCAL-, COCO- and FSS-1000 with three metrics of Class-IoU, Binary-IoU, and Positive IoU. API achieves considerable improvement compared with VPI in most cases. In particular, API improves VPI by under the 1-shot, and under the 5-shot settings for Class-IoU on COCO-20i. The performance advantage of API in Binary-IoU compared to VPI is relatively smaller than in Class-IoU on COCO-20i. This is due to the bias of Binary-IoU towards objects that cover a large part of the foreground and background areas.
PASCAL-5i | FSS-1000 | ||||
C-IoU | B-IoU | P-IoU | |||
1-shot | 5-shot | 1-shot | 5-shot | 1-shot | |
VPI | 57.3 | 60.4 | 70.3 | 72.1 | 84.3 |
API | 57.4 | 60.7 | 71.3 | 73.2 | 85.6 |
COCO-20i | FSS-1000 | ||||
C-IoU | B-IoU | P-IoU | |||
1-shot | 5-shot | 1-shot | 5-shot | 1-shot | |
VPI | 23.4 | 27.8 | 61.1 | 63.0 | 87.7 |
API | 31.2 | 35.9 | 61.4 | 62.3 | 88.0 |
We also visualize the attention maps computed by our API model in Figure 9. The pixel on the foreground object is highlighted to enhance the prediction of the prediction maps. The proposed latent attention module can capture the outline of the foreground objects. Besides, this module merely increases a little inference time. As evidenced in Figure 11, when we fix and increase from 1 to 15, the extra time introduced by the attention module is less than 0.1 seconds even under the 5-shot setting. These observations indicate the advantage of our strategy for computational efficiency.

4.3.3 Effect of Monte Carlo Sampling
The segmentation map is estimated by Monte Carlo sampling that obtains multiple prototypes z and attention maps m, and produces multiple outputs, then, all outputs are aggregated to produce the final segmentation map. We conduct a qualitative analysis of the effect of the Monte Carlo sampling on the segmentation results. As shown in Figure 11, the segmentation map for each sampled prototype is not always adequate. For example, in the first row, the segmentation map generated by 3 times sampling does not completely recover the object. By averaging the segmentation maps produced by the individual samples, the final segmentation map tends to be more complete and robust.
The quantitative results (Class-IoU of the sample number 5 & 15: 0.56 vs. 0.57) show that the prediction result turns to be better as the number of samples increases. Despite the fact that the segmentation results are more accurate given more samples, it will take more time for inference. We observe that the performance tends to saturate when reach . Therefore, in our experiments, we set to 5 during inference to achieve precise segmentation maps with acceptable inference cost.


5 Discussion and Conclusion
This paper tackles few-shot segmentation from a probabilistic perspective. It contains two latent variables from different levels: 1) a global latent variable is inferred from the data that represents the prototype of each object category; 2) a local latent variable is designed to represent an attention map for each image, which highlights the foreground object for improved segmentation. We develop attentive prototype inference (API) to leverage variational inference for efficient optimization. By probabilistic modeling, API is capable of enhancing its generalization ability by handling the inherent uncertainty caused by limited data and the intra-class variations of objects, which is essential for generalizing to new unseen categories in experimental practice.
The inherent property of the prototype-based method is computing the foreground prototype with the guide of label masks at the image level. Those prototypes are usually ambiguous for foreground objects at the feature level, especially for instances with small sizes. This could be resolved by taking hierarchical prototypes of varied feature levels into account in future work, i.e., computing prototypes at the different level of the encoder, especially, those that are close to the input layer.
The probabilistic manner is essential in API. The key step is the Monte Carlo (MC) estimation for the segmentation prediction. To achieve a better estimation result, we usually consider sampling multiple hypotheses. This would bring extra computational costs, especially during the training stage. However, it is unnecessary to obtain an accurate MC estimation of the conditional prediction distribution at the training stage. Recall SGD in neural network optimization, the biased estimation could introduce randomness for the learning process to achieve better generalization. Therefore, we choose the sampling number of 1 at the training stage. Also, there is an extra benefit that can be considered as the regularization by injecting uncertainty into the learning process. Besides, the inference speed at the testing stage can be acceptable as the sampling number of the MC estimation is small (i.e., 10 in all experiments).
The probabilistic formulation could handle ambiguities in the segmentation task by producing a set of diverse but plausible segmentation results. It could potentially have a high impact on clinical applications. Multiple segmentation hypotheses from our model could provide diagnosis probabilities or guide steps to resolve ambiguities.
Acknowledgments
This research was supported in part by Natural Science Foundation of China (Nos. 62106129, 62176139, 62106128), Natural Science Foundation of Shandong Province (Nos. ZR2021QF053, ZR2021QF001), Major Basic Research Project of Natural Science Foundation of Shandong Province (No. ZR2021ZD15), The Fundamental Research Funds of Shandong University and The Open Research Project Programme of the State Key Laboratory of Internet of Things for Smart City (University of Macau) (No. SKL-IoTSC(UM)-2021-2023/ORP/GA05/2022).
References
- [1] G. J. Brostow, J. Fauqueur, R. Cipolla, Semantic object classes in video: A high-definition ground truth database, Pattern Recognition Letters 30 (2) (2009) 88–97.
- [2] L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A. L. Yuille, Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs, IEEE Transactions on Pattern Analysis and Machine Intelligence 40 (4) (2017) 834–848.
- [3] J. Long, E. Shelhamer, T. Darrell, Fully convolutional networks for semantic segmentation, in: Computer Vision and Pattern Recognition, 2015.
- [4] Q. Fan, W. Pei, Y.-W. Tai, C.-K. Tang, Self-support few-shot semantic segmentation, in: European Conference on Computer Vision, Springer, 2022, pp. 701–719.
- [5] S. Luo, Y. Li, P. Gao, Y. Wang, S. Serikawa, Meta-seg: A survey of meta-learning for image segmentation, Pattern Recognition (2022) 108586.
- [6] E. H. Rosch, Natural categories, Cognitive psychology 4 (3) (1973) 328–350.
- [7] J. Snell, K. Swersky, R. Zemel, Prototypical networks for few-shot learning, in: Neural Information Processing Systems, 2017, pp. 4077–4087.
- [8] G. Li, V. Jampani, L. Sevilla-Lara, D. Sun, J. Kim, J. Kim, Adaptive prototype learning and allocation for few-shot segmentation, in: Computer Vision and Pattern Recognition, 2021, pp. 8334–8343.
- [9] N. Dong, E. Xing, Few-shot semantic segmentation with prototype learning., in: British Machine Vision Conference, 2018.
- [10] Q. Zhou, X. Wu, S. Zhang, B. Kang, Z. Ge, L. Jan Latecki, Contextual ensemble network for semantic segmentation, Pattern Recognition 122 (2022) 108290.
- [11] X. Lu, W. Wang, J. Shen, D. Crandall, L. Van Gool, Segmenting objects from relational visual data, IEEE Transactions on Pattern Analysis and Machine Intelligence (01) (2021) 1–1.
- [12] H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid scene parsing network, in: Computer Vision and Pattern Recognition, 2017, pp. 2881–2890.
- [13] Y. Zhang, X. Sun, J. Dong, C. Chen, Q. Lv, Gpnet: gated pyramid network for semantic segmentation, Pattern Recognition 115 (2021) 107940.
- [14] A. Shaban, S. Bansal, Z. Liu, I. Essa, B. Boots, One-shot learning for semantic segmentation, in: British Machine Vision Conference, 2017.
- [15] M. Siam, B. N. Oreshkin, M. Jagersand, Amp: Adaptive masked proxies for few-shot segmentation, in: International Conference on Computer Vision, 2019, pp. 5249–5258.
- [16] A. Okazawa, Interclass prototype relation for few-shot segmentation, in: European Conference on Computer Vision, Springer, 2022, pp. 362–378.
- [17] B. Yang, C. Liu, B. Li, J. Jiao, Q. Ye, Prototype mixture models for few-shot semantic segmentation, in: European Conference on Computer Vision, 2020, pp. 763–778.
- [18] Z. Tian, H. Zhao, M. Shu, Z. Yang, R. Li, J. Jia, Prior guided feature enrichment network for few-shot segmentation, IEEE Transactions on Pattern Analysis and Machine Intelligence (01) (2020) 1–1.
- [19] Y. Li, P. Zhang, X. Xu, Y. Lai, F. Shen, L. Chen, P. Gao, Few-shot prototype alignment regularization network for document image layout segementation, Pattern Recognition 115 (2021) 107882.
- [20] R. Singh, V. Bharti, V. Purohit, A. Kumar, A. K. Singh, S. K. Singh, Metamed: Few-shot medical image classification using gradient-based meta-learning, Pattern Recognition 120 (2021) 108111.
- [21] J. Min, D. Kang, M. Cho, Hypercorrelation squeeze for few-shot segmentation, in: International Conference on Computer Vision, 2021, pp. 6941–6952.
- [22] S. Hong, S. Cho, J. Nam, S. Lin, S. Kim, Cost aggregation with 4d convolutional swin transformer for few-shot segmentation, in: European Conference on Computer Vision, Springer, 2022, pp. 108–126.
- [23] J. Johnander, J. Edstedt, M. Felsberg, F. S. Khan, M. Danelljan, Dense gaussian processes for few-shot segmentation, in: European Conference on Computer Vision, Springer, 2022, pp. 217–234.
- [24] D. P. Kingma, M. Welling, Auto-encoding variational bayes, in: International Conference on Learning Representations, 2014.
- [25] K.-L. Lim, X. Jiang, Variational posterior approximation using stochastic gradient ascent with adaptive stepsize, Pattern Recognition 112 (2021) 107783.
- [26] S. Kohl, B. Romera-Paredes, C. Meyer, J. De Fauw, J. R. Ledsam, K. Maier-Hein, S. A. Eslami, D. J. Rezende, O. Ronneberger, A probabilistic u-net for segmentation of ambiguous images, in: Neural Information Processing Systems, 2018, pp. 6965–6975.
- [27] J. Zhang, C. Zhao, B. Ni, M. Xu, X. Yang, Variational few-shot learning, in: International Conference on Computer Vision, 2019, pp. 1685–1694.
- [28] H. Wang, Y. Yang, X. Cao, X. Zhen, C. Snoek, L. Shao, Variational prototype inference for few-shot semantic segmentation, in: Winter Conference on Applications of Computer Vision, 2021, pp. 525–534.
- [29] G. Bhat, F. J. Lawin, M. Danelljan, A. Robinson, M. Felsberg, L. Van Gool, R. Timofte, Learning what to learn for video object segmentation, in: European Conference on Computer Vision, 2020, pp. 777–794.
- [30] O. Ronneberger, P. Fischer, T. Brox, U-net: Convolutional networks for biomedical image segmentation, in: International Conference on Medical Image Computing and Computer Assisted Intervention, 2015, pp. 234–241.
- [31] W. Liu, C. Zhang, G. Lin, F. Liu, Crnet: Cross-reference networks for few-shot segmentation, in: Computer Vision and Pattern Recognition, 2020, pp. 4165–4173.
- [32] K. Nguyen, S. Todorovic, Feature weighting and boosting for few-shot segmentation, in: International Conference on Computer Vision, 2019, pp. 622–631.
- [33] K. Wang, J. H. Liew, Y. Zou, D. Zhou, J. Feng, Panet: Few-shot image semantic segmentation with prototype alignment, in: International Conference on Computer Vision, 2019, pp. 9197–9206.
- [34] T. Hu, P. Yang, C. Zhang, G. Yu, Y. Mu, C. G. M. Snoek, Attention-based multi-context guiding for few-shot semantic segmentation, in: AAAI Conference on Artificial Intelligence, 2019, pp. 8441–8448.
- [35] T. Wei, X. Li, Y. P. Chen, Y.-W. Tai, C.-K. Tang, Fss-1000: A 1000-class dataset for few-shot segmentation, in: Computer Vision and Pattern Recognition, 2020.
- [36] S. G. Armato III, G. McLennan, L. Bidaut, M. F. McNitt-Gray, C. R. Meyer, A. P. Reeves, B. Zhao, D. R. Aberle, C. I. Henschke, E. A. Hoffman, et al., The lung image database consortium (lidc) and image database resource initiative (idri): a completed reference database of lung nodules on ct scans, Medical physics 38 (2) (2011) 915–931.
- [37] C. Zhang, G. Lin, F. Liu, J. Guo, Q. Wu, R. Yao, Pyramid graph networks with connection attentions for region-based one-shot semantic segmentation, in: Computer Vision and Pattern Recognition, 2019, pp. 9587–9595.