This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Hybrid Generative-Contrastive Representation Learning

Saehoon Kim
Kakao Brain
shkim@kakaobrain.com
&Sungwoong Kim
Kakao Brain
swkim@kakaobrain.com
&Juho Lee
KAIST
juholee@kaist.ac.kr
Abstract

Unsupervised representation learning has recently received lots of interest due to its powerful generalizability through effectively leveraging large-scale unlabeled data. There are two prevalent approaches for this, contrastive learning and generative pre-training, where the former learns representations from instance-wise discrimination tasks and the latter learns them from estimating the likelihood. These seemingly orthogonal approaches have their own strengths and weaknesses. Contrastive learning tends to extract semantic information and discards details irrelevant for classifying objects, making the representations effective for discriminative tasks while degrading robustness to out-of-distribution data. On the other hand, the generative pre-training directly estimates the data distribution, so the representations tend to be robust but not optimal for discriminative tasks. In this paper, we show that we could achieve the best of both worlds by a hybrid training scheme. Specifically, we demonstrated that a transformer-based encoder-decoder architecture trained with both contrastive and generative losses can learn highly discriminative and robust representations without hurting the generative performance. We extensively validate our approach on various tasks. Code will be available at https://github.com/kakaobrain/gcrl.

Keywords Hybrid Learning, Generative Pretraining, Contrastive Learning

1 Introduction

Learning representations without human annotation has recently achieved remarkable progress, especially in natural language processing and computer vision [11, 29, 37, 25, 27, 9, 44, 43, 2, 34, 48, 22, 41, 19, 7, 6, 3, 17, 47, 4, 12]. Many unsupervised representation learning algorithms have been proposed, and these can be broadly categorized into either generative or self-supervised learning algorithms. Generative learning usually aims to obtain representations that can reconstruct an input using an encoder and a decoder while self-supervised learning generally trains an encoder to solve a pretext task derived from unlabeled data111Generative learning can be considered as a self-supervised learning [36], however, we separate it that performs an explicit decoding process to the input space.. In natural language processing, both generative learning and self-supervised learning are widely used for unsupervised representation learning. In computer vision, self-supervised learning, especially contrastive learning [48, 6] based on strong augmentations, is mostly adopted, and generative pre-training [4, 12] has recently shown some progress.

For downstream discriminative vision tasks such as image classification, contrastive learning has shown better performances than generative pre-training since it is trained with pretext tasks requiring instance discrimination, and thus encouraged to learn semantic representations rather than minor details. However, since contrastive learning only learns to discriminate images by their identities, it may be less robust under distributional shifts, e.g., less calibrated for out-of-distribution data or low-data transfer settings [51, 14]. On the other hand, the representations learned from generative pre-training may not be as efficient as the ones from contrastive learning, they are more likely to be robust under distributional shift or low-data regime. In addition, generative learning can reduce the reliance on the manually designed pretext tasks or augmentations which often incur overfitting problems [46, 33]. Similar to the trade-off between the discriminative and generative modeling [52, 40, 16], these two unsupervised representation learning objectives seem to be orthogonal and moreover incompatible, and hence there is almost no existing work to reap the benefits of both objectives in a multi-task pre-training way.

In this paper, we propose a hybrid multi-task learning framework that can achieve the merits of both generative and contrastive representation learning. In particular, while maintaining the structure of generative pre-training composed of autoregressive transformer blocks [50, 4] since they pose minimal inductive biases and thus are effective for generative modeling, we introduce an encoder-decoder architecture to explicitly separate the role of the blocks. Then, the instance-wise contrastive loss is applied to the pooled representations from the encoder while the generative loss is imposed on the output of the decoder. This separation alleviates the trade-off between the two objectives, and thus enables the encoder to learn both discriminative and robust representations. Experimental results on various image classification benchmarks show that the proposed hybrid approach, which we call Generative-Contrastive Representation Learning (GCRL), outperforms both the generative pre-training and contrastive learning when applied to downstream classification tasks with linear evaluation as well as out-of-distribution detection tasks. In addition, GCRL improves calibration of the prediction uncertainty and performance on low-shot transfer tasks. Furthermore, GCRL does not decrease generative performances of the decoder. Our main contributions can be summarized as follows:

  • We propose GCRL, a novel hybrid generative-contrastive representation learning framework. To the best of our knowledge, this is the first work to combine the generative and contrastive objectives for unsupervised representation learning on computer vision tasks.

  • GCRL does not introduce any specialized modules or inductive biases. Instead, we reinterpret the standard transformer blocks as encoder-decoder structures, to which the contrastive and generative losses are separately applied, allowing us to retain the benefits of both objectives in representation learning.

  • We demonstrate that GCRL outperforms baselines on several downstream image classification tasks and out-of-distribution detection tasks and provide extensive ablation studies.

2 Related Work

Self-supervised learning

Lots of self-supervised representation learning algorithms have been recently proposed for leveraging large-scale unlabeled data. In natural language processing, BERT [11] is a representative work that has exploited the masked word prediction and the next-sentence prediction as pretext tasks with bidirectional transformers and has achieved large performance improvements on many downstream tasks. Since BERT has been introduced, various variants [29, 37, 25] have been suggested with different pretext tasks such as the masked phrase prediction and the sentence order prediction. Different from BERT, InfoWord [27] has been proposed to maximize the mutual information between a global sentence representation and n-grams in it, which can be considered as contrastive learning, while ELECTRA [9] has explored the replaced token prediction with an adversarial training. In computer vision, contrastive learning [48] based on the manually designed positive and negative pairs has been typically used for self-supervised representation learning. MoCo [19, 7] has utilized the dynamic queuing and the moving-averaged encoder for efficiently handling a large number of negative samples. SimCLR [6] has improved the quality of representation by finding a more proper composition of transformations and non-linear projection heads. More recently, SwAV [3] has modified previous pairwise representation comparisons by introducing cluster assignment and swapped prediction, and BYOL [17] has removed the necessity of negative pairs, bootstrapping its own representation by keeping up with the moving averaged version of itself. Here, we use SimCLR as our baseline contrastive learning algorithm due to its simplicity and powerful performance, however our method can be combined with other algorithms.

Generative pre-training

Since the most basic self-supervised task is to reconstruct an input itself and to maximize its likelihood, early pre-training methods were based on generative modeling. Currently, autoregressive transformers have shown state-of-the-art performances in maximizing the data likelihood, and generative pre-training (GPT) of representations using autoregressive transformers has been increasingly investigated in both natural language processing and computer vision [44, 43, 2, 34, 4]. Also, as generative adversarial networks (GANs) [15] have been widely used in generating high-quality images, the use of GAN for unsupervised image representation learning has been recently explored [12]. These generative pre-training methods generally have shown great generalization ability when combined with large-scale data and big models. The proposed GCRL build upon the image GPT (iGPT) in explicitly estimating an image likelihood.

