Set-based Meta-Interpolation for
Few-Task Meta-Learning
Abstract
Meta-learning approaches enable machine learning systems to adapt to new tasks given few examples by leveraging knowledge from related tasks. However, a large number of meta-training tasks are still required for generalization to unseen tasks during meta-testing, which introduces a critical bottleneck for real-world problems that come with only few tasks, due to various reasons including the difficulty and cost of constructing tasks. Recently, several task augmentation methods have been proposed to tackle this issue using domain-specific knowledge to design augmentation techniques to densify the meta-training task distribution. However, such reliance on domain-specific knowledge renders these methods inapplicable to other domains. While Manifold Mixup based task augmentation methods are domain-agnostic, we empirically find them ineffective on non-image domains. To tackle these limitations, we propose a novel domain-agnostic task augmentation method, Meta-Interpolation, which utilizes expressive neural set functions to densify the meta-training task distribution using bilevel optimization. We empirically validate the efficacy of Meta-Interpolation on eight datasets spanning across various domains such as image classification, molecule property prediction, text classification and sound classification. Experimentally, we show that Meta-Interpolation consistently outperforms all the relevant baselines. Theoretically, we prove that task interpolation with the set function regularizes the meta-learner to improve generalization.
1 Introduction
The ability to learn a new task given only a few examples is crucial for artificial intelligence. Recently, meta-learning [39, 3] has emerged as a viable method to achieve this objective and enables machine learning systems to quickly adapt to a new task by leveraging knowledge from other related tasks seen during meta-training. Although existing meta-learning methods can efficiently adapt to new tasks with few data samples, a large dataset of meta-training tasks is still required to learn meta-knowledge that can be transferred to unseen tasks. For many real-world applications, such extensive collections of meta-training tasks may be unavailable. Such scenarios give rise to the few-task meta-learning problem where a meta-learner can easily memorize the meta-training tasks but fail to generalize well to unseen tasks. The few-task meta-learning problem usually results from the difficulty in task generation and data collection. For instance, in the medical domain, it is infeasible to collect a large amount of data to construct extensive meta-training tasks due to privacy concerns. Moreover, for natural language processing, it is not straightforward to split a dataset into tasks, and hence entire datasets are treated as tasks [30].
Several works have been proposed to tackle the few-task meta-learning problem using task augmentation techniques such as clustering a dataset into multiple tasks [30], leveraging strong image augmentation methods such as vertical flipping to construct new classes [32], and the employment of Manifold Mixup [44] for densifying the meta-training task distribution [49, 50]. However, majority of these techniques require domain-specific knowledge to design such task augmentations and hence cannot be applied to other domains. While Manifold Mixup based methods [49, 50] are domain-agnostic, we empirically find them ineffective for mitigating meta-overfitting in few-task meta-learning especially in non-image domains such as chemical and text, and that they sometimes degrade generalization performance.
In this work, we focus solely on domain-agnostic task augmentation methods that can densify the meta-training task distribution to prevent meta-overfitting and improve generalization at meta-testing for few-task meta-learning. To tackle the limitations already discussed, we propose a novel domain-agnostic task augmentation method for metric based meta-learning models. Our method, Meta-Interpolation, utilizes expressive neural set functions to interpolate two tasks and the set functions are trained with bilevel optimization so that a meta-learner trained on the interpolated tasks generalizes to tasks in the meta-validation set. As a consequence of end-to-end training, the learned augmentation strategy is tailored to each specific domain without the need for specialized domain knowledge.