Hybrid modeling

Hybrid generative and discriminative models have been largely studied to take advantages of both modeling directions [45, 30, 52, 5, 40], and among them, energy based models (EBMs) [16, 35, 32, 13] have shown great performances with improved training algorithms, recently. However, these hybrid approaches have been mostly developed for regularizing models under supervised learning, whereas our GCRL attempts to improve unsupervised representation learning. Furthermore, we separate target representations for each objective to overcome the fundamental trade-off problem in hybrid modeling.

3 Approach

This section presents how to implement GCRL without having trade-offs between generative and contrastive objectives.

3.1 Hybrid Objective & Network Architecture

Refer to caption
Figure 1: Illustration on GCRL architecture. Here, LN and GAP stand for layer normalization and global averaging pooling. Proj. layer before the contrastive loss refers to a projection head, implemented by a two-layers neural network. We use the representation obtained before the projection head for downstream classification tasks.

Given a batch of NN unlabeled samples {𝐱i}i=1N\{\mathbf{x}_{i}\}_{i=1}^{N} and the corresponding set of augmentation pairs {(𝐱i(1),𝐱i(2))}i=1N\{(\mathbf{x}_{i}^{(1)},\mathbf{x}_{i}^{(2)})\}_{i=1}^{N} where 𝐱i(1)\mathbf{x}_{i}^{(1)} and 𝐱i(2)\mathbf{x}_{i}^{(2)} are two random augmentations of 𝐱i\mathbf{x}_{i}, our hybrid loss function is defined as

=αg+βc,\begin{split}{\mathcal{L}}=\alpha{\mathcal{L}}_{g}+\beta{\mathcal{L}}_{c},\end{split} (1)

where α\alpha and β\beta are the loss weights, and the generative loss g{\mathcal{L}}_{g} and the contrastive loss c{\mathcal{L}}_{c} are respectively formulated as

g=1Ni=1Nlogp(𝐱i(1))=1Ni=1N(k=2Dlogp(𝐱i,k(1)|𝐱i,<k(1))),\displaystyle{\mathcal{L}}_{g}\!=-\frac{1}{N}\sum_{i=1}^{N}\log p(\mathbf{x}_{i}^{(1)})=-\frac{1}{N}\sum_{i=1}^{N}\!\Big{(}\!\sum_{k=2}^{D}\log p\left(\mathbf{x}_{i,k}^{(1)}|\mathbf{x}_{i,<k}^{(1)}\right)\!\!\Big{)}, (2)
c=12Ni=1N(logexp(𝐳i(1)𝐳i(2)τ)jiexp(𝐳i(1)𝐳j(1)τ)+jexp(𝐳i(1)𝐳j(2)τ)+logexp(𝐳i(2)𝐳i(1)τ)jiexp(𝐳i(2)𝐳j(2)τ)+jexp(𝐳i(2)𝐳j(1)τ)).\displaystyle{\mathcal{L}}_{c}\!=-\frac{1}{2N}\sum_{i=1}^{N}\!\left(\!\log\!\frac{\exp(\frac{\mathbf{z}_{i}^{(1)}\cdot\mathbf{z}_{i}^{(2)}}{\tau})}{\sum\limits_{j\neq i}\exp(\frac{\mathbf{z}_{i}^{(1)}\cdot\mathbf{z}_{j}^{(1)}}{\tau})\!+\!\sum\limits_{j}\exp(\frac{\mathbf{z}_{i}^{(1)}\cdot\mathbf{z}_{j}^{(2)}}{\tau})}+\log\!\frac{\exp(\frac{\mathbf{z}_{i}^{(2)}\cdot\mathbf{z}_{i}^{(1)}}{\tau})}{\sum\limits_{j\neq i}\exp(\frac{\mathbf{z}_{i}^{(2)}\cdot\mathbf{z}_{j}^{(2)}}{\tau})\!+\!\sum\limits_{j}\exp(\frac{\mathbf{z}_{i}^{(2)}\cdot\mathbf{z}_{j}^{(1)}}{\tau})}\right). (3)

Here, DD is the data dimension, 𝐳i(1)\mathbf{z}_{i}^{(1)} and 𝐳i(2)\mathbf{z}_{i}^{(2)} are projected and normalized representations extracted from 𝐱i(1)\mathbf{x}_{i}^{(1)} and 𝐱i(2)\mathbf{x}_{i}^{(2)}, respectively, and 𝐱i,k(1)\mathbf{x}_{i,k}^{(1)} is the kkth element of 𝐱i(1)\mathbf{x}_{i}^{(1)}. The generative loss g\mathcal{L}_{g} is the likelihood of the model autoregressively factorized in a typical raster scan order [42], and the contrastive loss c\mathcal{L}_{c} is the symmetric normalized temperature cross-entropy loss as in SimCLR [6].

In order to optimize the loss function, we propose a transformer-based architecture composed of encoder and decoder, motivated by empirical observations presented in iGPT [4], where early transformer blocks behave similarly to an encoder in typical auto-encoders and the remaining blocks behave like a decoder. Therefore, without altering the spatial resolution, we explicitly split a set of transformer blocks into an encoder and a decoder as shown in Figure 1. For the contrastive learning, the encoder takes two versions of an input, one is weakly augmented and the other is strongly augmented. Then, the two representations obtained from the encoder are used for computing the contrastive loss c\mathcal{L}_{c}. The decoder processes only the weakly augmented image to compute the generative loss g\mathcal{L}_{g}. We observe that training of the generative loss with strong augmentations deteriorates the robustness, since it focuses on producing unrealistic samples from a severely distorted distribution.

Table 1: Model specification of each method (left). Illustration on our implementation of transformer and axial attention blocks (right). Our attention block is composed of three types of attentions: unmasked row-wise, masked column-wise, and masked row-wise attention, where the parameters across attention types are shared in the same block.
Method #params #blocks #enc. blocks
iGPT 10M 12 12
76M 24 24
SimCLR 5M 6 6
10M 12 12
76M 24 24
GCRL (5+5)M 12 6
(38+38)M 24 12

[Uncaptioned image]

3.2 Model Details

We use the similar transformer block as iGPT for the encoder and decoder, where each block is constructed as follows. Let 𝐡l1\mathbf{h}_{l-1} be the output from the (l1)(l-1)th block. The llth transformer block is composed of the following operations,

𝐡l=LN(𝐡l1),𝐡l=𝐡l+CausalMultiHeadSelfAttention(𝐡l),𝐡l=𝐡l+MLP(LN(𝐡l)),\begin{split}\mathbf{h}_{l}&=\textrm{LN}(\mathbf{h}_{l-1}),\\ \mathbf{h}_{l}&=\mathbf{h}_{l}+\textrm{CausalMultiHeadSelfAttention}(\mathbf{h}_{l}),\\ \mathbf{h}_{l}&=\mathbf{h}_{l}+\textrm{MLP}(\textrm{LN}(\mathbf{h}_{l})),\end{split} (4)

where LN refers to layer normalization [1], causal multi-head self-attention module computes interactions across a sequence only in a raster scan order, and MLP block of the same structure as iGPT computes nonlinear embeddings of the tokens in a sequence independently. To accelerate a self-attention module, we reduce the size of the sequence by color quantization [4] and adopt the sparse attention block used in AxialTransformer [23] with slight modification. For color quantization, we compress RGB color channels into one of 512 codes as in iGPT, where the codebook is simply constructed by kk-means clustering. We use the codebook provided from the official implementation of iGPT for fair comparison. Given a colored image of 32x32 resolution, this simple processing effectively reduces the sequence length from 3,072 to 1,024.

AxialTransformer [23] accelerates a self-attention module by using a two-dimensional structure of an image, where row-wise and column-wise attentions are combined to access a full image in a raster scan order. In GCRL, unlike the original implementation of AxialTransformer, we sequentially apply row-wise attention, causal column-wise attention, and causal row-wise attention, to access a full image in every transformer block, as shown in Table 1. We empirically observe that our implementation of axial attention block reduces working memory and training time significantly compared to dense attention, without loosing performance. For detailed description including empirical results on the type of attention blocks, please refer to our PyTorch-like pseudo-code and ablation study in Appendix.

To fairly compare GCRL to iGPT and SimCLR, we use the same architecture for all objectives. We perform experiments with several models by varying the number of transformer blocks (6, 12, and 24 blocks). For all models, two-dimensional positional embeddings are applied after the first linear embedding, and the number of heads for axial attention was set to 8. Table 1 compares the sizes of the models tested for iGPT, SimCLR, and GCRL. Note that since GCRL is composed of an encoder and a decoder, we only need a half of the total number of blocks for evaluating the representations.

4 Experiments

This section compares the characteristics of representation learned by iGPT, SimCLR, and GCRL from the perspective of linear evaluation, generative performance, robustness to out-of-distribution samples, and low-shot transfer learning.

4.1 Experiment Details

Table 2: Pipeline of data augmentations used for SimCLR and GCRL. For GCRL, this pipeline is only applied to the input stream of heavy augmentation.
Order Augmentation Arguments Pytorch function
1 Resized and crop Crop size = (32, 32), crop scale = (0.2, 1.0) RandomResizedCrop
2 Color dist. (jitter) of prob. 0.8 brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 ColorJitter
3 Color dist. (gray) probability=0.2 RandomGrayscale
4 Horizontal flip probability=0.5 RandomHorizontalFlip
Table 3: Detailed configurations of experiments in Table 4.
Method Model size ImageNet32 CIFAR10 & 100
Batch size Epoch Peak learning rate Batch size Epoch Peak learning rate
iGPT 10M 1024 100 4.0e-04 512 200 1.6e-03
76M 384 384 15 3.2e-03
SimCLR 5M 1024 100 4.0e-04 512 400 4.8e-04
10M 1024 512 400 4.8e-04
76M 768 384 15 9.6e-04
GCRL (5+5)M 1024 50 + 50 4.0e-04 / 4.0e-04 512 200 + 200 1.6e-03 / 4.8e-04
(38+38)M 768 384 15 9.6e-04

We use three datasets: CIFAR10, CIFAR100 [28], and ImageNet32 [10], where CIFAR contains 50K samples for training and ImageNet32 provides 1.2M training samples of resolution 32x32. We use Adam optimizer with decoupled weight decay (AdamW) [39] with gradient clipping of norm 1.0 and weight decay factor 10410^{-4}, and do not decay trainable parameters in layer normalization and token embeddings. We apply dropout with rate 0.1 before every residual connection in transformer blocks. We linearly increase the learning rate from zero to a specified value over 5 epochs and apply the cosine decay after it [38]. We use PyTorch version 1.6.0 and conduct all experiments with V100s with 32GB memories. For CIFAR experiments, we report mean and standard deviations from repetitions with random seed 0, 1, and 2. For ImageNet experiments, we report the result with random seed 0.

We observe that it is practically useful for training GCRL only with the generative loss g{\mathcal{L}}_{g} for the first half of the total training epochs and then train with the full objective for the remaining half. Also, we set both of the loss weights α\alpha and β\beta for the hybrid objective to 1.01.0. Below, we summarize the data augmentation policies we used.

iGPT. We use the standard random crop with padding followed by horizontal flip with probability 0.5 for all experiments, where the pad size is 44 and the padding is done with reflections of images.

SimCLR. We use the similar augmentation policy as in [6]. We use the same two-layer MLP projection head as in SimCLR, where batch normalization is replaced by layer normalization, the hidden dimension is set to 1,024, and the final embedding dimension is set to 64 for CIFAR and 128 for ImageNet experiments.

GCRL. This requires both strong and weak augmentations where the weakly augmented images are processed for the generative loss. We choose the weak augmentation as the one used for iGPT and the strong augmentation as the one used for SimCLR. Table 2 shows a pipeline of data augmentations for SimCLR and GCRL. We think that this augmentation policy is fair enough to see the different behaviors of SimCLR and our approach, though it is possible to obtain more discriminative representations by employing a sophisticated augmentation strategy. For the projection head, we use the same network as in SimCLR.

It is not clear what representations should be used for downstream tasks in iGPT. Hence, we test with two versions of representations for iGPT. The first one is pooled from the middle transformer block, and the second one is pooled from the last transformer block. For SimCLR and GCRL, the output before the projection head is selected as the final representation.

4.2 Discriminative & Generative Performance