For example, for -way classification, we sample two tasks consisting of support and query sets and assign a new class to each pair of classes for , where are permutations on as depicted in Figure 1a. Hidden representations of the support set with classes and are then transformed into a single support set using a set function that maps a set of two vectors to a single vector. We refer to the output of the set function as the interpolated support set and these are used to compute class prototypes. As shown in Figure 1b, the interpolated support set is paired with a query set (query set 1 in Figure 1a)), randomly selected from the two tasks to obtain a new task. Lastly, we optimize the set function so that a meta-learner trained on the augmented task can minimize the loss on the meta-validation tasks as illustrated in Figure 1c.
To verify the efficacy of our method, we empirically show that it significantly improves the performance of Prototypical Networks [40] on the few-task meta-learning problem across multiple domains. Our method outperforms the relevant baselines on eight few-task meta-learning benchmark datasets spanning image classification, chemical property prediction, text classification, and sound classification. Furthermore, our theoretical analysis shows that our task interpolation method with the set function regularizes the meta-learner and improves generalization performance.
Our contribution is threefold:
-
•
We propose a novel domain-agnostic task augmentation method, Meta-Interpolation, which leverages expressive set functions to densify the meta-training task distribution for the few-task meta-learning problem.
-
•
We theoretically analyze our model and show that it regularizes the meta-learner for better generalization.
-
•
Through extensive experiments, we show that Meta-Interpolation significantly improves the performance of Prototypical Network on various domains such as image, text, and chemical molecule, and sound classification on the few-task meta-learning problem.
2 Related Work
Meta-Learning
The two mainstream approaches to meta-learning are gradient based [10, 33, 14, 24, 12, 36, 37] and metric based meta-learning [45, 40, 42, 29, 26, 6, 38]. The former formulates meta-knowledge as meta-parameters such as the initial model parameters and performs bilevel optimization to estimate the meta-parameters so that a meta-learner can generalize to unseen tasks with few gradient steps. The latter learns an embedding space where classification is performed by measuring the distance between a query and a set of class prototypes. In this work, we focus on metric based meta-learning with fewer number of meta-training tasks, i.e., few-task meta-learning. We propose a novel task augmentation method that densifies the meta-training task distribution and mitigates overfitting due to the fewer number of meta-training tasks for better generalization to unseen tasks.
Task Augmentation for Few-Task Meta-learning
Several methods have been proposed to augment the number of meta-training tasks to mitigate overfitting in the context of few-task meta-learning. Ni et al. [32] apply strong data augmentations such as vertical flip to images to create a new class. For text classification, Murty et al. [30] split meta-training tasks into latent reasoning categories by clustering data with a pretrained language model. However, they require domain-specific knowledge to design such augmentations, and hence the resulting augmentation techniques are inapplicable to other domains where there is no well-defined data augmentation or pretrained model. In order to tackle this limitation, Manifold Mixup-based task augmentations have also been proposed. MetaMix [49] interpolates support and query sets with Manifold Mixup [44] to construct a new query set. MLTI [50] performs Manifold Mixup [44] on support and query sets from two tasks for task augmentation. Although these methods are domain-agnostic, we empirically find that they are not effective in some domains and can degrade generalization performance. In contrast, we propose to train an expressive neural set function to interpolate two tasks with bilevel optimization to find optimal augmentation strategies tailored specifically to each domain.
3 Method
Preliminaries
In meta-learning, we are given a finite set of tasks , which are i.i.d samples from an unknown task distribution . Each task consists of a support set and a query set , where denote a data point and its corresponding label respectively. Given a predictive model with layers, we want to estimate the parameter that minimizes the meta-training loss and generalizes to query sets sampled from an unseen task using the support set , where is a hyperparameter for the function . In this work, we primarily focus on metric based meta-learning methods rather than gradient based meta-learning methods due to efficiency and empirically higher performance over the gradient based methods on the tasks we consider.
Problem Statement
In this work, we focus solely on few-task meta-learning. Here, the number of meta-training tasks drawn from the meta-training distribution is extremely small and the goal of a meta-learner is to learn meta-knowledge from such limited tasks that can be transferred to unseen tasks during meta-testing. The key challenges here are preventing the meta-learner from overfitting on the meta-training tasks and generalizing to unseen tasks drawn from a meta-test set.
Metric Based Meta-Learning
The goal of metric based meta-learning is to learn an embedding space induced by , where we perform classification by computing distances between data points and class prototypes. We adopt Prototypical Network (ProtoNet) [40] for , where is the identity function. Specifically, for each task with its corresponding support and query sets, we compute class prototypes as the average of the hidden representation of the support samples belonging to the class as follows:
(1) |
where denotes the number of instances belonging to the class . Given a metric , we compute the probability of a query point being assigned to the class by measuring the distance between the hidden representation and the class prototype followed by softmax. With the class probability, we compute the cross-entropy loss for ProtoNet as follows:
(2) |
where is an indicator function. At meta-test time, a test query is assigned a label based on the minimal distance to a class prototype, i.e., . However, optimizing w.r.t is prone to overfitting since we are given only a small number of meta-training tasks. The meta-learner tends to memorize the meta-training tasks, which limits its generalization to new tasks at meta-test time [51, 35].
Meta-Interpolation for Task Augmentation
In order to tackle the meta-overfitting problem with a small number of tasks, we propose a novel data-driven domain-agnostic task augmentation framework which enables the meta-learner trained on few tasks to generalize to unseen few-shot classification tasks. Several methods have been proposed to densify the meta-training tasks. However, they heavily depend on the augmentation of images [32] or need a pretrained language model for task augmentation [30]. Although Manifold Mixup based methods [49, 50] are domain-agnostic, we empirically find them ineffective in certain domains. Instead, we optimize expressive neural set functions to augment tasks to enhance the generalization of a meta-learner to unseen tasks. As a consequence of end-to-end training, the learned augmentation strategy is tailored to each domain.
Specifically, let be a set function which maps a set of dimensional vectors with cardinality to a dimensional vector. In all our experiments, we use Set Transformer [23] for . Given a pair of tasks and with corresponding support and query sets for way classification, we construct new classes by choosing pairs of classes from the two tasks. We sample permutations and on for each task and respectively and assign class to the pair for . For the newly assigned class , we pair two instances from classes and and interpolate their hidden representations with the set function . The class prototypes for class are computed using the output of as follows:
(3) |
where we define to be the set of all the interpolated prototypes for . For queries, we do not perform any interpolation. Instead, we use as the query set and compute its hidden representation . We then measure the distance between the query with and the interpolated prototype of class to compute the loss as follows:
(4) |
where . The intuition behind interpolating only support sets is to construct harder tasks that a meta-learner cannot memorize. Alternatively, we can interpolate only query sets. However, this is computationally more expensive since the size of query sets is usually larger than that of support sets. In Section 5, we empirically show that interpolating either support or query sets achieves higher training loss than interpolating both, which empirically supports the intuition. Lastly, we also use the original task to evaluate the loss in Eq. 2 by passing the corresponding support and query set to . The additional forward pass enriches the diversity of the augmented tasks and makes meta-training consistent with meta-testing since we do not perform any task augmentation in the meta-testing stage.
Optimization
Since jointly optimizing , the parameters of ProtoNet, and , the parameters of the set function , with few tasks is prone to overfitting, we consider as hyperparameters and perform bilevel optimization with meta-training and meta-validation tasks as follows:
(5) | ||||
(6) |
where denote the meta-training, meta-validation, and interpolated task, respectively. Since computing the exact gradient w.r.t is intractable due to the long inner optimization steps in Eq. 6, we leverage the implicit function theorem to approximate the gradient as Lorraine et al. [27]. Moreover, we alternately update and for computational efficiency as described in Algo. 1 and 2.
4 Theoretical Analysis
In this section, we theoretically investigate the behavior of the Set Transformer and how it induces a distribution dependent regularization, which is then shown to have the ability to control the Rademacher complexity for better generalization. To analyze the behavior of the Set Transformer, we first define it concretely with the attention mechanism . Given , define and . Then, for any , the output of the Set Transformer is defined as follows:
(7) |
where , , , (for ), and . denote query, key, and value for the attention mechanism for , respectively. Here, , , , , and . Let .
Our analysis will show the importance of the following quantity of the Set Transformer in our method:
(8) |
where , , with and . For a matrix denotes the entry for -th row and -th column of the matrix .
We now introduce the additional notation and problem setting to present our results. Define , , , and , where . We also define the empirical measure over the index with the Dirac measures . Let be the uniform distribution over For any function and point in its domain, we define the -th order tensor by For example, and are the gradient and the Hessian of evaluated at . For any -th order tensor , we define the vectorization of the tensor by . For an vector , we define where represents the Kronecker product. We assume that for all , where . This assumption is satisfied, for example, if represents a deep neural network with ReLU activations. This assumption is also satisfied in the simpler special case considered in the proposition below.
The following theorem shows that is approximately plus regularization terms on the directional derivatives of on the direction of :
Theorem 1.
For any , if is -times differentiable for all , then the -th order approximation of is given by where and
To illustrate the effect of this data-dependent regularization, we now consider the following special case that is used by Yao et al. [50] for ProtoNet: where , , and denotes dot product. Define , where and . Note that if is no worse than the random guess; e.g., for all . We write for any positive semi-definite matrix . In this special case, we consider that is balanced: i.e., for all . This is used to prevent the Set Transformer from over-fitting to the training sets; i.e., in such simple special cases, the Set Transformer without any restriction is too expressive relative to the rest of the model (and may memorize the training sets without using the rest of the model). The following proposition shows that the additional regularization term is simplified to the form of in this special case:
Proposition 1.
In the special case explained above, the second approximation of is given by where
In the above regularization form, we have an implicit regularization effect on where . The following theorem shows that this implicit regularization can reduce the Rademacher complexity for better generalization:
Proposition 2.
Let with . Then, .
All the proofs are presented in Appendix A.
5 Experiments
We now demonstrate the efficacy of our set-based task augmentation method on multiple few-task benchmark datasets and compare against the relevant baselines.
Datasets
We perform classification on eight datasets to validate our method. (1), (2), & (3) Metabolism [17], NCI [31] and Tox21 [18]: these are binary classification datasets for predicting the properties of chemical molecules. For Metabolism, we use three subdatasets for meta-training, meta-validation, and meta-testing, respectively. For NCI, we use four subdatasets for meta-training, two for meta-validation and the remaining three for meta-testing. For Tox21, we use six subdatasets for meta-training, two for meta-validation, and four for meta-testing. (4) GLUE-SciTail [30]: it consists of four natural language inference datasets where we predict whether a hypothesis sentence contradicts a premise sentence. We use MNLI [47] and QNLI [46] for meta-training, SNLI [5] and RTE [46] for meta-validation, and SciTail [20] for meta-testing. (5) ESC-50 [34]: this is an environmental sound recognition dataset. We make a 20/15/15 split out of 50 base classes for meta-training/validation/testing and sample 5 classes from each spilt to construct a 5-way classification task. (6) Rainbow MNIST (RMNIST) [11]: this is a 10-way classification dataset. Following Yao et al. [50], we construct each task by applying compositions of image transformations to the images of the MNIST [9] dataset. (7) & (8) Mini-ImageNet-S [45] and CIFAR100-FS [22]: these are 5-way classification datasets where we choose 12/16/20 classes out of 100 base classes for meta-training/validation/testing, respectively and sample 5 classes from each split to construct a task.
Note that Metabolism, Tox21, NCI, GLUE-SciTail, and ESC-50 are real-world few-task meta-learning datasets with a very small number of tasks. For Mini-ImageNet-S and CIFAR100-FS, following Yao et al. [50], we artificially reduce the number of tasks from the original datasets for few-task meta-learning. RMNIST is synthetically generated by applying augmentations to MNIST.
Implementation Details
For RMNIST, Mini-ImageNet-S, and CIFAR100-FS, we use four convolutional blocks with each block consisting of a convolution, ReLU, batch normalization [19], and max pooling. For Metabolism, Tox21, and NCI, we convert the chemical molecules into SMILES format and extract a 1024 bit fingerprint feature using RDKit [15] where each bit captures a fragment of the molecule. We use two blocks of affine transformation, batch normalization, and Leaky ReLU, and affine transformation for the last layer. For GLUE-SciTail dataset, we stack 3 fully connected layers with ReLU on the pretrained language model ELECTRA [8]. For ESC-50 dataset, we pass raw audio signal to the pretrained VGGish [16] feature extractor to obtain an embedding vector. We use the feature vector as input to the classifier which is exactly the same as the one used for Metabolism, Tox21, and NCI. For our Meta-Interpolation, we use Set Transformer [23] for the set function .
Baselines
We compare our method against following domain-agnostic baselines.
- 1.
-
2.
MetaReg [2]: ProtoNet with regularization where element-wise coefficients are learned with bilevel optimization.
-
3.
MetaMix [49]: ProtoNet trained with support sets and mixed query sets where we interpolate one instance from the support sets and the other from the original query sets with Manifold Mixup.
-
4.
MLTI [50]: ProtoNet trained with Manifold Mixup based task augmentation. We sample two tasks and interpolate two query sets and support sets with Manifold Mixup, respectively.
-
5.
ProtoNet+ST ProtoNet and Set Transformer () trained with bilevel optimization but without task augmentation for in Eq. 6.
-
6.
Meta-Interpolation Our full model learning to interpolate support sets from two tasks using bilevel optimization and training the ProtoNet with both the original and interpolated tasks.
Chemical | Text | Speech | |||
---|---|---|---|---|---|
Metabolism | Tox21 | NCI | GLUE-SciTail | ESC-50 | |
Method | 5-shot | 5-shot | 5-shot | 4-shot | 5-shot |
ProtoNet | |||||
MetaReg | |||||
MetaMix | |||||
MLTI | |||||
ProtoNet+ST | |||||
Meta-Interpolation |
Results
As shown in Table 1, Meta-Interpolation consistently outperforms all the domain-agnostic task augmentation and regularization baselines on non-image domains. Notably, it significantly improves the performance on ESC-50, which is a challenging datatset that only contains 40 examples per class. In addition, Meta-Interpolation effectively tackles the Metabolism and GLUE-SciTail datasets which have an extremely small number of meta-training tasks: three and two meta-training tasks, respectively. Contrarily, the baselines do not achieve consistent improvements across all the domains and tasks considered. For example, MetaReg is effective on the sound domain (ESC-50) and Metabolism, but does not work on the chemical (Tox21 and NCI) and text (GLUE-SciTail) domains. Similarly, MetaMix and MLTI achieve performance improvements on some datasets but degrade the test accuracy on others. This empirical evidence supports the hypothesis that the optimal task augmentation strategy varies across domains and justifies the motivation for Meta-Interpolation which learns augmentation strategies tailored to each domain.
RMNIST | Mini-ImageNet-S | CIFAR-100-FS | |||
---|---|---|---|---|---|
Method | 1-shot | 1-shot | 5-shot | 1-shot | 5-shot |
ProtoNet | |||||
MetaReg | |||||
MetaMix | |||||
MLTI | |||||
ProtoNet+ST | |||||
Meta Interpolation |
We provide additional experimental results on the image domain in Table 2. Again, Meta-Interpolation outperforms all the baselines. In contrast to the previous experiments, MetaReg hurts the generalization performance on all the image datasets except on RMNIST. Note that Manifold Mixup-based augmentation methods, MetaMix and MLTI, marginally improve the generalization performance for 1-shot classification on Mini-ImageNet-S and CIFAR-100-FS, although they boost the accuracy on 5-shot experiments. This suggests that different task augmentation strategies are required for 1-shot and 5-shot for the same dataset. Meta-Interpolation on the other hand learns task augmentation strategies tailored for each task and dataset and consistently improves the performance of the vanilla ProtoNet for all the experiments on the image datasets.