Table 4: Comparison between baselines and GCRL in terms of linear evaluation and bits-per-dimension (bpd) on the color-quantized space. “Pos.” refers to which transformer block is used for the final representation of linear evaluation.
Method Model size Pos. ImageNet32 CIFAR10 CIFAR100
Acc.(\uparrow) Bpd.(\downarrow) Acc.(\uparrow) Bpd.(\downarrow) Acc.(\uparrow) Bpd.(\downarrow)
iGPT 10M half 0.2508 3.1117 0.7750 ±\pm 0.0063 2.7775 ±\pm 0.0021 0.4914 ±\pm 0.0054 2.6756 ±\pm 0.0019
last 0.1941 0.7482 ±\pm 0.0023 0.4516 ±\pm 0.0030
76M half 0.4162 3.0419 0.9372 ±\pm 0.0024 2.6969 ±\pm 0.0003 0.7163 ±\pm 0.0031 2.5978 ±\pm 0.0003
last 0.3177 0.8910 ±\pm 0.0014 0.6171 ±\pm 0.0011
iGPT-S [4] 76M best 0.4190 3.0145 - - - -
SimCLR 5M last 0.2717 - 0.8355 ±\pm 0.0018 - 0.5853 ±\pm 0.0058 -
10M last 0.2955 0.8442 ±\pm 0.0064 0.6022 ±\pm 0.0041
76M last 0.3856 0.9046 ±\pm 0.0026 0.6954 ±\pm 0.0055
GCRL (5+5)M half 0.3010 3.1140 0.8391 ±\pm 0.0020 2.7839 ±\pm 0.0023 0.5868 ±\pm 0.0022 2.6834 ±\pm 0.0019
(38+38)M half 0.4359 3.0448 0.9506 ±\pm 0.0004 2.6784 ±\pm 0.0001 0.7603 ±\pm 0.0005 2.5756 ±\pm 0.0001
Refer to caption
Refer to caption
Figure 2: Generated samples from iGPT-76M (left) and GCRL-76M (right) trained on ImageNet32 (temperature is 1.0). More samples could be found in Appendix.

We use the linear evaluation protocol on representations to measure the discriminative performance of visual features, and report bits-per-dimension (bpd) on the color-quantized space 222 Note that the bpd values reported in this paper are not directly comparable to the ones in literature [8, 26], because we quantize the color space to reduce the sequence length. for assessing the generative performance. For linear evaluation, we follow the setup in the previous work [4, 6], where we train a single-layer linear classifier with AdamW optimizer for 100 epochs with weight decay 10410^{-4} and batch size 512. We use the learning rate of 0.010.01 for CIFAR experiments, and 0.10.1 for ImageNet experiments. No data augmentation is applied for linear evaluation.

Table 4 compares GCRL with baselines under various model capacity and datasets, where the hyper-parameters used for this experiment are reported in Table 3. For iGPT-76M, SimCLR-76M, and GCRL-(38+38)M with CIFAR10 and CIFAR100, we fine-tune our pre-trained ImageNet32 models by 15 epochs. Table 4 shows that GCRL is able to learn representations having both discriminative and generative features. When the model capacity is high enough, we observe that GCRL consistently improves the linear evaluation performance, compared to iGPT (our implementation), iGPT-S [4], and SimCLR, while maintaining the generative performance. Specifically, in case of ImageNet32, GCRL achieves 43.6% linear evaluation accuracy, which is superior to iGPT-S (41.9%). In addition, when fine-tuning the model of 76M and (38+38)M trained on ImageNet to CIFARs, GCRL achieves lower bpd than the one from iGPT (2.6784 vs. 2.6969 in CIFAR10), which indicates the generative performance of GCRL is superior to iGPT.

Figure 2 compares the perceptual quality of generated images from iGPT and GCRL on ImageNet32, where both approaches are able to generate realistic images. In addition, we measure FID-10K [21] for iGPT and GCRL of 10M parameters trained on CIFAR10, where the score of iGPT is 45.02 ±\pm 0.99 and the score of GCRL is 44.43 ±\pm 0.36. Considering these observations with the small difference between bpds in Table 4, we can argue that GCRL maintains the generative performance of iGPT.

4.3 Robustness on OOD samples

Table 5: AUROC results of a supervised OOD detection task from iGPT, SimCLR, and GCRL, where we use ImageNet32, CIFAR10, and CIFAR100 as in-distribution datasets, while SVHN and STL-10 are OOD datasets.
Method Model size Pos. ImagNet32 CIFAR10 CIFAR100
SVHN STL-10 SVHN STL-10 SVHN STL-10
iGPT 10M half 0.9824 0.5458 0.9638 ±\pm 0.0040 0.6191 ±\pm 0.0065 0.9395 ±\pm 0.0069 0.6214 ±\pm 0.0049
last 0.9538 0.5214 0.9414 ±\pm 0.0048 0.6311 ±\pm 0.0014 0.9236 ±\pm 0.0028 0.6501 ±\pm 0.0010
76M half 0.9803 0.6187 0.9896 ±\pm 0.0009 0.7325 ±\pm 0.0021 0.9697 ±\pm 0.0009 0.7342 ±\pm 0.0050
last 0.9231 0.5243 0.9261 ±\pm 0.0027 0.6593 ±\pm 0.0006 0.8723 ±\pm 0.0031 0.6797 ±\pm 0.0025
SimCLR 5M last 0.5867 0.4705 0.7945 ±\pm 0.0949 0.4946 ±\pm 0.0076 0.8243 ±\pm 0.0728 0.5820 ±\pm 0.0118
10M last 0.6415 0.5142 0.7262 ±\pm 0.0144 0.4850 ±\pm 0.0055 0.7326 ±\pm 0.0355 0.5660 ±\pm 0.0344
76M last 0.6196 0.5148 0.7533 ±\pm 0.0267 0.5026 ±\pm 0.0107 0.7102 ±\pm 0.0163 0.6172 ±\pm 0.0066
GCRL (5+5)M half 0.9975 0.5951 0.9971 ±\pm 0.0003 0.6323 ±\pm 0.0052 0.9896 ±\pm 0.0016 0.6375 ±\pm 0.0025
(38+38)M half 0.9982 0.6138 0.9975 ±\pm 0.0001 0.7000 ±\pm 0.0031 0.9841 ±\pm 0.0005 0.7092 ±\pm 0.0023
Table 6: AUROC and AUPRC results of an unsupervised OOD detection task by approximating mass of a density function, where we use CIFAR10 (C10), CIFAR100 (C100), and ImageNet32 (I32) as in-distribution datasets.
Method #params I32 \rightarrow SVHN C10 \rightarrow SVHN C100 \rightarrow SVHN
AUROC AUPRC AUROC AUPRC AUROC AUPRC
iGPT 10M 0.9740 0.9847 0.8301 ±\pm 0.0253 0.6129 ±\pm 0.0474 0.8571 ±\pm 0.0121 0.6762 ±\pm 0.0249
76M 0.7472 0.8543 0.3932 ±\pm 0.0093 0.2222 ±\pm 0.0030 0.4545 ±\pm 0.0062 0.2502 ±\pm 0.0021
GCRL (5+5)M 0.9710 0.9826 0.8814 ±\pm 0.0227 0.7080 ±\pm 0.0484 0.8769 ±\pm 0.0094 0.7167 ±\pm 0.0260
(38+38)M 0.9950 0.9970 0.9713 ±\pm 0.0022 0.9097 ±\pm 0.0064 0.9565 ±\pm 0.0034 0.8754 ±\pm 0.0080

We expect that our hybrid scheme learns more robust features for out-of-distribution (OOD) samples than both generative and discriminative learning since it can combine the merits of the two methods. Specifically, the generative loss is helpful for OOD discrimination because it learns the data distribution directly, and contrastive loss also helps OOD discrimination to some extent because it pulls the representations of similar in-distribution samples together to form a cluster-like structure. To validate our hypothesis, we conduct two types of OOD detection tasks using the representations learned from each objective. First, we consider a supervised OOD detection setting with the OOD detection scores that are computed from class-conditional Gaussian densities [31, 51]: s(𝐱)sup=maxc{𝒩(𝐳~(𝐱)|𝝁c,𝚺c)}c=1|C|s(\mathbf{x})_{\textrm{sup}}=\max_{c}\left\{{\mathcal{N}}\left({\tilde{\mathbf{z}}}(\mathbf{x})|\boldsymbol{\mu}_{c},\boldsymbol{\Sigma}_{c}\right)\right\}_{c=1}^{|C|}, where 𝐳~(𝐱){\tilde{\mathbf{z}}}(\mathbf{x}) is the representation from 𝐱\mathbf{x} without projection, |C||C| is the number of classes, and the parameters (𝝁c,𝚺c)(\boldsymbol{\mu}_{c},\boldsymbol{\Sigma}_{c}) are estimated from training samples of class cc. Second, we consider an unsupervised OOD detection setting where OOD scores are computed by area of high probability region [16] approximated with the magnitude of the score functions, s(𝐱)unsup=logp(𝐱)𝐱2s(\mathbf{x})_{\textrm{unsup}}=-\left\|\frac{\partial\log p(\mathbf{x})}{\partial\mathbf{x}}\right\|_{2}. Since iGPT and GCRL quantize the input space, we cannot exactly compute this gradient with respect to the input image. 333Using a stochastic quantization [24] makes it possible to derive the input gradient.. Instead, we compute the gradient with respect the token embedding. Note that SimCLR is not applicable to the unsupervised OOD detection because it does not provide a density.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 3: Distribution of supervised OOD scores of in-distribution (CIFAR10 and ImageNet32) and out-of-distribution (SVHN), where we use the 10M models for CIFAR10 and 76M models for ImageNet32. In case of iGPT, we use the representation extracted from the middle transformer block.
Refer to caption
Refer to caption
Refer to caption
Figure 4: Two-dimensional visualization of 2,000 randomly selected test samples of CIFAR10 and SVHN, where iGPT (10M), SimCLR (10M), and GCRL of (5+5)M are trained on CIFAR10. Here, samples of SVHN are marked in gray, and samples of the same class in CIFAR10 are marked in the same color.

Table 5 presents the supervised OOD detection results where GCRL consistently outperforms the baselines in case of SVHN. Since STL-10 is a subset of ImageNet, it may be considered as an in-distribution set. However, we want to report how our approach behaves when a subset of in-distribution sets is given as an OOD dataset. The performance of OOD detection is measured by area under the ROC curve (AUROC), and area under the precision-recall curve (AUPRC) results are placed in Appendix. Figure 3 shows several distributions of s(𝐱)sups(\mathbf{x})_{\textrm{sup}} of in/out-distribution test samples, supporting the results of Table 5. Table 6 shows the results of the unsupervised OOD detection where GCRL outperforms iGPT. To empirically analyze the reason why GCRL works better than baselines, we visualize the representations learned from each method in a two-dimensional space by tt-SNE [49], as shown in Figure 4. We can observe that iGPT is able to discriminate in-distribution and out-of-distribution samples, but the in-distribution samples are not well clustered. On the other hand, SimCLR pulls visually-similar samples together in the embedding space, but it is not able to discriminate in-distribution and out-of-distribution samples. GCRL achieves both in/out discrimination and cluster-like structure.

4.4 Ablation Study

This section includes several additional experiments to support our claims; (a) ablation study on the motivation of network design and low-shot transfer learning, and (b) the regularization effects of a generative loss in GCRL.

Table 7: Ablation study of GCRL by varying the number of blocks and the coefficients of objective on CIFAR10.
#enc. / #total Model Size (α,β)(\alpha,\beta) Linear eval. bpd
6/6 5M (1.0, 1.0) 0.8034 2.7863
(0.5, 1.5) 0.8095 2.8056
(1.5, 0.5) 0.7929 2.7788
12/12 10M (1.0, 1.0) 0.8145 2.7813
(0.5, 1.5) 0.8278 2.7772
(1.5, 0.5) 0.8017 2.7872
6/12 (5+5)M (1.0, 1.0) 0.8376 2.7866
(0.5, 1.5) 0.8355 2.7870
(1.5, 0.5) 0.8336 2.7887

Table 7 shows an ablation study to support our network design choice, showing that when the number of blocks in encoder is half of the total number of blocks, the final performance is robust to the hyperparameters in the GCRL objective (α\alpha and β\beta). It is noted that noticeable trade-offs between the two objectives can be observed when the hybrid loss is imposed directly to the same target representation from the final block, whereas the proposed separation of target representations can significantly reduce the trade-offs.

Table 8 shows the low-shot transfer classification accuracy from ImageNet32 to CIFAR10. We observe that GCRL improves the classification accuracy as well as expected calibration error (ECE) [18] by a large margin. We believe that the generative loss greatly improves the calibration of predictions and generalization ability because it prevents overfitting to some extent, especially when labeled samples are scarce.

Table 8: Low-shot transfer classification results of iGPT (76M), SimCLR (76M), and GCRL (38+38M) with five random seeds (0,1,2,3, and 4) in case of CIFAR10.
Method 500 labels (1%) 2500 labels (5%) 5000 labels (10%)
ACC. ECE. ACC. ECE. ACC. ECE.
iGPT 0.7417 ±\pm0.0055 0.1349 ±\pm0.0103 0.8332 ±\pm0.0015 0.0566 ±\pm0.0029 0.8583 ±\pm0.0016 0.0353 ±\pm0.0019
SimCLR 0.7609 ±\pm0.0029 0.1680 ±\pm0.0063 0.8332 ±\pm0.0013 0.0459 ±\pm0.0023 0.8528 ±\pm0.0027 0.0215 ±\pm0.0033
GCRL 0.7956 ±\pm0.0035 0.0808 ±\pm0.0086 0.8746 ±\pm0.0047 0.0260 ±\pm0.0045 0.8937 ±\pm0.0027 0.0127 ±\pm0.0014
Table 9: Regularizing effects of the data likelihood in case of the easy pretext tasks for a contrastive loss, where we compare GCRL of (5+5)M with SimCLR of 10M trained by a weak augmentation (WA) policy. As in Table 7, we conduct this experiment with random seed 0.