Moreover, we plot the meta-training and meta-validation loss on RMNIST and NCI dataset in Figure 2. Meta-Interpolation obtains higher training loss but much lower validation loss than the others on both datasets. This implies that interpolating only support sets constructs harder tasks that a meta-learner cannot memorize and regularizes the meta-learner for better generalization. ProtoNet overfits to the meta-training tasks on both datasets. MLTI mitigates the overfitting issue on RMNIST but is not effective on the NCI dataset where it shows high validation loss in Figure 2(d). On the other hand, MetaMix, which constructs a new query set by interpolating a support and query set with Manifold Mixup, results in generating overly difficult tasks which causes underfitting on RMNIST where the training loss is not properly minimized in Figure 2(a). However, this augmentation strategy is effective for tackling meta-overfitting on NCI where the validation loss is lower than ProtoNet. The loss curve of ProtoNet+ST supports the claim that increasing the model size and using bilevel optimization cannot handle the few-task meta-learning problem. It shows higher validation loss on both RMNIST and NCI as presented in Figure 2(b) and 2(d). Similarly, MetaReg which learns coefficients for regularization fails to prevent meta-overfitting on both datasts.
Lastly, we empirically show that the performance gains mostly come from the task augmentation with Meta-Interpolation, rather than from bilevel optimization or the introduction of extra parameters with the set function. As shown in Table 1 and 2, ProtoNet+ST, which is Meta Interpolation but trained without any task augmentation, significantly degrades the performance of ProtoNet on Mini-ImageNet and CIFAR-100-FS. On the other datasets, the ProtoNet+ST obtains marginal improvement or largely underperforms the other baselines. Thus, the task augmentation strategy of interpolating two support sets with the set function is indeed crucial for tackling the few-task meta-learning problem.
![[Uncaptioned image]](https://cdn.awesomepapers.org/papers/a514bbe3-3f4f-4a83-9a96-0f6618c56a78/x7.png)




Model | Accuracy |
---|---|
Meta-Interpolation | |
w/o Interpolation | |
w/o Bilevel | |
w/o |
Set Function | Accuracy |
---|---|
ProtoNet | |
DeepSets | |
Set Transformer |
Interpolation Strategy | Accuracy |
---|---|
Query+Support | |
Query | |
Support+ Noise | |
Support |
Ablation Study
We further perform ablation studies to verify the effectiveness of each component of Meta-Interpolation. In Table 5, we show experimental results on the ESC-50 dataset by removing various components of our model. Firstly, we train our model without any task interpolation but keep the set function , denoted as w/o Interpolation. The model without task interpolation significantly underperforms the full task-augmentation model, Meta-Interpolation, which shows that the improvements come from task interpolation rather than the extra parameters introduced by the set encoding layer. Moreover, bilevel optimization is shown to be effective for estimating , which are the parameters of the set function. Jointly training the ProtoNet and the set function without bilevel optimization, denoted as w/o Bilevel, largely degrades the test accuracy by . Lastly, we remove the loss for inner optimization in Eq. 6, denoted as w/o . This hurts the generalization performance since it decreases the diversity of tasks and causes inconsistency between meta-training and meta-testing, since we do not perform any interpolation for support sets at meta-test time.
We also explore an alternative set function such as DeepSets [52] using the ESC50 dataset to show the general effectiveness of our method regardless of the set encoding scheme. In Table 5, Meta-Interpolation with DeepSets improves the generalization performance of ProtoTypical Network and the model with Set Transformer further boost the performance as a consequence of higher-order and pairwise interactions among the set elements via the attention mechanism.
RMNIST | |
---|---|
Interpolation Strategy | Accuracy |
Support+ Noise | |
Support |
Lastly, we empirically validate our interpolation strategy that mixes only support sets. We compare our method to various interpolation strategies including one that mixes a support set with a zero mean and unit variance Gaussian noise. In Table 5, we empirically show that the interpolation strategy which mixes only support sets outperforms the other mixing strategies. Note that interpolating a support set with gaussian noise works well on ESC50 dataset though we find that it significantly degrades the performance of ProtoNet on RMNIST, from to as shown in Table 5, which justifies our approach of mixing two support sets.
Visualization
In Figure 3(a), we present the t-SNE [43] visualizations of the original and interpolated tasks with MLTI and Meta-Interpolation, respectively. Following Yao et al. [50], we sample three original tasks from NCI and ESC-50 dataset, where each task is a two-way five-shot and five-way five-shot classification problem, respectively. The tasks are interpolated with MLTI or Meta-Interpolation to construct 300 additional tasks and represented as a set of all class prototypes. To visualize the prototypes, we first perform Principal Component Analysis [13] (PCA) to reduce the dimension of each prototype. The first 50 principal components are then used to compute the t-SNE visualizations. As shown in Figure 3(b) and 3(d), Meta-Interpolation successfully learns an expressive neural set function that densifies the task distribution. The task augmentations with MLTI, however, do not cover a wide embedding space as shown in Figure 3(a) and 3(c) as the mixup strategy allows to generate tasks only on the simplex defined by the given set of tasks.
Limitation
Although we have shown promising results in various domains, our method requires extra computation for bilevel optimization to estimate , the parameters of the set function , which makes it challenging to apply our method to gradient based meta-learning methods such as MAML. Moreover, our interpolation is limited to classification problem and it is not straightforward to apply it to regression tasks. Reducing the computational cost for bilevel optimization and extending our framework to regression will be important for future work.
6 Conclusion
We proposed a novel domain-agnostic task augmentation method, Meta Interpolation, to tackle the meta-overfitting problem in few-task meta-learning. Specifically, we leveraged expressive neural set functions to interpolate a given set of tasks and trained the interpolating function using bilevel optimization, so that the meta-learner trained with the augmented tasks generalizes to meta-validation tasks. Since the set function is optimized to minimize the loss on the validation tasks, it allows us to tailor the task augmentation strategy to each specific domain. We empirically validated the efficacy of our proposed method on various domains, including image classification, chemical property prediction, text and sound classification, showing that Meta-Interpolation achieves consistent improvements across all domains. This is in stark contrast to the baselines which improve generalization in certain domains but degenerate performance in others. Furthermore, our theoretical analysis shed light on how Meta-Interpolation regularizes the meta-learner and improves its generalization performance. Lastly, we discussed the limitation of our method.
Acknowledgments and Disclosure of Funding
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST)), the Engineering Research Center Program through the National Research Foundation of Korea (NRF) funded by the Korean Government MSIT (NRF-2018R1A5A1059921), Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2021-0-02068, Artificial Intelligence Innovation Hub), the National Research Foundation of Korea (NRF) funded by the Ministry of Education (NRF-2021R1F1A1061655), Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2022-0-00713), Samsung Electronics (IO201214-08145-01), and Google Research Grant. It was also results of a study on the “HPC Support” Project, supported by the ‘Ministry of Science and ICT’ and NIPA.
References
- Ba et al. [2016] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
- Balaji et al. [2018] Yogesh Balaji, Swami Sankaranarayanan, and Rama Chellappa. MetaReg: Towards domain generalization using meta-regularization. Advances in neural information processing systems, 31, 2018.
- Bengio et al. [1991] Yoshua Bengio, Samy Bengio, and Jocelyn Cloutier. Learning a synaptic learning rule. In IJCNN-91-Seattle International Joint Conference on Neural Networks, volume 2, pages 969–vol. IEEE, 1991.
- Blair [1992] Charles Blair. The computational complexity of multi-level linear programs. Annals of Operations Research, 34, 1992.
- Bowman et al. [2015] Samuel Bowman, Gabor Angeli, Christopher Potts, and Christopher D Manning. A large annotated corpus for learning natural language inference. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pages 632–642, 2015.
- Cao et al. [2021] Kaidi Cao, Maria Brbic, and Jure Leskovec. Concept learners for few-shot learning. In International Conference on Learning Representations, 2021.
- Chen et al. [2011] Richard Li-Yang Chen, Amy Cohn, and Ali Pinar. An implicit optimization approach for survivable network design. In 2011 IEEE network science workshop, pages 180–187. IEEE, 2011.
- Clark et al. [2020] Kevin Clark, Minh-Thang Luong, Quoc V. Le, and Christopher D. Manning. ELECTRA: Pre-training text encoders as discriminators rather than generators. In International Conference on Learning Representations, 2020.
- Deng [2012] Li Deng. The mnist database of handwritten digit images for machine learning research [best of the web]. IEEE signal processing magazine, 29(6):141–142, 2012.
- Finn et al. [2017] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning, pages 1126–1135. PMLR, 2017.
- Finn et al. [2019] Chelsea Finn, Aravind Rajeswaran, Sham Kakade, and Sergey Levine. Online meta-learning. In International Conference on Machine Learning, pages 1920–1930. PMLR, 2019.
- Flennerhag et al. [2020] Sebastian Flennerhag, Andrei A. Rusu, Razvan Pascanu, Francesco Visin, Hujun Yin, and Raia Hadsell. Meta-learning with warped gradient descent. In International Conference on Learning Representations, 2020.
- F.R.S. [1901] Karl Pearson F.R.S. Liii. On lines and planes of closest fit to systems of points in space. The London, Edinburgh, and Dublin Philosophical Magazine and Journal of Science, 2(11):559–572, 1901. doi: 10.1080/14786440109462720.
- Grant et al. [2019] Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, and Thomas Griffiths. Recasting gradient-based meta-learning as hierarchical bayes. In International Conference on Learning Representations, 2019.
- Greg Landrum [2018] Greg Landrum. RDKit: Open-source cheminformatics software., 2018. https://github.com/rdkit/rdkit.
- Hershey et al. [2017] Shawn Hershey, Sourish Chaudhuri, Daniel PW Ellis, Jort F Gemmeke, Aren Jansen, R Channing Moore, Manoj Plakal, Devin Platt, Rif A Saurous, Bryan Seybold, et al. CNN architectures for large-scale audio classification. In 2017 IEEE International Conference on Acoustics, Speech and Signal Processing, ICASSP, pages 131–135. IEEE, 2017.
- Huang et al. [2021] Kexin Huang, Tianfan Fu, Wenhao Gao, Yue Zhao, Yusuf Roohani, Jure Leskovec, Connor W Coley, Cao Xiao, Jimeng Sun, and Marinka Zitnik. Therapeutics data commons: machine learning datasets and tasks for therapeutics. arXiv e-prints, pages arXiv–2102, 2021.
- Huang et al. [2016] Ruili Huang, Menghang Xia, Dac-Trung Nguyen, Tongan Zhao, Srilatha Sakamuru, Jinghua Zhao, Sampada A Shahane, Anna Rossoshek, and Anton Simeonov. Tox21challenge to build predictive models of nuclear receptor and stress response pathways as mediated by exposure to environmental chemicals and drugs. Frontiers in Environmental Science, 3:85, 2016.
- Ioffe and Szegedy [2015] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pages 448–456. PMLR, 2015.
- Khot et al. [2018] Tushar Khot, Ashish Sabharwal, and Peter Clark. SciTail: A textual entailment dataset from science question answering. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018.
- Kingma and Ba [2015] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
- Krizhevsky et al. [2009] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. In https://www.cs.toronto.edu/ kriz/cifar.html. Citeseer, 2009.
- Lee et al. [2019] Juho Lee, Yoonho Lee, Jungtaek Kim, Adam Kosiorek, Seungjin Choi, and Yee Whye Teh. Set transformer: A framework for attention-based permutation-invariant neural networks. In International Conference on Machine Learning, pages 3744–3753. PMLR, 2019.
- Lee and Choi [2018] Yoonho Lee and Seungjin Choi. Gradient-based meta-learning with learned layerwise metric and subspace. In International Conference on Machine Learning, pages 2927–2936. PMLR, 2018.
- Lhoest et al. [2021] Quentin Lhoest, Albert Villanova del Moral, Yacine Jernite, Abhishek Thakur, Patrick von Platen, Suraj Patil, Julien Chaumond, Mariama Drame, Julien Plu, Lewis Tunstall, Joe Davison, Mario Šaško, Gunjan Chhablani, Bhavitvya Malik, Simon Brandeis, Teven Le Scao, Victor Sanh, Canwen Xu, Nicolas Patry, Angelina McMillan-Major, Philipp Schmid, Sylvain Gugger, Clément Delangue, Théo Matussière, Lysandre Debut, Stas Bekman, Pierric Cistac, Thibault Goehringer, Victor Mustar, François Lagunas, Alexander Rush, and Thomas Wolf. Datasets: A community library for natural language processing. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 175–184, Online and Punta Cana, Dominican Republic, 2021. Association for Computational Linguistics.
- Liu et al. [2019] Yanbin Liu, Juho Lee, Minseop Park, Saehoon Kim, Eunho Yang, Sungju Hwang, and Yi Yang. Learning to propagate labels: Transductive propagation network for few-shot learning. In International Conference on Learning Representations, 2019.
- Lorraine et al. [2020] Jonathan Lorraine, Paul Vicol, and David Duvenaud. Optimizing millions of hyperparameters by implicit differentiation. In International Conference on Artificial Intelligence and Statistics, pages 1540–1552. PMLR, 2020.
- Loshchilov and Hutter [2019] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
- Mishra et al. [2018] Nikhil Mishra, Mostafa Rohaninejad, Xi Chen, and Pieter Abbeel. A simple neural attentive meta-learner. In International Conference on Learning Representations, 2018.
- Murty et al. [2021] Shikhar Murty, Tatsunori B Hashimoto, and Christopher D Manning. DReCa: A general task augmentation strategy for few-shot natural language inference. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 1113–1125, 2021.
- NCI [2018] NCI. NCI dataset, 2018. https://github.com/GRAND-Lab/graph_datasets.
- Ni et al. [2021] Renkun Ni, Micah Goldblum, Amr Sharaf, Kezhi Kong, and Tom Goldstein. Data augmentation for meta-learning. In International Conference on Machine Learning, pages 8152–8161. PMLR, 2021.
- Nichol et al. [2018] Alex Nichol, Joshua Achiam, and John Schulman. On first-order meta-learning algorithms. arXiv preprint arXiv:1803.02999, 2018.
- Piczak [2015] Karol J. Piczak. ESC: Dataset for Environmental Sound Classification. In Proceedings of the 23rd Annual ACM Conference on Multimedia, pages 1015–1018. ACM Press, 2015. ISBN 978-1-4503-3459-4. doi: 10.1145/2733373.2806390. URL http://dl.acm.org/citation.cfm?doid=2733373.2806390.
- Rajendran et al. [2020] Janarthanan Rajendran, Alexander Irpan, and Eric Jang. Meta-learning requires meta-augmentation. Advances in Neural Information Processing Systems, 33:5705–5715, 2020.
- Rajeswaran et al. [2019] Aravind Rajeswaran, Chelsea Finn, Sham M Kakade, and Sergey Levine. Meta-learning with implicit gradients. Advances in neural information processing systems, 32, 2019.
- Rusu et al. [2019] Andrei A. Rusu, Dushyant Rao, Jakub Sygnowski, Oriol Vinyals, Razvan Pascanu, Simon Osindero, and Raia Hadsell. Meta-learning with latent embedding optimization. In International Conference on Learning Representations, 2019.
- Satorras and Estrach [2018] Victor Garcia Satorras and Joan Bruna Estrach. Few-shot learning with graph neural networks. In International Conference on Learning Representations, 2018.
- Schmidhuber [1987] Jurgen Schmidhuber. Evolutionary principles in self-referential learning. On learning how to learn: The meta-meta-… hook.) Diploma thesis, Institut f. Informatik, Tech. Univ. Munich, 1(2), 1987.
- Snell et al. [2017] Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. Advances in neural information processing systems, 30, 2017.
- Srivastava et al. [2014] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1):1929–1958, 2014.
- Sung et al. [2018] Flood Sung, Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales. Learning to compare: Relation network for few-shot learning. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1199–1208, 2018.
- Van der Maaten and Hinton [2008] Laurens Van der Maaten and Geoffrey Hinton. Visualizing data using t-sne. Journal of machine learning research, 9(11), 2008.
- Verma et al. [2019] Vikas Verma, Alex Lamb, Christopher Beckham, Amir Najafi, Ioannis Mitliagkas, David Lopez-Paz, and Yoshua Bengio. Manifold mixup: Better representations by interpolating hidden states. In International Conference on Machine Learning, pages 6438–6447. PMLR, 2019.
- Vinyals et al. [2016] Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al. Matching networks for one shot learning. Advances in neural information processing systems, 29, 2016.
- Wang et al. [2019] Alex Wang, Amanpreet Singh, Julian Michael, Felix Hill, Omer Levy, and Samuel R. Bowman. GLUE: A multi-task benchmark and analysis platform for natural language understanding. In International Conference on Learning Representations, 2019.
- Williams et al. [2018] Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage challenge corpus for sentence understanding through inference. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pages 1112–1122, 2018.
- Wolf et al. [2020] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, Joe Davison, Sam Shleifer, Patrick von Platen, Clara Ma, Yacine Jernite, Julien Plu, Canwen Xu, Teven Le Scao, Sylvain Gugger, Mariama Drame, Quentin Lhoest, and Alexander M. Rush. Transformers: State-of-the-art natural language processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 38–45, Online, October 2020. Association for Computational Linguistics.
- Yao et al. [2021] Huaxiu Yao, Long-Kai Huang, Linjun Zhang, Ying Wei, Li Tian, James Zou, Junzhou Huang, et al. Improving generalization in meta-learning via task augmentation. In International Conference on Machine Learning, pages 11887–11897. PMLR, 2021.
- Yao et al. [2022] Huaxiu Yao, Linjun Zhang, and Chelsea Finn. Meta-learning with fewer tasks through task interpolation. In International Conference on Learning Representations, 2022.
- Yin et al. [2020] Mingzhang Yin, George Tucker, Mingyuan Zhou, Sergey Levine, and Chelsea Finn. Meta-learning without memorization. In International Conference on Learning Representations, 2020.
- Zaheer et al. [2017] Manzil Zaheer, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Russ R Salakhutdinov, and Alexander J Smola. Deep sets. Advances in neural information processing systems, 30, 2017.
Appendix
Organization
The appendix is organized as follows: In Section A, we provide proofs of the theorem and propositions in Section 4. In Section B, we show additional experimental results — Meta-interpolation with first order MAML, the effect of the number of meta-training and meta-validation tasks, ablation study for location of interpolation, and the effect of the number of tasks for interpolation. In Section C, we provide detailed descriptions of the experimental setup used in the main paper together with the exact data splits for meta-training, meta-validation and meta-testing for all the datasets used. Finally, we specify the exact architecture of the Prototypical Networks used for all the experiments and further describe the Set Transformer in detail in Section C.3. All hyperparameters are specified in Section C.4.
Appendix A Proofs
A.1 Proof of Theorem 1
Proof.
Define , and Define the dimensionality as and From the definition, since we use in both training and testing time, we have
where
In the analysis of the loss functions, without the loss of generality, we can set the permutation on to be the identity since the every combination can be realized by one permutation instead of the two permutations . Therefore, using the definition of the , we can write the corresponding loss by
where
This can be rewritten as:
Thus,
Define the set
Then,
Summarizing the computation so far, we have that
and
with
We now analyze the relationship of and . From this definition, we first compute as follows. By writing and ,
Here, by writing , , and , we can rewrite that
Thus, by letting ,
Using these,
and
Thus, by letting and , and by defining , we have that
For , similarly, by writing and ,
Here, since and , we have and thus,
Therefore, we have that
where and does not depend on the . Using these and by defining , we have that
With these preparation, we are now ready to prove the regularization form. Fix and write . Define a vector such that . Then,
Then, from the results of the calculations above, we have that
Using the assumptions that for all , for any , the -th approximation of is given by
Let to be set later and define
Then,
Define and . For any given , define
Then, we have that
The -th approximation of is given by
where is the -th order derivative of . Here, for any , by defining ,
Then, by using the vectorization of the tensor , we can rewrite this equation as
(9) |
where . By combining these with ,
where
∎
A.2 Proof of Proposition 1
Proof.
We now apply our general regularization form theorem to this special case from the previous paper [50] on the ProtoNet loss. In this special case, we have that and
Here, for with
(10) |
and
(11) |
we recover that
Thus, to instantiate our general theorem to this special case, we only need to compute the and up to the second order approximation.
Define . Since and by Equation 10 and 11,
Thus, we have . Then,
Similarly,
For the second order, by defining the logistic function ,
Similarly,
Therefore, the second approximation of is
where with
where
For the first order term,
where
For the second order term,
Since with the being balanced, the second approximation of becomes:
∎
A.3 Proof of Proposition 2
Proof.
Let be independent uniform random variables taking values in ; i.e., Rademacher variables. We first bound the empirical Rademacher complexity part as follows:
where denotes the Moore–Penrose inverse of . By taking expectation and using this bound on the empirical Rademacher complexity, we now bound the Rademacher complexity as follows:
∎
Appendix B Additional Experiments
Metabolism | ESC50 | |
---|---|---|
Method | 5shot | 5shot |
FOMAML | ||
FOMAML + MLTI | ||
FOMAML + Meta-Interpolation | ||
ProtoNet + Meta-Interpolation |
B.1 First-order MAML
As stated in the introduction, we focus solely on metric based meta-learning due to their efficiency and better empirical performance over the gradient based methods on the few-task meta-learning problem. Moreover it is challenging to combine our method with Model-Agnostic Meta-Learning (MAML) [10] since it yields a tri-level optimization problem which requires differentiating through second order derivatives. Furthermore, tri-level optimization is still known to be a challenging problem [4, 7] and currently an active line of research. Thus, instead of using the original MAML, we perform additional experiments on the ESC-50 and Metabolism dataset with first-order MAML (FOMAML) which approximates the Hessian with a zero matrix. As shown in Table 7, the experimental results show that Meta-Interpolation with first-order MAML outperforms MLTI, which again confirms the general effectiveness of our set-based task augmentation scheme. However, it largely underperforms our original Meta-Interpolation framework with metric-based meta-learning.
Model | 5 Tasks | 10 Tasks | 15 Tasks | 20 Tasks |
---|---|---|---|---|
ProtoNet | ||||
MLTI | ||||
Ours |
Model | 5 Tasks | 10 Tasks | 15 Tasks |
---|---|---|---|
ProtoNet | |||
MLTI | |||
Ours |
B.2 Effect of the number of meta-training and validation tasks
We analyze the effect of the number of meta-training tasks on ESC-50 dataset. As shown in Table B.1, Meta-Interpolation consistently outperforms the baselines by large margins, regardless of the number of the tasks. Furthermore, we report the test accuracy as a function of the number of meta-validation tasks in Table B.1. Although the generalization performance of Meta-Interpolation slightly decreases as we reduce the number of meta-validation tasks, it still outperforms the relevant baselines by a large margin.
B.3 Location of interpolation
Layer | Accuracy |
---|---|
Input Layer | |
Layer 1 | |
Layer 2 | |
Layer 3 |
Contrary to Manifold Mixup, we fix the layer of interpolation as shown in Table 17. Otherwise, we cannot use the same architecture of Set Transformer to interpolate the output of different layers since the hidden dimension of each layer is different. Moreover, we report the test accuracy by changing the layer of interpolation. Interpolating hidden representation of support sets from layer 2, which is the model used in the main experiments on ESC50 dataset, achieves the best performance.
B.4 Effect of Cardinality for Interpolation