[Uncaptioned image]       Method Epoch CIFAR10 CIFAR100 ACC. ECE. ACC. ECE. SimCLR + WA 400 0.7414 0.0231 0.4976 0.0230 800 0.7526 0.0580 0.4763 0.0500 GCRL + WA 400 0.8252 0.0145 0.5684 0.0192 800 0.8331 0.0125 0.5789 0.0199

Table 10: Comparison of test classification accuracy and expected calibration error (ECE) between ResNet, iGPT, and GCRL on CIFAR10 and CIFAR100. We train ResNet-18 by 50 epochs and iGPT by 200 epochs. We train GCRL only with a generative loss for the first 100 epochs, where the peak learning rate is 1.6e-3 and then train with the hybrid objective for the remaining 100 epochs with the peak learning rate of 9.6e-4. We repeat this experiment with random seed 0, 1, and 2.
Method Size Peak LR. Epoch CIFAR10 CIFAR100
ACC. ECE. ACC. ECE.
ResNet-18 11M 1.0e-3 50 0.9360 ±\pm 0.0001 0.0445 ±\pm 0.0002 0.7406 ±\pm 0.0044 0.1589 ±\pm 0.0041
iGPT 5M 4.8e-4 200 0.8979 ±\pm 0.0034 0.0513 ±\pm 0.0034 0.6709 ±\pm 0.0048 0.1200 ±\pm 0.0046
10M 4.8e-4 200 0.9024 ±\pm 0.0017 0.0495 ±\pm 0.0040 0.6763 ±\pm 0.0064 0.1581 ±\pm 0.0073
GCRL (5+5)M 1.6e-3 / 9.6e-4 200 0.9412 ±\pm 0.0018 0.0337 ±\pm 0.0078 0.7403 ±\pm 0.0026 0.0725 ±\pm 0.0104
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 5: Reliability diagrams of iGPT and GCRL on CIFAR10 and CIFAR100, where the diagrams and ECE are from the results of random seed 0 in Table 10.

Table 9 shows the regularization effect of the generative loss, where we deliberately use an easy pretext task for constrastive learning to cause overfitting. We adjust the random crop ratio from 0.2 to 0.8, meaning that at least 80% of an original image will be included in the cropped image. To understand the effect of such weak augmentations for SimCLR, Table 9 shows 5-NN classification losses of SimCLR and GCRL on test data during training. We observe that the loss of SimCLR starts increasing after 400 epochs, meaning that the model is overfitted to the pretext task, and thus hurting generalization. In addition, we observe that ECE of SimCLR at 800 epoch is higher than the one at 400 epoch, which also indicates the overfitting. Yet, GCRL does not suffer from such an overfitting and still performs comparable to the ones with strong augmentations thanks to the generative loss.

To show that GCRL prevents a classification model from overfitting, we conduct a supervised learning for classification with two baselines: ResNet-18 [20] and iGPT by replacing the generative loss into the supervised cross-entropy loss. In case of GCRL, we replace the contrastive loss into the supervised loss. For training ResNet, we employ a default augmentation policy [53], where a 32x32 crop is randomly selected from a padded image with reflection of 4 pixels on each side, followed by its horizontal flip of probability 0.5. For GCRL as well as iGPT, we use the same augmentation policy of SimCLR in section 4.2. Table 10 shows classification accuracy and ECE of compared methods on CIFAR10 and CIFAR100. Compared to iGPT, GCRL clearly achieves the better accuracy, proving that the regularization by a generative loss is effective. In addition, GCRL achieves very competitive performances compared to ResNet, even though ours has minimal inductive biases in network structure. Figure 5 presents reliability diagrams of iGPT-10M and GCRL-(5+5)M models, showing that our approach provides well-calibrated predictions than the baseline.

5 Conclusion

We study a hybrid unsupervised training scheme that jointly optimizes generative and contrastive objectives in a single network. Instead of developing a complex and specialized module for the hybrid objective, we reinterpret the standard transformer blocks as an encoder-decoder structure, to which contrastive and generative losses are applied separately. We observe that our hybrid approach learns more discriminative and robust features than the ones from a single objective, especially when the model capacity is high enough. For future work, we will scale up our network to achieve more discriminative features than the ones from CNN-based architecture with recent contrastive learning approaches, while maintaining all the benefits of generative models.

Acknowledgment

JHL was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST))

References

  • [1] J. L. Ba, J. R. Kiros, and G. E. Hinton. Layer normalization. In NeurIPS, 2016.
  • [2] T. B. Brown, B. Mann, N. Ryder, M. Subbiah, J. Kaplan, P. Dhariwal, A. Neelakantan, P. Shyam, G. Sastry, A. Askell, S. Agarwal, A. Herbert-Voss, G. Krueger, T. Henighan, R. Child, A. Ramesh, D. M. Ziegler, J. Wu, C. Winter, C. Hesse, M. Chen, E. Sigler, M. Litwin, S. Gray, B. Chess, J. Clark, C. Berner, S. McCandlish, A. Radford, I. Sutskever, and D. Amodei. Language models are few-shot learners. In arXiv preprint arXiv:2005.14165, 2020.
  • [3] M. Caron, I. Misra, J. Mairal, P. Goyal, P. Bojanowski, and A. Joulin. Unsupervised learning of visual features by contrasting cluster assignments. In NeurIPS, 2020.
  • [4] M. Chen, A. Radford, R. Child, J. Wu, H. Jun, P. Dhariwal, D. Luan, and I. Sutskever. Generative pretraining from pixels. In ICML, 2020.
  • [5] R. T. Q. Chen, J. Behrmann, D. Duvenaud, and J.-H. Jacobsen. Residual flows for invertible generative modeling. In NeurIPS, 2019.
  • [6] T. Chen, S. Kornblith, M. Norouzi, and G. Hinton. A simple framework for contrastive learning of visual representations. In ICML, 2020.
  • [7] X. Chen, H. Fan, R. Girshick, and K. He. Improved baselines with momentum contrastive learning. In arXiv preprint arXiv:2003.04297, 2020.
  • [8] R. Child, S. Gray, A. Radford, and I. Sutskever. Generating long sequences with sparse transformers. In ICML, 2019.
  • [9] K. Clark, M.-T. Luong, Q. V. Le, and C. D. Manning. Electra: Pre-training text encoders as discriminators rather than generators. In ICLR, 2020.
  • [10] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei. ImageNet: A Large-Scale Hierarchical Image Database. In CVPR, 2009.
  • [11] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. In arXiv preprint arXiv:1810.04805, 2018.
  • [12] J. Donahue and K. Simonyan. Large scale adversarial representation learning. In NeurIPS, 2019.
  • [13] Y. Du and I. Mordatch. Implicit generation and modeling with energy-based models. In NeurIPS, 2019.
  • [14] L. Ericsson, H. Gouk, and T. M. Hospedales. How well do self-supervised models transfer? In arXiv preprint arXiv:2011.13377, 2020.
  • [15] I. J. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio. Generative adversarial nets. In NeurIPS, 2014.
  • [16] W. Grathwohl, K.-C. Wang, J.-H. Jacobsen, D. Duvenaud, M. Norouzi, and K. Swersky. Your classifier is secretly an energy based model and you should treat it like one. In ICLR, 2020.
  • [17] J.-B. Grill, F. Strub, F. Altché, C. Tallec, P. H. Richemond, E. Buchatskaya, C. Doersch, B. A. Pires, Z. D. Guo, M. G. Azar, B. Piot, K. Kavukcuoglu, R. Munos, and M. Valko. Bootstrap your own latent: A new approach to self-supervised learning. In NeurIPS, 2020.
  • [18] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On calibration of modern neural networks. In ICML, 2017.
  • [19] K. He, H. Fan, Y. Wu, S. Xie, and R. Girshick. Momentum contrast for unsupervised visual representation learning. In CVPR, 2020.
  • [20] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
  • [21] M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, and S. Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. In NeurIPS, 2017.
  • [22] R. D. Hjelm, A. Fedorov, S. Lavoie-Marchildon, K. Grewal, P. Bachman, A. Trischler, and Y. Bengio. Learning deep representations by mutual information estimation and maximization. In ICLR, 2019.
  • [23] J. Ho, N. Kalchbrenner, D. Weissenborn, and T. Salimans. Axial attention in multidimensional transformers. In arXiv preprint arXiv:1912.12180, 2019.
  • [24] E. Jang, S. Gu, and B. Poole. Categorical reparameterization with gumbel-softmax. In ICLR, 2017.
  • [25] M. Joshi, D. Chen, Y. Liu, D. S. Weld, L. Zettlemoyer, and O. Levy. Spanbert: Improving pre-training by representing and predicting spans. In arXiv preprint arXiv:1907.10529, 2019.
  • [26] H. Jun, R. Child, M. Chen, J. Schulman, A. Ramesh, A. Radford, and I. Sutskever. Distribution augmentation for generative modeling. In ICML, 2020.
  • [27] L. Kong, C. de Masson d’Autume, W. Ling, L. Yu, Z. Dai, and D. Yogatama. A mutual information maximization perspective of language representation learning. In arXiv preprint arXiv:1910.08350, 2019.
  • [28] A. Krizhevsky, V. Nair, and G. Hinton. Cifar-10 (canadian institute for advanced research).
  • [29] Z. Lan, M. Chen, S. Goodman, K. Gimpel, P. Sharma, and R. Soricut. Albert: A lite bert for self-supervised learning of language representations. In arXiv preprint arXiv:1909.11942, 2019.
  • [30] J. A. Lasserre, C. M. Bishop, and T. P. Minka. Principled hybrids of generative and discriminative models. In CVPR, 2006.
  • [31] K. Lee, K. Lee, H. Lee, and J. Shin. A simple unified framework for detecting out-of-distribution samples and adversarial attacks. In NeurIPS, 2018.
  • [32] K. Lee, W. Xu, F. Fan, and Z. Tu. Wasserstein introspective neural networks. In CVPR, 2018.
  • [33] K. Lee, Y. Zhu, K. Sohn, C.-L. Li, J. Shin, and H. Lee. i-mix: A strategy for regularizing contrastive representation learning. In arXiv preprint arXiv:2010.08887, 2020.
  • [34] M. Lewis, Y. Liu, N. Goyal, M. Ghazvininejad, A. Mohamed, O. Levy, V. Stoyanov, and L. Zettlemoyer. Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In arXiv preprint arXiv:1910.13461, 2019.
  • [35] H. Liu and P. Abbeel. Hybrid discriminative-generative training via contrastive learning. In arXiv preprint arXiv:2007.09070, 2020.
  • [36] X. Liu, F. Zhang, Z. Hou, Z. Wang, L. Mian, J. Zhang, and J. Tang. Self-supervised learning: Generative or contrastive. In arXiv preprint arXiv:2006.08218, 2020.
  • [37] Y. Liu, M. Ott, N. Goyal, J. Du, M. Joshi, D. Chen, O. Levy, M. Lewis, L. Zettlemoyer, and V. Stoyanov. Roberta: A robustly optimized bert pretraining approach. In arXiv preprint arXiv:1907.11692, 2019.
  • [38] I. Loshchilov and F. Hutter. Sgdr: Stochastic gradient descent with warm restarts. In ICLR, 2017.
  • [39] I. Loshchilov and F. Hutter. Decoupled weight decay regularization. In ICLR, 2019.
  • [40] R. Mackowiak, L. Ardizzone, U. Köthe, and C. Rother. Generative classifiers as a basis for trustworthy image classification. In arXiv preprint arXiv:2007.15036, 2020.
  • [41] I. Misra and L. van der Maaten. Self-supervised learning of pretext-invariant representations. In CVPR, 2020.
  • [42] A. V. Oord, N. Kalchbrenner, and K. Kavukcuoglu. Pixel recurrent neural networks. In ICML, 2016.
  • [43] A. Radford, R. C. J. Wu, D. Luan, D. Amodei, and I. Sutskever. Language models are unsupervised multitask learners. 2019.
  • [44] A. Radford, K. Narasimhan, T. Salimans, and I. Sutskever. Improving language understanding by generative pre-training. 2018.
  • [45] R. Raina, Y. Shen, A. Mccallum, and A. Y. Ng. Classification with hybri generative/discriminative models. In NeurIPS, 2004.
  • [46] A. Tamkin, M. Wu, and N. Goodman. Viewmaker networks: Learning views for unsupervised representation learning. In ICLR, 2021.
  • [47] Y. Tian, C. Sun, B. Poole, D. Krishnan, C. Schmid, and P. Isola. What makes for good views for contrastive learning. In arXiv preprint arXiv:2005.10243, 2020.
  • [48] A. van den Oord, Y. Li, and O. Vinyals. Representation learning with contrastive predictive coding. In arXiv preprint arXiv:1807.03748, 2018.
  • [49] L. van der Maaten and G. Hinton. Visualizing data using t-SNE. Journal of Machine Learning Research, 2008.
  • [50] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need. In NeurIPS, 2017.
  • [51] J. Winkens, R. Bunel, A. G. Roy, R. Stanforth, V. Natarajan, J. R. Ledsam, P. MacWilliams, P. Kohli, A. Karthikesalingam, S. Kohl, T. Cemgil, S. M. A. Eslami, and O. Ronneberger. Contrastive training for improved out-of-distribution detection. In arXiv preprint arXiv:2007.05566, 2020.
  • [52] J.-H. Xue and D. M. Titterinton. On the generative-discriminative tradeoff approach: Interpretation, asymptotic, efficiency and classification performance. Computational Statistics and Data Analysis, 54:438 – 451, 2010.
  • [53] S. Zagoruyko and N. Komodakis. Wide residual networks. In Proceedings of the British Machine Vision Conference (BMVC), 2016.