Since the set function can handle sets of arbitrary cardinality , we plot the test accuracy on the ESC-50 dataset with varying set sizes from one to five. As shown in Figure 4, for sets of cardinality one, where we do not perform any interpolation, we observe significant degradation in performance. This suggests that the interpolation of instances is crucial for better generalization. On the other hand, for set sizes from two to five, the gain is marginal with increasing cardinality. Furthermore, increasing the set size introduces extra computational overhead. Thus, we set the cardinality to two for task interpolation in all the experiments.
Appendix C Experimental Setup
C.1 Dataset Description
Metabolism
We use the following split for the metabolism dataset:
Meta-Train | |||
Meta-Validation | |||
Meta-Test |
Following Yao et al. [50], we balance each subdataset by selecting 1000 samples from each subdataset. Each data sample is processed by extracting a 1024-bit fingerprint feature from the SMILES representation of each chemical compound using the RDKit [15] library.
Tox21
We use the following split for the metabolism dataset111https://tdcommons.ai/single_pred_tasks/tox/#tox21:
Meta-Train | |||
Meta-Validation | |||
Meta-Test |
NCI
We download the dataset from the github repository 222https://github.com/GRAND-Lab/graph_datasets/tree/master/Graph_Repository and use the following splits for meta-training, meta-validation, and meta-testing:
Meta-Train | |||
Meta-Validation | |||
Meta-Test |
GLUE-SciTail
We use Hugging Face Datasets library [25] to download MNLI, QNLI, SNLI, RTE, and SciTail datasets and tokenize it ELECTRA tokenizer with setting maximum length to 128. We list meta-train, meta-validation, and meta-test split as follows:
Meta-Train | |||
Meta-Validation | |||
Meta-Test |
ESC-50
We download ESC-50 dataset [34] from the github repository333https://github.com/karolpiczak/ESC-50 and use the meta-training, meta-validation, and meta-test split as follows:
Meta-train set:
Meta-validation set:
Meta-test set:
RMNIST
Following Yao et al. [50], we create RMNIST dataset by changing the size, color, and angle of the images from the original MNIST dataset. To be specific, we merge training and test split of MNIST and randomly select 5,600 images for each class and split into 56 subdatasets where each class has 100 examples. We only choose 16/6/10 subdatasets for meta-train, meta-validation, and meta-test split, respectively and do not use the rest of them. Each subdatset with the corresponding composition of image transformations is considered a distinct task. Following are the specific splits.
Meta-train set:
Meta-validation set:
Meta-test set:
Mini-ImageNet-S
Following Yao et al. [50], we reduce the number of meta-training tasks by choosing subset of the original meta-training classes. We specify the classes used for meta-training tasks as follows:
For meta-validation and meta-testing classes, we use the same classes as the original Mini-ImageNet.
CIFAR-100-FS
Similar to Mini-ImageNet-S, we choose 12/16/20 classes for meta-training, meta-validation, meta-test classes. Followings are the specific classes for each split:
Meta-Train | |||
Meta-Validation | |||
Meta-Test |
C.2 Prototypical Network Architecture
We summarize neural network architectures for ProtoNet on each datasets in Table 11, 12, 13, 14, 15, and 16. For Glue-SciTail, we use pretrained ELECTRA-small [8] which we download from Hugging Face [48].
Output Size | Layers |
---|---|
Input Image | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
32 | Flatten |
Output Size | Layers |
---|---|
Input Image | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
800 | Flatten |
Output Size | Layers |
---|---|
Input Image | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
conv2d(, stride 1, padding 1), BatchNorm2D, ReLU, Maxpool(, stride 2) | |
128 | Flatten |
Output Size | Layers |
---|---|
Input SMILE | |
Linear(1024, 500, bias=True), BatchNorm1D, LeakyReLU | |
Linear(500, 500, bias=True), BatchNorm1D, LeakyReLU | |
Linear(500, 500, bias=True) |
Output Size | Layers |
---|---|
Input VGGish feature | |
Linear(650, 500, bias=True), BatchNorm1D, LeakyReLU | |
Linear(500, 500, bias=True), BatchNorm1D, LeakyReLU | |
Linear(500, 500, bias=True) |
Output Size | Layers |
---|---|
Input sentence | |
ElectraModel(“google/electra-small-discriminator”) | |
[CLS] embedding | |
Linear(256, 256, bias=True), ReLU | |
Linear(256, 256, bias=True), ReLU | |
Linear(256, 256, bias=True) |
C.3 Set Transformer
We describe Set Transformer [23], , in more detail. Let be a stack of -dimensional row vectors. Let be weight matrices for self-attention and let be bias vectors for . For encoding an input , we compute self-attention as follows:
(12) |
where is a vector of ones, , and softmax is applied for each row. After self-attention, we add a skip connection with layer normalization [1] as follows:
(13) |
where .
Similarly, we compute another self-attention on top of with
weight matrices for self-attention and let be bias vectors, for .
(14) |
After self-attention, we also add a skip connection after the second self-attention with dropout [41].
(15) | ||||
(16) |
where .
After encoding with two-layers of self-attention, we perform pooling with attention. Let be learnable parameters and be weight matrices for the pooling-attention and let be bias vectors, for . We pool as follows:
(17) |
After pooling, we add another skip connection with dropout as follows:
(18) |
where . Finally, we perform affine-transformation after the pooling as follows:
(19) |
where .
To summarize, Set Transformer is the set function that can handle a set with arbitrary cardinality .
C.4 Hyperparameters
In Table 17 and 18, we summarize all the hyperparameters for each datasets, where MI stands for Mini-ImageNet-S. For CIFAR-100-FS, we use the same hyperparameters for both 1-hot and 5-shot.
Hyperparameters | Metabolism | Tox21 | NCI | GLUE-SciTail | ESC-50 |
learning rate | |||||
optimizer | Adam [21] | Adam | Adam | AdamW [28] | Adam |
scheduler | none | none | none | linear | none |
batch size | 4 | 4 | 4 | 1 | 4 |
query size for meta-training | 10 | 10 | 10 | 10 | 5 |
maximum training iterations | 10,000 | 10,000 | 10,000 | 50,000 | 10,000 |
number of episodes for meta-test | 3,000 | 3,000 | 3,000 | 3,000 | 3,000 |
hyper learning rate | |||||
hyper optimizer | Adam | Adam | Adam | Adam | Adam |
hyper scheduler | linear | linear | linear | linear | linear |
update period | 100 | 100 | 100 | 1,000 | 100 |
input size for set function | 500 | 500 | 500 | 256 | 500 |
hidden size for set function | 1,024 | 1,024 | 500 | 1,024 | 1,024 |
layer for interpolation | 1 | 1 | 1 | 2 | 2 |
iterations for Neumann series | 5 | 5 | 5 | 10 | 5 |
distance metric | Euclidean | Euclidean | Euclidean | Euclidean | Euclidean |
Hyperparameters | RMNIST | MI (1-shot) | MI (5-shot) | CIFAR-100-FS (1-shot, 5-shot) |
---|---|---|---|---|
learning rate | ||||
optimizer | SGD | Adam | Adam | Adam |
scheduler | none | none | none | none |
batch size | 4 | 4 | 4 | 4 |
query size for meta-training | 1 | 15 | 15 | 15 |
maximum training iterations | 10,000 | 50,000 | 50,000 | 50,000 |
number of episodes for meta-test | 3,000 | 3,000 | 3,000 | 3,000 |
hyper learning rate | ||||
hyper optimizer | Adam | Adam | Adam | Adam |
hyper scheduler | linear | linear | linear | linear |
update period | 1000 | 1,000 | 1,000 | 1,000 |
input size for set function | 1,568 | 500 | 500 | 256 |
hidden size for set function | 1,568 | 14,112 | 56,448 | 8,192 |
layer for interpolation | 2 | 2 | 1 | 1 |
iterations for Neumann series | 5 | 5 | 5 | 5 |
distance metric | Euclidean | Euclidean | Euclidean | Euclidean |