Appendix

# h: the representation from previous transformer block (tensor shape: BxHxWxC). Here, B refers to the batch size, (H, W) means the height and width, and C is the embedding dimension.
# MSA: multi-head self-attention
def attention_block (h):
       ha = MSA(h.view(B*H, W, C), causal_mask = False).view(B, H, W, C).transpose(1,2)
       ha = MSA(ha.view(B*W, H, C), causal_mask = True).view(B, W, H, C).transpose(1,2)
       ha = torch.cat([torch.zeros(B, 1, W, C), ha[:, :-1, :, :]], axis=1)
       h = MSA((h + ha).view(B*H, W, C), causal_mask = True).view(B, H, W, C)
       return h
end def
Algorithm 1 PyTorch-like Pseudo-code of an Axial Attention Block
Table 11: Comparison of different attention types in iGPT (10M) on MNIST and CIFAR10 with three random seeds (0, 1, and 2).
Attention type MNIST CIFAR10
Throughput (images/s) Test bpd. Throughput (images/s) Test bpd.
Dense 303.0862 ±\pm 0.3009 0.5538 ±\pm 0.0032 202.2085 ±\pm 0.2702 2.7676 ±\pm 0.0016
Axial 383.6595 ±\pm 0.5432 0.5629 ±\pm 0.0030 355.0681 ±\pm 0.1357 2.7775 ±\pm 0.0021
Table 12: AUPRC results of a supervised OOD detection task, where the in-distribution datasets are ImageNet32, CIFAR10 and CIFAR100.
Method Model size Pos. ImageNet32 CIFAR10 CIFAR100
SVHN STL-10 SVHN STL-10 SVHN STL-10
iGPT 10M half 0.9924 0.8851 0.9436 ±\pm 0.0048 0.6715 ±\pm 0.0064 0.9158 ±\pm 0.0080 0.6586 ±\pm 0.0034
last 0.9801 0.8775 0.9208 ±\pm 0.0061 0.7085 ±\pm 0.0018 0.8976 ±\pm 0.0023 0.7079 ±\pm 0.0021
76M half 0.9916 0.9125 0.9829 ±\pm 0.0012 0.7931 ±\pm 0.0016 0.9491 ±\pm 0.0009 0.7952 ±\pm 0.0035
last 0.9666 0.8800 0.9040 ±\pm 0.0031 0.7318 ±\pm 0.0002 0.8195 ±\pm 0.0031 0.7494 ±\pm 0.0029
SimCLR 5M last 0.7258 0.8522 0.6211 ±\pm 0.1460 0.5636 ±\pm 0.0052 0.6857 ±\pm 0.1142 0.6754 ±\pm 0.0066
10M last 0.7682 0.8691 0.5276 ±\pm 0.0193 0.5620 ±\pm 0.0059 0.5525 ±\pm 0.0501 0.6687 ±\pm 0.0258
76M last 0.7706 0.8646 0.6310 ±\pm 0.0337 0.5743 ±\pm 0.0105 0.5573 ±\pm 0.0150 0.7065 ±\pm 0.0064
GCRL (5+5)M half 0.9989 0.9061 0.9938 ±\pm 0.0006 0.7026 ±\pm 0.0038 0.9803 ±\pm 0.0027 0.7018 ±\pm 0.0025
(38+38)M half 0.9992 0.9090 0.9953 ±\pm 0.0002 0.7600 ±\pm 0.0030 0.9698 ±\pm 0.0011 0.7820 ±\pm 0.0015

We introduce a PyTorch-like pseudo-code for our implementation of axial attention blocks and compare the performance of it to the dense attention blocks. Algorithm 1 shows how to transform the representation from previous block in the axial attention block. We remark that all parameters across attention types in the same block are shared.

Table 11 compares our axial attention with dense attention in iGPT of 10M parameters in terms of inference throughput and test bpd. We train the models on MNIST and CIFAR10 by 30 and 200 epochs, respectively. We do not apply data augmentations to MNIST. For CIFAR10, we apply horizontal flipping of probability 0.5 and random cropping with resizing of pad size 4. We measure inference throughput by a V100 GPU with 100 repetitions, where we use the batch sizes of 48 and 24 for MNIST and CIFAR10 in case of the dense attention. We use the batch sizes of 128 and 64 for MNIST and CIFAR10 in case of the axial attention block. Table 11 shows that our axial attention block performs comparably to the dense attention block, while improving inference throughput by 26% (MNIST) and 75% (CIFAR10). We remark that it is free to choose other sparse attention implementations, since the main contribution of our work is to introduce the benefits of representation learned from a hybrid objective in an unsupervised fashion.

Table 12 presents AUPRC results of the supervised OOD detection task, showing the same trend of AUROC results. Figure 6-8 show 500 generated samples with softmax temperature (denoted as τ\tau) of 1.0, 0.99, and 0.98, where the same random seed is used. Regardless of the temperature, our model is able to generate images of high perceptual quality.

Refer to caption
Figure 6: Generated images from GCRL of (38+38)M trained on ImageNet32 (τ=1.0\tau=1.0).
Refer to caption
Figure 7: Generated images from GCRL of (38+38)M trained on ImageNet32 (τ=0.99\tau=0.99).
Refer to caption
Figure 8: Generated images from GCRL of (38+38)M trained on ImageNet32 (τ=0.98\tau=0.98).