This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\@testdef

undefined

Transfer Learning with Kernel Methods

Adityanarayanan Radhakrishnan , , Equal ContributionLaboratory for Information & Decision Systems, and Institute for Data, Systems, and Society, Massachusetts Institute of TechnologyBroad Institute of MIT and Harvard Max Ruiz Luyten 11footnotemark: 1  ,22footnotemark: 2 Neha Prasad22footnotemark: 2 Caroline Uhler22footnotemark: 2 ,33footnotemark: 3
(July 28, 2025)
Abstract

Transfer learning refers to the process of adapting a model trained on a source task to a target task. While kernel methods are conceptually and computationally simple machine learning models that are competitive on a variety of tasks, it has been unclear how to perform transfer learning for kernel methods. In this work, we propose a transfer learning framework for kernel methods by projecting and translating the source model to the target task. We demonstrate the effectiveness of our framework in applications to image classification and virtual drug screening. In particular, we show that transferring modern kernels trained on large-scale image datasets can result in substantial performance increase as compared to using the same kernel trained directly on the target task. In addition, we show that transfer-learned kernels allow a more accurate prediction of the effect of drugs on cancer cell lines. For both applications, we identify simple scaling laws that characterize the performance of transfer-learned kernels as a function of the number of target examples. We explain this phenomenon in a simplified linear setting, where we are able to derive the exact scaling laws. By providing a simple and effective transfer learning framework for kernel methods, our work enables kernel methods trained on large datasets to be easily adapted to a variety of downstream target tasks.

1 Introduction

Transfer learning refers to the machine learning problem of utilizing knowledge from a source task to improve performance on a target task. Recent approaches to transfer learning have achieved tremendous empirical success in many applications including in computer vision [45, 17], natural language processing [40, 43, 16], and the biomedical field [19, 15]. Since transfer learning approaches generally rely on complex deep neural networks, it can be difficult to characterize when and why they work [44]. Kernel methods [46] are conceptually and computationally simple machine learning models that have been found to be competitive with neural networks on a variety of tasks including image classification [3, 29, 42] and drug screening [42]. Their simplicity stems from the fact that training a kernel method involves performing linear regression after transforming the data. There has been renewed interest in kernels due to a recently established equivalence between wide neural networks and kernel methods [25, 2], which has led to the development of modern, neural tangent kernels (NTKs) that are competitive with neural networks. Given their simplicity and effectiveness, kernel methods could provide a powerful approach for transfer learning and also help characterize when transfer learning between a source and target task would be beneficial. However, developing an algorithm for transfer learning with kernel methods for general source and target tasks has been an open problem. In particular, while there is a standard transfer learning approach for neural networks that involves replacing and re-training the last layer of a pre-trained network, there is no known corresponding operation for kernels. The limited prior work on transfer learning with kernels focuses on applications in which the source and target tasks have the same label sets [14, 30, 37]. Examples include predicting stock returns for a given sector based on returns available for other sectors [30] or predicting electricity consumption for certain zones of the United States based on the consumption in other zones [37]. These methods are not applicable to general source and target tasks with differing label dimensions, which includes classical transfer learning applications such as using a model trained to classify between thousands of objects to subsequently classify new objects.

In this work, we present a general framework for performing transfer learning with kernel methods. Unlike prior work, our framework enables transfer learning for kernels regardless of whether the source and target tasks have the same or differing label sets. Furthermore, like for transfer learning methodology for neural networks, our framework allows transferring to a variety of target tasks after training a kernel method only once on a source task. To provide some intuition for our proposed framework, instead of replacing and re-training the last layer of a neural network as is standard for transfer learning using neural networks, our approach for transfer learning using kernels translates to adding a new layer to the end of a neural network.

Refer to caption
Figure 1: Our framework for transfer learning with kernel methods for supervised learning tasks. After training a kernel method on a source task, we transfer the source model to the target task via a combination of projection and translation operations. (a) Projection involves training a second kernel method on the predictions of the source model on the target data, as is shown for image classification between natural images and house numbers. (d) Projection is effective when the predictions of the source model on target examples provide useful information about target labels; e.g., a model trained to classify natural images may be able to distinguish the images of zeros from ones by using the similarity of zeros to balls and ones to poles. (c) Translation involves adding a correction term to the source model, as is shown for predicting the effect of a drug on a cell line. (d) Translation is effective when the predictions of the source model can be additively corrected to match labels in the target data; e.g., the predictions of a model trained to predict the effect of drugs on one cell line may be additively adjustable to predict the effect on new cell lines.

The key components of our transfer learning framework are: Train a kernel method on a source dataset and then apply the following operations to transfer the model to the target task.

  • Projection. We apply the trained source kernel to each sample in the target dataset and then train a secondary model on these source predictions to solve the target task; see Fig. 1a.

  • Translation. When the source and target tasks have the same label sets, we train a correction term that is added to the source model to adapt it to the target task; see Fig. 1c.

Projection is effective when the source model predictions contain information regarding the target labels. We will demonstrate that this is the case in image classification tasks in which the predictions of a classifier trained to distinguish between a thousand objects in ImageNet32 [11] provides information regarding the labels of images in other datasets such as street view house numbers (SVHN) [33]; see Fig. 1b. In particular, we will show across 23 different source and target task combinations that kernels transferred using our approach achieve up to a 10%10\% increase in accuracy over kernels trained on target tasks directly.

On the other hand, translation is effective when the predictions of the source model can be corrected to match the labels of the target task via an additive term. We will show that this is the case in virtual drug screening in which a model trained to predict the effect of a drug on one cell line can be adjusted to capture the effect on a new cell line; see Fig. 1d. In particular, we will show that our transfer learning approach provides an improvement to prior kernel method predictors [42] even when transferring to cell lines and drugs not present in the source task.

Interestingly, we observe that for both applications, image classification and virtual drug screening, transfer learned kernel methods follow simple scaling laws; i.e., how the number of available target samples effects the performance on the target task can be accurately modelled. As a consequence, our work provides a simple method for estimating the impact of collecting more target samples on the performance of the transfer learned kernel predictors. In the simplified setting of transfer learning with linear kernel methods we are able to mathematically derive the scaling laws, thereby providing a mathematical basis for the empirical observations. Overall, our work demonstrates that transfer learning with kernel methods between general source and target tasks is possible and demonstrates the simplicity and effectiveness of the proposed method on a variety of important applications.

2 Results

In the following, we present our framework for transfer learning with kernel methods more formally. Since kernel methods are fundamental to this work, we start with a brief review.

Given training examples X=[x(1),,x(n)]d×nX=[x^{(1)},\ldots,x^{(n)}]\in\mathbb{R}^{d\times n}, corresponding labels y=[y(1),,y(n)]1×ny=[y^{(1)},\ldots,y^{(n)}]\in\mathbb{R}^{1\times n}, a standard nonlinear approach to fitting the training data is to train a kernel machine [46]. This approach involves first transforming the data, {x(i)}i=1n\{x^{(i)}\}_{i=1}^{n}, with a feature map, ψ\psi, and then performing linear regression. To avoid defining and working with feature maps explicitly, kernel machines rely on a kernel function, K:d×dK:\mathbb{R}^{d}\times\mathbb{R}^{d}\to\mathbb{R}, which corresponds to taking inner products of the transformed data, i.e., K(x(i),x(j))=ψ(x(i)),ψ(x(j))K(x^{(i)},x^{(j)})=\langle\psi(x^{(i)}),\psi(x^{(j)})\rangle. The trained kernel machine predictor uses the kernel instead of the feature map and is given by:

f^(x)=αK(X,x),where α=argminw1×nywKn22,\displaystyle\hat{f}(x)=\alpha K(X,x),\leavevmode\nobreak\ \leavevmode\nobreak\ \textrm{where }\leavevmode\nobreak\ \alpha=\operatorname*{arg\,min}_{w\in\mathbb{R}^{1\times n}}\|y-wK_{n}\|_{2}^{2}\leavevmode\nobreak\ , (1)

and Knn×nK_{n}\in\mathbb{R}^{n\times n} with (Kn)i,j=K(x(i),x(j))(K_{n})_{i,j}=K(x^{(i)},x^{(j)}) and K(X,x)nK(X,x)\in\mathbb{R}^{n} with K(X,x)i=K(x(i),x)K(X,x)_{i}=K(x^{(i)},x). Note that for datasets with over 10510^{5} samples, computing the exact minimizer α\alpha is computationally prohibitive, and we instead use fast, approximate iterative solvers such as EigenPro [31]. For a more detailed description of kernel methods see Appendix A.

For the experiments in this work, we utilize a variety of kernel functions. In particular, we consider the classical Laplace kernel given by K(x,x~)=exp(Lxx~2)K(x,\tilde{x})=\exp\left(-L\|x-\tilde{x}\|_{2}\right), which is a standard benchmark kernel that has been widely used for image classification and speech recognition [31]. In addition, we consider recently discovered kernels that correspond to infinitely wide neural networks: While there is an emerging understanding that increasingly wider neural networks generalize better [5, 32], such models are generally computationally difficult to train. Remarkably, recent work identified conditions under which neural networks in the limit of infinite width implement kernel machines; the corresponding kernel is known as the Neural Tangent Kernel (NTK[25]. In the following, we use the NTK corresponding to training an infinitely wide ReLU fully connected network [25] and also the convolutional NTK (CNTK) corresponding to training an infinitely wide ReLU convolutional network [2].111We chose to use the CNTK without global average pooling (GAP) [2] for our experiments. While the CNTK model with GAP as well as the models considered in [8] give higher accuracy on image datasets, they are computationally prohibitive to compute for our large-scale experiments. For example, a CNTK with GAP is estimated to take 1200 GPU hours for 50k training samples [29].

Unlike the usual supervised learning setting where we train a predictor on a single domain, we will consider the following transfer learning setting from [50], which involves two domains: (1) a source with domain 𝒳s\mathcal{X}_{s} and data distribution s\mathbb{P}_{s}; and (2) a target with domain 𝒳t\mathcal{X}_{t} and data distribution t\mathbb{P}_{t}. The goal is to learn a model for a target task ft:𝒳t𝒴tf_{t}:\mathcal{X}_{t}\to\mathcal{Y}_{t} by making use of a model trained on a source task fs:𝒳s𝒴sf_{s}:\mathcal{X}_{s}\to\mathcal{Y}_{s}. We let csc_{s} and ctc_{t} denote the dimensionality of 𝒴s\mathcal{Y}_{s} and 𝒴t\mathcal{Y}_{t} respectively, i.e. for image classification these denote the number of classes in the source and target. Lastly, we let (Xs,ys)𝒳sns×𝒴sns(X_{s},y_{s})\in\mathcal{X}_{s}^{n_{s}}\times\mathcal{Y}_{s}^{n_{s}} and (Xt,yt)𝒳tnt×𝒴tnt(X_{t},y_{t})\in\mathcal{X}_{t}^{n_{t}}\times\mathcal{Y}_{t}^{n_{t}} denote the source and target dataset, respectively. Throughout this work, we assume that the source and target domains are equal (𝒳s=𝒳t\mathcal{X}_{s}=\mathcal{X}_{t}), but that the data distributions differ (st\mathbb{P}_{s}\neq\mathbb{P}_{t}).

Our work is concerned with the recovery of ftf_{t} by transferring a model, f^s\hat{f}_{s}, that is learned by training a kernel machine on the source dataset. To enable transfer learning with kernels, we propose the use of two methods, projection and translation. We first describe these methods individually and demonstrate their performance on transfer learning for image classification using kernel methods. For each method, we empirically establish scaling laws relating the quantities ns,nt,cs,ctn_{s},n_{t},c_{s},c_{t} to the performance boost given by transfer learning, and we also derive explicit scaling laws when ft,fsf_{t},f_{s} are linear maps. We then utilize a combination of the two methods to perform transfer learning in an application to virtual drug screening. Code and hardware details are available in Appendix L.

2.1 Transfer learning via projection

Projection involves learning a map from source model predictions to target labels and is thus particularly suited for situations where the number of labels in the source task csc_{s} is much larger than the number of labels in the target task ctc_{t}.

Definition 1.

Given a source dataset (Xs,ys)(X_{s},y_{s}) and a target dataset (Xt,yt)(X_{t},y_{t}), the projected predictor, f^t\hat{f}_{t}, is given by:

f^t(x)=f^p(f^s(x)), where f^p:=argmin{f:𝒴s𝒴t}ytf(f^s(Xt))2,\displaystyle\hat{f}_{t}(x)=\hat{f}_{p}(\hat{f}_{s}(x)),\textrm{ where }\leavevmode\nobreak\ \hat{f}_{p}:=\operatorname*{arg\,min}_{\{f:\mathcal{Y}_{s}\to\mathcal{Y}_{t}\}}\|y_{t}-f(\hat{f}_{s}(X_{t}))\|^{2}, (2)

where f^s\hat{f}_{s} is a predictor trained on the source dataset.222When there are infinitely many possible values for the parameterized function f^p\hat{f}_{p}, we consider the minimum norm solution.

While Definition 1 is applicable to any machine learning method, we focus on predictors f^s\hat{f}_{s} and f^p\hat{f}_{p} parameterized by kernel machines given their conceptual and computational simplicity. As illustrated in Fig. 1a and b, projection is effective when the predictions of the source model already provide useful information for the target task.

Refer to caption
Figure 2: Analysis of transfer learning with kernels trained on ImageNet32 to CIFAR10, Oxford 102 Flowers, DTD, and a subset of SVHN. All curves in (b,c) are averaged over 3 random seeds. (a) Comparison of the transferred kernel predictor test accuracy (green) to the test accuracy of the baseline kernel predictors trained directly on the target tasks (red). In all cases, the transferred kernel predictors outperform the baseline predictors and the difference in performance is as high as 10%10\%. (b) Test accuracy of the transferred and baseline predictors as a function of the number of target examples. These curves, which quantitatively describe the benefit of collecting more target examples, follow simple logarithmic trends (R2>.95R^{2}>.95). (c) Performance of the transferred kernel methods decreases when increasing the number of source classes but keeping the total number of source examples fixed. Corresponding plots for DTD and SVHN are in Appendix Fig. 6.

Improving kernel-based image classifier performance with projection. We now demonstrate the effectiveness of projected kernel predictors for image classification. In particular, we first train kernels to classify among 1000 objects across 1.28 million images in ImageNet32 and then transfer these models to 4 different target image classification datasets: CIFAR10 [28], Oxford 102 Flowers [35], Describable Textures Dataset (DTD) [12], and SVHN [33]. We selected these datasets since they cover a variety of transfer learning settings, i.e. all of the CIFAR10 classes are in ImageNet32, ImageNet32 contains only 2 flower classes, and none of DTD and SVHN classes are in ImageNet32. A full description of the datasets is provided in Appendix B.

For all datasets, we compare the performance of 3 kernels (the Laplace kernel, NTK, and CNTK) when trained just on the target task, i.e. the baseline predictor, and when transferred via projection from ImageNet32. Training details for all kernels are provided in Appendix C. In Fig. 2a, we showcase the improvement of projected kernel predictors over baseline predictors across all datasets and kernels. We observe that projection yields a sizeable increase in accuracy (up to 10%10\%) on the target tasks, thereby highlighting the effectiveness of this method. It is remarkable that this performance increase is observed even for transferring to Oxford 102 Flowers or DTD, datasets that have little to no overlap with images in ImageNet32.

In Appendix Fig. 5a, we compare our results with those of a finite-width neural network analog of the (infinite-width) CNTK where all layers of the source network are fine-tuned on the target task using the standard cross-entropy loss [20] and the Adam optimizer [27]. We observe that the performance gap between transfer-learned finite-width neural networks and the projected CNTK is largely influenced by the performance gap between these models on ImageNet32. In fact, in Appendix Fig. 5a, we show that finite-width neural networks trained to the same test accuracy on ImageNet32 as the (infinite-width) CNTK yield lower performance than the CNTK when transferred to target image classification tasks.

The computational simplicity of kernel methods allows us to compute scaling laws for the projected predictors. In Fig. 2b, we analyze how the performance of projected kernel methods varies as a function of the number of target examples, ntn_{t}, for CIFAR10 and Oxford 102 Flowers. The results for DTD and SVHN are presented in Appendix Fig. 6a and b. For all target datasets, we observe that the accuracy of the projected predictors follows a simple logarithmic trend given by the curve alognt+ba\log n_{t}+b for constants a,ba,b (R2R^{2} values on all datasets are above 0.950.95). By fitting this curve on the accuracy corresponding to just the smallest five values of ntn_{t}, we are able to predict the accuracy of the projected predictors within 2% of the reported accuracy for large values of ntn_{t} (see Appendix D and Appendix Fig. 8). The robustness of this fit across many target tasks illustrates the practicality of the transferred kernel methods for estimating the number of target examples needed to achieve a given accuracy. Additional results on the scaling laws upon varying the number of source examples per class are presented in Appendix Fig. 7 for transferring between ImageNet32 and CIFAR10. In general, we observe that the performance increases as the number of source training examples per class increases, which is expected given the similarity of source and target tasks.

Lastly, we analyze the impact of increasing the number of classes while keeping the total number of source training examples fixed at 4040k. Fig. 2c shows that having few samples for each class can be worse than having a few classes with many samples. This may be expected for datasets such as CIFAR10, where the classes overlap with the ImageNet32 classes: having few classes with more examples that overlap with CIFAR10 should be better than having many classes with fewer examples per class and less overlap with CIFAR10. A similar trend can be observed for DTD, but interestingly, the trend differs for SVHN, indicating that SVHN images can be better classified by projecting from a variety of ImageNet32 classes (see Appendix Fig. 6).

2.2 Transfer learning via translation

While projection involves composing a map with the source model, the second component of our framework, translation, involves adding a map to the source model as follows.

Definition 2.

Given a source dataset (Xs,ys)(X_{s},y_{s}) and a target dataset (Xt,yt)(X_{t},y_{t}), the translated predictor, f^t\hat{f}_{t}, is given by:

f^t(x)=f^s(x)+f^c(x), where f^c=argmin{f:𝒳t𝒴t}ytf^s(Xt)f(Xt)2,\displaystyle\hat{f}_{t}(x)=\hat{f}_{s}(x)+\hat{f}_{c}(x),\textrm{ where }\leavevmode\nobreak\ \hat{f}_{c}=\operatorname*{arg\,min}_{\{f:\mathcal{X}_{t}\to\mathcal{Y}_{t}\}}\|y_{t}-\hat{f}_{s}(X_{t})-f(X_{t})\|^{2}, (3)

where f^s\hat{f}_{s} is a predictor trained on the source dataset.333When there are infinitely many possible values for the parameterized function f^c\hat{f}_{c}, we consider the minimum norm solution.

Translated predictors correspond to first utilizing the trained source model directly on the target task and then applying a correction, f^c\hat{f}_{c}, which is learned by training a model on the corrected labels, ytf^s(Xt)y_{t}-\hat{f}_{s}(X_{t}). Like for the projected predictors, translated predictors can be implemented using any machine learning model, including kernel methods. When the predictors f^s\hat{f}_{s} and f^c\hat{f}_{c} are parameterized by linear models, translated predictors correspond to training a target predictor with weights initialized by those of the trained source predictor (proof in Appendix J). We note that training translated predictors is also a new form of boosting [9] between the source and target dataset, since the correction term accounts for the error of the source model on the target task. Lastly, we note that while the formulation given in Definition 2 requires the source and target tasks to have the same label dimension, projection and translation can be naturally combined to overcome this restriction.

Refer to caption
Figure 3: Transferring kernel methods from CIFAR10 to adapt to 19 different corruptions in CIFAR10-C. (a) Test accuracy of baseline kernel method (red), using source predictor given by directly applying the kernel trained on CIFAR10 to CIFAR10-C (gray), and transferred kernel method (green). The transferred kernel method outperforms the other models on all 19 corruptions and even improves on the baseline kernel method when the source predictor exhibits a decrease in performance. Additional results are presented in Appendix Fig. 10. (b) Performance of the transferred and baseline kernel predictors as a function of the number of target examples. The transferred kernel method can outperform both source and baseline predictors even when transferred using as little as 200200 target examples.

Improving kernel-based image classifier performance with translation. We now demonstrate that the translated predictors are particularly well-suited for correcting kernel methods to handle distribution shifts in images. Namely, we consider the task of transferring a source model trained on CIFAR10 to corrupted CIFAR10 images in CIFAR10-C [22]. CIFAR10-C consists of the test images in CIFAR10, but the images are corrupted by one of 19 different perturbations such as adjusting image contrast and introducing natural artifacts such as snow or frost. In our experiments, we select the 1010k images of CIFAR10-C with the highest level of perturbation, and we reserve 99k images of each perturbation for training and 11k images for testing. In Appendix Fig. 9, we additionally analyze translating kernels from subsets of ImageNet32 to CIFAR10.

Again, we compare the performance of the three kernel methods considered for projection, but along with the accuracy of the translated predictor and baseline predictor, we also report the accuracy of the source predictor, which is given by using the source model directly on the target task. In Fig. 3a and Appendix Fig. 10, we show that the translated predictors outperform the baseline and source predictors on all 19 perturbations. Interestingly, even for corruptions such as contrast and fog where the source predictor is worse than the baseline predictor, the translated predictor outperforms all other kernel predictors by up to 11%11\%. In Appendix Fig. 10, we show that for these corruptions, the translated kernel predictors also outperform the projected kernel predictors trained on CIFAR10. In Appendix Fig. 5b, we additionally compare with the performance of a finite-width analog of the CNTK by fine-tuning all layers on the target task with cross-entropy loss and the Adam optimizer. We observe that the translated kernel methods outperform the corresponding neural networks. Remarkably kernels translated from CIFAR10 can even outperform fine-tuning a neural network pre-trained on ImageNet32 for several perturbations (see Appendix Fig. 5c).

Analogously to our analysis of the projected predictors, we visualize how the accuracy of the translated predictors is affected by the number of target examples, ntn_{t}, for a subset of corruptions shown in Fig. 3b. We observe that the performance of the translated predictors is heavily influenced by the performance of the source predictor. For example, as shown in Fig. 3b for the brightness perturbation, where the source predictor already achieves an accuracy of 60.80%60.80\%, the translated predictors achieve an accuracy of above 60%60\% when trained on only 1010 target training samples. For the examples of the contrast and fog corruptions, Fig. 3b also shows that very few target examples allow the translated predictors to outperform the source predictors (e.g., by up to 5%5\% for only 200200 target examples). Overall, our results showcase that translation is effective at adapting kernel methods to distribution shifts in image classification.

2.3 Transfer learning via projection and translation in virtual drug screening

Refer to caption
Figure 4: Transferring the NTK trained to predict gene expression for given drug and cell line combinations in CMAP to new drug and cell line combinations. (a, b) The transfer learned NTK (green) outperforms imputation by mean over cell line (gray) and previous NTK baseline predictors from [42] across R2R^{2}, cosine similarity, and Pearson r metrics. All results are averaged over the performance on 5 cell lines and are stratified by whether or not the target data contains drugs that are present in the source data. (c, d) The transferred kernel method performance follows a logarithmic trend (R2>.9R^{2}>.9) as a function of the number of target examples and exhibits a better scaling coefficient than the baselines. The results are averaged over 5 cell lines. (e, f) Visualization of the performance of the transferred NTK in relation to the top two principal components of gene expression for target drug and cell line combinations. The performance of the NTK is generally lower for cell and drug combinations that are further from the control gene expression for a given cell line. Visualizations for the remaining 3 cell lines are presented in Appendix Fig. 11.

We now demonstrate the effectiveness of projection and translation for the use of kernel methods for virtual drug screening. A common problem in drug screening is that experimentally measuring many different drug and cell line combinations is both costly and time consuming. The goal of virtual drug screening approaches is to computationally identify promising candidates for experimental validation. Such approaches involve training models on existing experimental data to then impute the effect of drugs on cell lines for which there was no experimental data.

The CMAP dataset [47] is a large-scale, publicly available drug screen containing measurements of 978 landmark genes for 116,228 combinations of 20,336 drugs (molecular compounds) and 70 cell lines. This dataset has been an important resource for drug screening[41, 7].444CMAP also contains data on genetic perturbations; but in this work, we focus on imputing the effect of chemical perturbations only. Prior work for virtual drug screening demonstrated the effectiveness of low-rank tensor completion and nearest neighbor predictors for imputing the effect of unseen drug and cell line combinations in CMAP [23]. However, these methods crucially rely on the assumption that for each drug there is at least one measurement for every cell line, which is not the case when considering new chemical compounds. To overcome this issue, recent work [42] introduced kernel methods for drug screening using the NTK to predict gene expression vectors from drug and cell line embeddings, which capture the similarity between drugs and cell lines.

In the following, we demonstrate that the NTK predictor can be transferred to improve gene expression imputation for drug and cell line combinations, even in cases where neither the particular drug nor the particular cell line were available when training the source model. To utilize the framework of [42], we use the control gene expression vector as cell line embedding and the 1024 bit circular fingerprints from [1] as drug embedding. All pre-processing of the CMAP gene expression vectors is described in Appendix E. For the source task, we train the NTK to predict gene expression for the 54,444 drug and cell line combinations corresponding to the 65 cell lines with the least drug availability in CMAP. We then impute the gene expression for each of the 5 cell lines (A375, A549, MCF7, PC3, VCAP) with the most drug availability. We chose these data splits in order to have sufficient target samples to analyze model performance as a function of the number of target samples. In our analysis of the transferred NTK, we always consider transfer to a new cell line, and we stratify by whether a drug in the target task was already available in the source task. For this application we combine projection and translation into one predictor as follows.

Definition 3.

Given a source dataset (Xs,ys)(X_{s},y_{s}) and a target dataset (Xt,yt)(X_{t},y_{t}), the projected and translated predictor, f^pt\hat{f}_{pt}, is given by:

f^pt(x)=f^([f^s(x)|x]), where f^=argminf:𝒴s×𝒳t𝒴sytf([f^s(Xt)|Xt])2,\displaystyle\hat{f}_{pt}(x)=\hat{f}\left(\left[\hat{f}_{s}(x)\leavevmode\nobreak\ |\leavevmode\nobreak\ x\right]\right),\textrm{ where }\leavevmode\nobreak\ \hat{f}=\operatorname*{arg\,min}_{f:\mathcal{Y}_{s}\times\mathcal{X}_{t}\to\mathcal{Y}_{s}}\left\|y_{t}-f\left(\left[\hat{f}_{s}(X_{t})\leavevmode\nobreak\ |\leavevmode\nobreak\ X_{t}\right]\right)\right\|^{2},

where f^s\hat{f}_{s} is a predictor trained on the source dataset and [f^s(x)|x]𝒴s×𝒳t\left[\hat{f}_{s}(x)\leavevmode\nobreak\ |\leavevmode\nobreak\ x\right]\in\mathcal{Y}_{s}\times\mathcal{X}_{t} is the concatenation of f^s(x)\hat{f}_{s}(x) and xx.

Note that if we omit x,Xtx,X_{t} in the concatenation above, we get the projected predictor and if we omit f^s\hat{f}_{s} in the concatenation above, we get the translated predictor. Generally, f^s(x)\hat{f}_{s}(x) and xx can correspond to different modalities (e.g., class label vectors and images), but in the case of drug screening, both correspond to gene expression vectors of the same dimension. Thus, combining projection and translation is natural in this context.

Fig. 4a and b show that the transferred kernel predictors outperform both, the baseline model from [42] as well as imputation by mean (over each cell line) gene expression across three different metrics (R2R^{2}, cosine similarity, and Pearson r value) on both tasks (i.e., transferring to drugs that were seen in the source task as well as completely new drugs). All metrics considered are described in Appendix F. All training details are presented in Appendix C. Interestingly, the transferred kernel methods provide a boost over the baseline kernel methods even when transferring to new cell lines and new drugs. But as expected, we note that the increase in performance is greater when transferring to drug and cell line combinations for which the drug was available in the source task. Fig. 4c and d show that the transferred kernels again follow simple logarithmic scaling laws (fitting a logarithmic model to the red and green curves yields R2>0.9R^{2}>0.9). We note that the transferred NTKs have better scaling coefficients than the baseline models, thereby implying that the performance gap between the transferred NTK and the baseline NTK grows as more target examples are collected. In Fig. 4e and f, we visualize the performance of the transferred NTK in relation to the top 2 principal components of gene expression for drug and cell line combinations. We generally observe that the performance of the NTK is lower for cell and drug combinations that are further from the control, i.e., the unperturbed state. Plots for the other 3 cell lines are presented in Appendix Fig. 11. In Appendix G and Appendix Fig. 12, we show that this approach can also be used for other transfer learning tasts related to virtual drug screening. In particular, we show that the imputed gene expression vectors can be transferred to predict the viability of a drug and cell line combination in the large-scale, publicly available Cancer Dependency Map (DepMap) dataset [13].

2.4 Theoretical analysis of projection and translation in the linear setting

In the following, we provide explicit scaling laws for the performance of projected and translated kernel methods in the linear setting, thereby providing a mathematical basis for the empirical observations in the previous sections.

Derivation of the scaling law for the projected predictor in the linear setting. We assume that 𝒳=d\mathcal{X}=\mathbb{R}^{d}, 𝒴s=cs\mathcal{Y}_{s}=\mathbb{R}^{c_{s}}, 𝒴t=ct\mathcal{Y}_{t}=\mathbb{R}^{c_{t}} and that fsf_{s} and ftf_{t} are linear maps, i.e., fs=ωscs×df_{s}=\omega_{s}\in\mathbb{R}^{c_{s}\times d} and ft=ωtct×df_{t}=\omega_{t}\in\mathbb{R}^{c_{t}\times d}. The following results provide a theoretical foundation for the empirical observations regarding the role of the number of source classes and the number of source samples for transfer learning shown in Fig. 2 as well as in [24]. In particular, we will derive scaling laws for the risk, or expected test error, of the projected predictor as a function of the number of source examples, nsn_{s}, target examples, ntn_{t}, and number of source classes, csc_{s}. We note that the risk of a predictor is a standard object of study for understanding generalization in statistical learning theory [48] and defined as follows.

Definition 4.

Let \mathbb{P} be a probability density on d\mathbb{R}^{d} and let x,x(i)i.i.d.x,x^{(i)}\overset{i.i.d.}{\sim}\mathbb{P} for i=1,2,ni=1,2,\ldots n. Let X=[x(1),,x(n)]d×nX=[x^{(1)},\ldots,x^{(n)}]\in\mathbb{R}^{d\times n} and y=[wx(1),wx(n)]c×ny=[w^{*}x^{(1)},\ldots w^{*}x^{(n)}]\in\mathbb{R}^{c\times n} for wc×dw^{*}\in\mathbb{R}^{c\times d}. The risk of a predictor w^\hat{w} trained on the samples (X,y)(X,y) is given by

(w^)=𝔼x,X[wxw^xF2].\displaystyle\mathcal{R}(\hat{w})=\mathbb{E}_{x,X}[\|w^{*}x-\hat{w}x\|_{F}^{2}]\leavevmode\nobreak\ . (4)

By understanding how the risk scales with the number of source examples, target examples, and source classes, we can characterize the settings in which transfer learning is beneficial. As is standard in analyses of the risk of over-parameterized linear regression [18, 6, 4, 21], we consider the risk of the minimum norm solution given by

w^=argminwywXF2, i.e., w^=yX,\displaystyle\hat{w}=\operatorname*{arg\,min}_{w}\|y-wX\|_{F}^{2},\leavevmode\nobreak\ \textrm{ i.e., }\leavevmode\nobreak\ \hat{w}=yX^{\dagger}\leavevmode\nobreak\ ,

where XX^{\dagger} is the Moore-Penrose inverse of XX. Theorem 1 establishes a closed form for the risk of the projected predictor ω^pω^s\hat{\omega}_{p}\hat{\omega}_{s}, thereby giving a closed form for the scaling law for transfer learning in the linear setting; the proof is given in Appendix H.

Theorem 1.

Let 𝒳=d\mathcal{X}=\mathbb{R}^{d}, 𝒴s=cs\mathcal{Y}_{s}=\mathbb{R}^{c_{s}}, 𝒴t=ct\mathcal{Y}_{t}=\mathbb{R}^{c_{t}}, and let ω^s=ysXs\hat{\omega}_{s}=y_{s}X_{s}^{\dagger} and ω^p=yt(ω^sXt)\hat{\omega}_{p}=y_{t}(\hat{\omega}_{s}X_{t})^{\dagger}. Assuming that s\mathbb{P}_{s} and t\mathbb{P}_{t} are independent, isotropic distributions on d\mathbb{R}^{d}, then the risk (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is given by

(ω^pω^s)=[(C1+C2K1)(1ntd)+(1C1C2)]ωtF2+C2K2ε,\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=\left[\left(C_{1}+C_{2}K_{1}\right)\left(1-\frac{n_{t}}{d}\right)+\left(1-C_{1}-C_{2}\right)\right]||\omega_{t}||_{F}^{2}+C_{2}K_{2}\varepsilon,

where ε=ωt(Id×dωsωs)F2\varepsilon=||\omega_{t}(I_{d\times d}-\omega_{s}^{\dagger}\omega_{s})||_{F}^{2} and

C1\displaystyle C_{1} =nscs(dns)d(d1)(d+2),C2=ns[d(ns+1)2]d(d1)(d+2),\displaystyle=\frac{n_{s}c_{s}(d-n_{s})}{d(d-1)(d+2)}\quad\textrm{,}\quad C_{2}=\frac{n_{s}\left[d(n_{s}+1)-2\right]}{d(d-1)(d+2)}\leavevmode\nobreak\ ,
K1\displaystyle K_{1} =1nt(dcs)(d1)(d+2),K2=ntd+nt(dnt)(d1)(d+2).\displaystyle=1-\frac{n_{t}(d-c_{s})}{(d-1)(d+2)}\quad\textrm{,}\quad K_{2}=\frac{n_{t}}{d}+\frac{n_{t}(d-n_{t})}{(d-1)(d+2)}\leavevmode\nobreak\ .

The ε\varepsilon term in Theorem 1 quantifies the similarity between the source and target tasks. For example, if there exists a linear map ωp\omega_{p} such that ωpωs=ωt\omega_{p}\omega_{s}=\omega_{t}, then ε=0\varepsilon=0. In the context of classification, this can occur if the target classes are a strict subset of the source classes. Since transfer learning is typically performed between source and target tasks that are similar, we expect ε\varepsilon to be small. To gain more insights into the behavior of transfer learning using the projected predictor, the following corollary considers the setting where dd\to\infty in Theorem 1; the proof is given in Appendix I.

Corollary 1.

Let S=nsd,T=ntd,C=csdS=\frac{n_{s}}{d},T=\frac{n_{t}}{d},C=\frac{c_{s}}{d} and assume ωtF=Θ(1)\|\omega_{t}\|_{F}=\Theta(1). Under the setting of Theorem 1, if S,T,C<S,T,C<\infty as dd\to\infty, then:

  1. a)

    (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is monotonically decreasing for S[0,1]S\in[0,1] if ε<(1C)ωtF\varepsilon<(1-C)\|\omega_{t}\|_{F}.

  2. b)

    If 2S1ST<02S-1-ST<0, then (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) decreases as CC increases.

  3. c)

    If S=1S=1, then (ω^pω^s)=(1T+TC)(ω^t)+εT(2T)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=(1-T+TC)\mathcal{R}(\hat{\omega}_{t})+\varepsilon T(2-T).

  4. d)

    If S=1S=1 and T,C=Θ(δ)T,C=\Theta(\delta), then (ω^pω^s)=(12T)ωtF2+2Tε+Θ(δ2).\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=(1-2T)\|\omega_{t}\|_{F}^{2}+2T\varepsilon+\Theta(\delta^{2}).

Remarks. Corollary 1 not only formalizes several intuitions regarding transfer learning, but also theoretically corroborates surprising dependencies on the number of source examples, target examples, and source classes that were empirically observed in Fig. 2 for kernels and in [24] for convolutional networks. First, Corollary 1a implies that increasing the number of source examples is always beneficial for transfer learning when the source and target tasks are related (ε0\varepsilon\approx 0), which matches intuition. Next, Corollary 1b implies that increasing the number of source classes while leaving the number of source examples fixed can decrease performance (i.e. if 2S1ST>02S-1-ST>0), even for similar source and target tasks satisfying ε0\varepsilon\approx 0. This matches the experiments in Fig. 2c, where we observed that increasing the number of source classes when keeping the number of source examples fixed can be detrimental to the performance. This is intuitive for transferring from ImageNet32 to CIFAR10, since we would be adding classes that are not as useful for predicting objects in CIFAR10. Corollary 1c implies that when the source and target task are similar and the number of source classes is less than the data dimension, transfer learning with the projected predictor is always better than training only on the target task. Moreover, if the number of source classes is finite (C=0C=0), Corollary 1c implies that the risk of the projected predictor decreases an order of magnitude faster than the baseline predictor. In particular, the risk of the baseline predictor is given by (1T)ωt2(1-T)\|\omega_{t}\|^{2}, while that of the projected predictor is given by (1T)2ωt2(1-T)^{2}\|\omega_{t}\|^{2}. Note also that when the number of target samples is small relative to the dimension, Corollary 1c implies that decreasing the number of source classes has minimal effect on the risk. Lastly, Corollary 1d implies that when TT and CC are small, the risk of the projected predictor is roughly that of a baseline predictor trained on twice the number of samples.

Derivation of the scaling law for the translated predictor in the linear setting. Analogously to the case for projection, we analyze the risk of the translated predictor when ω^s\hat{\omega}_{s} is the minimum norm solution to ysωXsF2\|y_{s}-\omega X_{s}\|_{F}^{2} and ω^c\hat{\omega}_{c} is the minimum norm solution to ytω^sXtωXtF2\|y_{t}-\hat{\omega}_{s}X_{t}-\omega X_{t}\|_{F}^{2}.

Theorem 2.

Let 𝒳=d\mathcal{X}=\mathbb{R}^{d}, 𝒴s=cs\mathcal{Y}_{s}=\mathbb{R}^{c_{s}}, 𝒴t=ct\mathcal{Y}_{t}=\mathbb{R}^{c_{t}}, and let ω^t=ω^s+ω^c\hat{\omega}_{t}=\hat{\omega}_{s}+\hat{\omega}_{c} where ω^s=ysXs\hat{\omega}_{s}=y_{s}X_{s}^{\dagger} and ω^c=(ytω^sXt)Xt\hat{\omega}_{c}=(y_{t}-\hat{\omega}_{s}X_{t})X_{t}^{\dagger}. Assuming that s\mathbb{P}_{s} and t\mathbb{P}_{t} are independent, isotropic distributions on d\mathbb{R}^{d}, the the risk (ω^t)\mathcal{R}(\hat{\omega}_{t}) is given by

(ω^t)=[ωsωtF2ωtF2+(1nsd)(1ωsωtF2ωtF2)](ω^b),\mathcal{R}(\hat{\omega}_{t})=\left[\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}+\left(1-\frac{n_{s}}{d}\right)\left(1-\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}\right)\right]\mathcal{R}(\hat{\omega}_{b}),

where ω^b=ytXt\hat{\omega}_{b}=y_{t}X_{t}^{\dagger} is the baseline predictor.

The proof is given in Appendix K. Theorem 2 formalizes several intuitions regarding when translation is beneficial. In particular, we first observe that if the source model ωs\omega_{s} is recovered exactly (i.e. ns=dn_{s}=d), then the risk of the translated predictor is governed by the distance between the oracle source model and target model, i.e., ωsωt\|\omega_{s}-\omega_{t}\|. Hence, the translated predictor generalizes better than the baseline predictor if the source and target models are similar. In particular, by flattening the matrices ωs\omega_{s} and ωt\omega_{t} into vectors and assuming ωs=ωt\|\omega_{s}\|=\|\omega_{t}\|, the translated predictor outperforms the baseline predictor if the angle between the flattened ωs\omega_{s} and ωt\omega_{t} is less than π4\frac{\pi}{4}. On the other hand, when there are no source samples, the translated predictor is exactly the baseline predictor and the corresponding risks are equivalent. In general, we observe that the risk of the translated predictor is simply a weighted average between the baseline risk and the risk in which the source model is recovered exactly.

Comparing Theorem 2 to Theorem 1, we note that the projected predictor and the translated predictor generalize based on different quantities. In particular, in the case when ns=dn_{s}=d, the risk of the translated predictor is a constant multiple of the baseline risk while the risk of the projected predictor is a multiple of the baseline risk that decreases with ntn_{t}. Hence, depending on the distance between ωs\omega_{s} and ωt\omega_{t}, the translated predictor can outperform the projected predictor or vice-versa. As a simple example consider the setting where ωs=ωt\omega_{s}=\omega_{t}, ns=dn_{s}=d, and nt,cs<dn_{t},c_{s}<d; then the translated predictor achieves 0 risk while the projected predictor achieves non-zero risk. When 𝒴s=𝒳t\mathcal{Y}_{s}=\mathcal{X}_{t}, we suggest combining the projected and translated predictors, as we did in the case of virtual drug screening. Otherwise, our results suggest using the translated predictor for transfer learning problems involving distribution shift in the features but no difference in the label sets, and the projected predictor otherwise.

3 Discussion

In this work, we developed a framework that enables transfer learning with kernel methods. In particular, we introduced the projection and translation operations to adjust the predictions of a source model to a specific target task: While projection involves applying a map directly to the predictions given by the source model, translation involves adding a map to the predictions of a source model. We demonstrated the effectiveness of the transfer learned kernels on image classification and virtual drug screening tasks. Namely, we showed that transfer learning increased the performance of kernel-based image classifiers by up to 10%10\% over training such models directly on the target task. Interestingly, we found that transfer-learned convolutional kernels performed comparably to transfer learning using the corresponding finite-width convolutional networks. In virtual drug screening, we demonstrated that the transferred kernel methods provided an improvement over prior work [42], even in settings where none of the target drug and cell lines were present in the source task. For both applications, we analyzed the performance of the transferred kernel model as a function of the number of target examples and observed empirtically that the transferred kernel followed a simple logarithmic trend, thereby enabling predicting the benefit of collecting more target examples on model performance. Lastly, we mathematically derived the scaling laws in the linear setting, thereby providing a theoretical foundation for the empirical observations. We end by discussing various consequences as well as future research directions motivated by our work.

Benefit of pre-training kernel methods on large datasets. A key contribution of our work is enabling kernels trained on large datasets to be transferred to a variety of downstream tasks. As is the case for neural networks, this allows pre-trained kernel models to be saved and shared with downstream users to improve their applications of interest. A key next step to making these models easier to save and share is to reduce their reliance on storing the entire training set such as by using coresets [49]. We envision that by using such techniques in conjunction with modern advances in kernel methods, the memory and runtime costs could be drastially reduced.

Reducing kernel evaluation time for state-of-the-art convolutional kernels. In this work, we demonstrated that it is possible to train convolutional kernel methods on datasets with over 1 million images. In order to train such models, we resorted to using the CNTK of convolutional networks with a fully connected last layer. While other architectures, such as the CNTK of convolutional networks with a global average pooling last layer, have been shown to achieve superior performance on CIFAR10 [2], training such kernels on 5050k images from CIFAR10 is estimated to take 12001200 GPU hours [36], which is more than three orders of magnitude slower than the kernels used in this work. The main computational bottleneck for using such improved convolution kernels is evaluating the kernel function itself. Thus an important problem is to improve the computation time for such kernels, which would allow training better convolutional kernels on large-scale image datasets, which could then be transferred using our framework to improve the performance on a variety of downstream tasks.

Using kernel methods to adapt to distribution shifts. Our work demonstrates that kernels pre-trained on a source task can adapt to a target task with distribution shift when given even just a few target training samples. This opens novel avenues for applying kernel methods to tackle distribution shift in a variety of domains including healthcare or genomics in which models need to be adapted to handle shifts in cell lines, populations, batches, etc. In the context of virtual drug screening, we showed that our transfer learning approach could be used to generalize to new cell lines. The scaling laws described in this work may provide an interesting avenue to understand how many samples are required in the target domain for more complex domain shifts such as from a model organism like mouse to humans, a problem of great interest in the pharmacological industry.

Acknowledgements

The authors were partially supported by ONR (N00014-22-1-2116), NSF (DMS-1651995), NCCIH/NIH, the MIT-IBM Watson AI Lab, MIT J-Clinic for Machine Learning and Health, the Eric and Wendy Schmidt Center at the Broad Institute, and a Simons Investigator Award (to C.U.).

References

  • [1] Democratizing deep-learning for drug discovery, quantum chemistry, materials science and biology. https://github.com/deepchem/deepchem, 2016.
  • [2] S. Arora, S. S. Du, W. Hu, Z. Li, R. Salakhutdinov, and R. Wang. On exact computation with an infinitely wide neural net. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems. Curran Associates, Inc., 2019.
  • [3] S. Arora, S. S. Du, Z. Li, R. Salakhutdinov, R. Wang, and D. Yu. Harnessing the power of infinitely wide deep nets on small-data tasks. In International Conference on Learning Representations, 2020.
  • [4] P. L. Bartlett, P. M. Long, G. Lugosi, and A. Tsigler. Benign overfitting in linear regression. Proceedings of the National Academy of Sciences, 117(48):30063–30070, 2020.
  • [5] M. Belkin, D. Hsu, S. Ma, and S. Mandal. Reconciling modern machine-learning practice and the classical bias–variance trade-off. Proceedings of the National Academy of Sciences, 116(32):15849–15854, 2019.
  • [6] M. Belkin, D. Hsu, and J. Xu. Two models of double descent for weak features. Society for Industrial and Applied Mathematics Journal on Mathematics of Data Science, 2(4):1167–1180, 2020.
  • [7] A. Belyaeva, L. Cammarata, A. Radhakrishnan, C. Squires, K. Yang, G. Shivashankar, and C. Uhler. Causal network models of SARS-CoV-2 expression and aging to identify candidates for drug repurposing. Nature Communications, 12(1024), 2021.
  • [8] A. Bietti. Approximation and learning with deep convolutional models: a kernel perspective. In International Conference on Learning Representations, 2022.
  • [9] C. M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag, Berlin, Heidelberg, 2006.
  • [10] S. Chatterjee and E. Meckes. Multivariate normal approximation using exchangeable pairs. ALEA Lat. Am. J. Probab. Math. Stat., 4, 01 2007.
  • [11] P. Chrabaszcz, I. Loshchilov, and F. Hutter. A downsampled variant of imagenet as an alternative to the cifar datasets. arXiv:1707.08819, 2017.
  • [12] M. Cimpoi, S. Maji, I. Kokkinos, S. Mohamed, , and A. Vedaldi. Describing textures in the wild. In Proceedings of the IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2014.
  • [13] S. Corsello, R. Nagari, R. Spangler, J. Rossen, M. Kocak, J. Bryan, R. Humeidi, D. Peck, X. Wu, A. Tang, V. Wang, S. Bender, E. Lemire, R. Narayan, P. Montgomery, U. ben david, C. Garvie, Y. Chen, M. Rees, and T. Golub. Discovering the anticancer potential of non-oncology drugs by systematic viability profiling. Nature Cancer, 1:1–14, 02 2020.
  • [14] W. Dai, Q. Yang, G.-R. Xue, and Y. Yu. Boosting for transfer learning. In ACM International Conference Proceeding Series, volume 227, pages 193–200, 01 2007.
  • [15] J. De Fauw, J. R. Ledsam, B. Romera-Paredes, S. Nikolov, N. Tomasev, S. Blackwell, H. Askham, X. Glorot, B. O’Donoghue, D. Visentin, et al. Clinically applicable deep learning for diagnosis and referral in retinal disease. Nature Medicine, 24(9):1342–1350, 2018.
  • [16] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), 2019.
  • [17] J. Donahue, Y. Jia, O. Vinyals, J. Hoffman, N. Zhang, E. Tzeng, and T. Darrell. Decaf: A deep convolutional activation feature for generic visual recognition. In International Conference on Machine Learning, 2014.
  • [18] H. W. Engl, M. Hanke, and A. Neubauer. Regularization of Inverse Problems, volume 375. Springer Science & Business Media, 1996.
  • [19] A. Esteva, B. Kuprel, R. Novoa, J. Ko, S. Swetter, H. Blau, and S. Thrun. Dermatologist-level classification of skin cancer with deep neural networks. Nature, 542, 2017.
  • [20] I. Goodfellow, Y. Bengio, and A. Courville. Deep Learning, volume 1. MIT Press, 2016.
  • [21] T. Hastie, A. Montanari, S. Rosset, and R. J. Tibshirani. Surprises in high-dimensional ridgeless least squares interpolation. arXiv:1903.08560, 2019.
  • [22] D. Hendrycks and T. G. Dietterich. Benchmarking neural network robustness to common corruptions and perturbations. arXiv:1903.12261, 2019.
  • [23] R. Hodos, P. Zhang, H.-C. Lee, Q. Duan, Z. Wang, N. R. Clark, A. Ma’ayan, F. Wang, B. Kidd, J. Hu, D. Sontag, and J. Dudley. Cell-specific prediction and application of drug-induced gene expression profiles. Pacific Symposium on Biocomputing, 23:32–43, 2018.
  • [24] M. Huh, P. Agrawal, and A. A. Efros. What makes ImageNet good for transfer learning? arXiv:1608.08614, 2016.
  • [25] A. Jacot, F. Gabriel, and C. Hongler. Neural Tangent Kernel: Convergence and generalization in neural networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems. Curran Associates, Inc., 2018.
  • [26] S. Jaeger-Honz, S. Fulle, and S. Turk. Mol2vec: Unsupervised machine learning approach with chemical intuition. Journal of Chemical Information and Modeling, 58, 12 2017.
  • [27] D. P. Kingma and J. Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
  • [28] A. Krizhevsky. Learning multiple layers of features from tiny images. Master’s thesis, University of Toronto, 2009.
  • [29] J. Lee, S. S. Schoenholz, J. Pennington, B. Adlam, L. Xiao, R. Novak, and J. Shol-Dickstein. Finite Versus Infinite Neural Networks: an Empirical Study. In Advances in Neural Information Processing Systems, 2020.
  • [30] H. Lin and M. Reimherr. On transfer learning in functional linear regression. arXiv:2206.04277, 2022.
  • [31] S. Ma and M. Belkin. Kernel machines that adapt to GPUs for effective large batch training. In Conference on Machine Learning and Systems, 2019.
  • [32] P. Nakkiran, G. Kaplun, Y. Bansal, T. Yang, B. Barak, and I. Sutskever. Deep double descent: Where bigger models and more data hurt. In International Conference in Learning Representations, 2020.
  • [33] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng. Reading digits in natural images with unsupervised feature learning. 2011.
  • [34] E. Nichani, A. Radhakrishnan, and C. Uhler. Increasing depth leads to U-shaped test risk in over-parameterized convolutional networks. In International Conference on Machine Learning Workshop on Over-parameterization: Pitfalls and Opportunities, 2021.
  • [35] M.-E. Nilsback and A. Zisserman. Automated flower classification over a large number of classes. 2008.
  • [36] R. Novak, L. Xiao, J. Hron, J. Lee, A. A. Alemi, J. Sohl-Dickstein, and S. S. Schoenholz. Neural Tangents: Fast and easy infinite neural networks in Python. In International Conference on Learning Representations, 2020.
  • [37] D. Obst, B. Ghattas, J. Cugliari, G. Oppenheim, S. Claudel, and Y. Goude. Transfer learning for linear regression: a statistical test of gain. arXiv:2102.09504, 2021.
  • [38] T. E. Oliphant. A guide to NumPy, volume 1. Trelgol Publishing USA, 2006.
  • [39] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala. Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems. Curran Associates, Inc., 2019.
  • [40] M. E. Peters, M. Neumann, M. Iyyer, M. Gardner, C. Clark, K. Lee, and L. Zettlemoyer. Deep contextualized word representations. In Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), 2018.
  • [41] S. Pushpakom, F. Iorio, P. A. Eyers, K. J. Escott, S. Hopper, A. Wells, A. Doig, J. Guilliams, T. Latimer, C. McNamee, A. Norris, P. Sanseau, D. Cavalla, and M. Pirmohamed. Drug repurposing: progress, challenges and recommendations. Nature Reviews Drug Discovery, 18(1):41–58, 2019.
  • [42] A. Radhakrishnan, G. Stefanakis, M. Belkin, and C. Uhler. Simple, fast, and flexible framework for matrix completion with infinite width neural networks. arXiv:2108.00131, 2021.
  • [43] C. Raffel, N. Shazeer, A. Roberts, K. Lee, S. Narang, M. Matena, Y. Zhou, W. Li, and P. J. Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of Machine Learning Research, 21(140):1–67, 2020.
  • [44] M. Raghu, C. Zhang, J. Kleinberg, and S. Bengio. Transfusion: Understanding transfer learning for medical imaging. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, 2019.
  • [45] A. S. Razavian, H. Azizpour, J. Sullivan, and S. Carlsson. CNN features off-the-shelf: An astounding baseline for recognition. In IEEE Conference on Computer Vision and Pattern Recognition Workshops, 2014.
  • [46] B. Schölkopf and A. J. Smola. Learning with Kernels: Support Vector Machines, Regularization, Optimization, and Beyond. MIT Press, 2002.
  • [47] A. Subramanian, R. Narayan, S. M. Corsello, et al. A next generation connectivity map: L1000 platform and the first 1,000,000 profiles. Cell, 171(6):1437–1452, 2017.
  • [48] V. N. Vapnik. Statistical Learning Theory. Wiley-Interscience, 1998.
  • [49] Y. Zheng and J. M. Phillips. Coresets for kernel regression. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages 645–654, 2017.
  • [50] F. Zhuang, Z. Qi, K. Duan, D. Xi, Y. Zhu, H. Zhu, H. Xiong, and Q. He. A comprehensive survey on transfer learning. 2020.

Appendix

Appendix A Review of Kernel Regression

We provide a brief review of training kernel machines using kernel ridge regression. For a detailed description of these methods, we refer the reader to [46].

Let X=[x(1),,x(n)]d×nX=[x^{(1)},\ldots,x^{(n)}]\in\mathbb{R}^{d\times n} denote training examples and y=[y(1),,y(n)]1×ny=[y^{(1)},\ldots,y^{(n)}]\in\mathbb{R}^{1\times n} denote the corresponding labels. The key idea behind kernel machines is to first transform the training examples, XX, using a feature map and then perform linear regression on the transformed data to fit the labels, yy. In particular, given a feature map ψ:dm\psi:\mathbb{R}^{d}\to\mathbb{R}^{m}, we can nonlinearly fit the data by solving

argminwmywTψ(X)2,\displaystyle\operatorname*{arg\,min}_{w\in\mathbb{R}^{m}}\|y-w^{T}\psi(X)\|^{2},

where ψ(X)m×n\psi(X)\in\mathbb{R}^{m\times n} and the iith column of ψ(X)\psi(X) is ψ(x(i))\psi(x^{(i)}). In cases where mm is much larger than dd, solving the above system becomes computationally expensive. The key idea behind kernel regression is to assume that ww is given by a linear combination of training examples, i.e., w=i=1nαiψ(x(i))w=\sum_{i=1}^{n}\alpha_{i}\psi(x^{(i)}), and then solve the above system for the coefficients α=[α1,αn]1×n\alpha=[\alpha_{1},\ldots\alpha_{n}]\in\mathbb{R}^{1\times n} instead. Assuming this form for ww, the above optimization problem can be written as

argminαni=1n(y(i)αj=1nψ(x(j)),ψ(x(i)))2.\displaystyle\operatorname*{arg\,min}_{\alpha\in\mathbb{R}^{n}}\sum_{i=1}^{n}\left(y^{(i)}-\alpha\sum_{j=1}^{n}\langle\psi(x^{(j)}),\psi(x^{(i)})\rangle\right)^{2}.

Importantly, this optimization problem only depends on the inner product of the feature map between training samples. Thus, instead of working with the feature map directly, we can define a kernel, i.e., a positive semi-definite, symmetric function K:d×dK:\mathbb{R}^{d}\times\mathbb{R}^{d}\to\mathbb{R}, such that K(x(j),x(i))=ψ(x(j)),ψ(x(i))K(x^{(j)},x^{(i)})=\langle\psi(x^{(j)}),\psi(x^{(i)})\rangle. The resulting optimization problem is known as kernel regression and given as follows:

argminα1×nyαKn22,\displaystyle\operatorname*{arg\,min}_{\alpha\in\mathbb{R}^{1\times n}}\|y-\alpha K_{n}\|_{2}^{2}\leavevmode\nobreak\ ,

where Knn×nK_{n}\in\mathbb{R}^{n\times n} with (Kn)i,j=K(x(i),x(j))(K_{n})_{i,j}=K(x^{(i)},x^{(j)}). Importantly, the abstraction to kernel methods allows using feature maps that map into an infinite dimensional inner product space (i.e., a Hilbert space), which are central to the study of infinite-width neural networks.

In addition to kernel regression described above, we also consider kernel ridge regression, which involves modifying the objective with a regularization term with a tunable ridge parameter λ\lambda as follows:

argminα1×nyα(Kn+λIn×n)22.\displaystyle\operatorname*{arg\,min}_{\alpha\in\mathbb{R}^{1\times n}}\|y-\alpha(K_{n}+\lambda I_{n\times n})\|_{2}^{2}\leavevmode\nobreak\ .

We primarily use a small non-zero ridge term to avoid numerical issues leading to a singular (or non-invertible) kernel matrix, KnK_{n}.

Appendix B Overview of Image Classification Datasets

For projection, we used ImageNet32 as the source dataset and CIFAR10, Oxford 102 Flowers, DTD, and a subset of SVHN as the target datasets. For all target datasets, we used the training and test splits given by the PyTorch library [39]. For ImageNet32, we used the training and test splits provided by the authors [11]. An overview of the number of training and test samples used from each of these datasets is outlined below.

  1. 1.

    ImageNet32 contains 1,281,1671,281,167 training images across 10001000 classes and 5050k images for validation. All images are of size 32×32×332\times 32\times 3.

  2. 2.

    CIFAR10 contains 5050k training images across 1010 classes and 1010k images for validation. All images are of size 32×32×332\times 32\times 3.

  3. 3.

    Oxford 102 Flowers contains 10201020 training images across 102102 classes and 61496149 images for validation. Images were resized to 32×32×332\times 32\times 3 for the experiments.

  4. 4.

    DTD contains 18801880 training images across 4747 classes and 18801880 images for validation. Images were resized to size 32×32×332\times 32\times 3 for experiments.

  5. 5.

    SVHN contains 7325773257 training images across 1010 classes and 2630226302 images for validation. All images are of size 32×32×332\times 32\times 3. In Fig. 2, we used the same 500500 training image subset for all experiments.

Appendix C Training and Architecture Details

Model descriptions:

  1. 1.

    Laplace Kernel: For samples x,x~x,\tilde{x}, and bandwidth parameter LL, the kernel is of the form:

    exp(xx~2L).\exp\left(-\frac{\|x-\tilde{x}\|_{2}}{L}\right).

    For our experiments, we used a bandwidth of L=10L=10.

  2. 2.

    NTK: We used the NTK corresponding to an infinite width ReLU fully connected network with 55 hidden layers. We chose this depth as it gave superior performance on image classification task considered in [34].

  3. 3.

    CNTK: We used the CNTK corresponding to an infinite width ReLU convolutional network with 66 convolutional layers followed by a fully connected layer. All convolutional layers used filters of size 3×33\times 3. The first 55 convolutional layers used a stride size of 22 to downsample the image representations. All convolutional layers used zero padding. The CNTK was computed using the Neural Tangents library [36].

  4. 4.

    CNN: We compare the CNTK to a finite-width CNN of the same architecture that has 1616 filters in the first layer, 3232 filters in the second layer, 6464 filters in the third layer, 128128 filters in the fourth layer, and 256256 filters in the fifth and sixth layers. In all experiments, the CNN was trained using Adam with a learning rate of 10410^{-4}.

Details for projection experiments.

For all kernels trained on ImageNet32, we used EigenPro [31]. For all models, we trained until the training accuracy was greater than 99%99\%, which was at most 6 epochs of EigenPro. For transfer learning to CIFAR10, Oxford 102 Flowers, DTD, and SVHN, we applied a Laplace kernel to the outputs of the trained source model. For CIFAR10 and DTD, we solved the kernel regression exactly using NumPy [38]. For DTD and SVHN, we used ridge regularization with a coefficient of 10410^{-4} to avoid numerical issues with solving exactly. The CNN was trained for at most 500 epochs on ImageNet32, and the transferred model corresponded to the one with highest validation accuracy during this time. When transfer learning, we fine-tuned all layers of the CNN for up to 200200 epochs (again selecting the model with the highest validation accuracy on the target task).

Details for translation experiments.

For transferring kernels from CIFAR10 to CIFAR-C, we simply solved kernel regression exactly (no ridge regularization term). For the corresponding CNNs, we trained the source models on CIFAR10 for 100 epochs and selected the model with the best validation performance. When transferring CNNs to CIFAR-C, we fine-tuned all layers of the CNN for 200200 epochs and selected the model with the best validation accuracy. When translating kernels from ImageNet32 to CIFAR10 in Appendix Fig. 9, we used the following aggregated class indices in ImageNet32 to match the classes in CIFAR10:

  1. 1.

    plane = {372, 230, 231, 232}

  2. 2.

    car = {265, 266, 267, 268 }

  3. 3.

    bird = {383, 384, 385, 386}

  4. 4.

    cat = {8, 10, 11, 55}

  5. 5.

    deer = {12, 9, 57}

  6. 6.

    dog = {131, 132, 133, 134}

  7. 7.

    frog = {499, 500, 501, 494}

  8. 8.

    horse = {80, 39}

  9. 9.

    ship = {243, 246, 247, 235}

  10. 10.

    truck = {279, 280, 281, 282}.

Details for virtual drug screening.

We used the NTK corresponding to a 1 hidden layer ReLU fully connected network with an offset term. The same model was used in [42]. We solved kernel ridge regression when training the source models, baseline models, and transferred models. For the source model, we used ridge regularization with a coefficient of 10001000. To select this ridge term, we used a grid search over {1,10,100,1000,10000}\{1,10,100,1000,10000\} on a random subset of 1010k samples from the source data. We used a ridge term of 10001000 when transferring the source model to the target data and a term of 100100 when training the baseline model. We again tuned the ridge parameter for these models over the same set of values but on a random subset of 10001000 examples for one cell line (A549) from the target data. We used 5-fold cross validation for the target task and reported the metrics computed across all folds.

Appendix D Projection Scaling Laws

For the curves showing the performance of the projected predictor as a function of the number of target examples in Fig. 2b and Appendix Fig. 6a, b, we performed a scaling law analysis. In particular, we used linear regression to fit the coefficients a,ba,b of the function y=alog2x+by=a\log_{2}x+b to the points from each of the curves presented in the figures. Each curve in these figures has 5050 evenly spaced points and all accuracies are averaged over 3 seeds at each point. The R2R^{2} values for each of the fits is presented in Appendix Fig. 8. Overall, we observe that all values are above 0.9440.944 and are higher than 0.990.99 for CIFAR10 and SVHN, which have more than 20002000 target training samples. Moreover, by fitting the same function on the first 5 points from these curves for CIFAR10, we are able to predict the accuracy on the last point of the curve within 2%2\% of the reported accuracy.

Appendix E Pre-processing for CMAP Data

While CMAP contains 978978 landmark genes, we removed all genes that were 11 upon log2(x+1)\log_{2}(x+1) scaling the data. This eliminates 135135 genes and removes batch effects identified in [7] for each cell line. Following the methodology of [7], we also removed all perturbations with dose less than 0 and used only the perturbations that had an associated simplified molecular-input line-entry system (SMILES) string, which resulted in a total of 20,33620,336 perturbations. Following [7], for each of the 116,228116,228 observed drug and cell type combinations we then averaged the gene expression over all the replicates.

Appendix F Metrics for Evaluating Virtual Drug Screening

Let y^n×d\hat{y}\in\mathbb{R}^{n\times d} denote the predicted gene expression vectors and let yn×dy^{*}\in\mathbb{R}^{n\times d} denote the ground truth. Let y¯(i)=1dj=1dyj(i)\bar{y}^{(i)}=\frac{1}{d}\sum_{j=1}^{d}y_{j}^{(i)}. Let y^v,yvdn\hat{y}_{v},y^{*}_{v}\in\mathbb{R}^{dn} denote vectorized versions of y^\hat{y} and yy^{*}. We use the same three metrics as those considered in [42, 23]. All evaluation metrics have a maximum value of 11 and are defined below.

1. Pearson r value:

r=y^v,yvy^v2yv2.\displaystyle r=\frac{\langle\hat{y}_{v},y^{*}_{v}\rangle}{\|\hat{y}_{v}\|_{2}\|y_{v}^{*}\|_{2}}.

2.Mean R2R^{2}:

R2=1ni=1n(1j=1d(y^j(i)yj(i))2j=1d(yj(i)y¯(i))2).\displaystyle R^{2}=\frac{1}{n}\sum_{i=1}^{n}\left(1-\frac{\sum_{j=1}^{d}(\hat{y}_{j}^{(i)}-{y_{j}^{*}}^{(i)})^{2}}{\sum_{j=1}^{d}({y_{j}^{*}}^{(i)}-\bar{y}^{(i)})^{2}}\right).

3. Mean Cosine Similarity:

c=1ni=1ny^(i),y(i)y^(i)2y(i)2.\displaystyle c=\frac{1}{n}\sum_{i=1}^{n}\frac{\langle\hat{y}^{(i)},{y^{*}}^{(i)}\rangle}{\|\hat{y}^{(i)}\|_{2}\|{y^{*}}^{(i)}\|_{2}}.

We additionally subtract out the mean over cell type before computing cosine similarity to avoid inflated cosine similarity arising from points far from the origin.

Appendix G DepMap Analysis

To provide another application of our framework in the context of virtual drug screening, we used projection to transfer the kernel methods trained on imputing gene expression vectors in CMAP to predicting the viability of a drug and cell line combination in DepMap [13]. Viability scores in DepMap are real values indicating how lethal a drug is for a given cancer cell line (negative viability indicates cell death). To transfer from CMAP to DepMap, we trained a kernel method to predict the gene expression vectors for 55,46255,462 cell line and drug combinations for the 64 cell lines from CMAP that do not overlap with DepMap. We then used projection to transfer the model to the 66 held-out cell lines present in both CMAP and DepMap, which are PC3, MCF7, A375, A549, HT29, and HEPG2. Analogously to our analysis of CMAP, we stratified the target dataset by drugs that appear in both the source and target tasks (9726 target samples) and drugs that are only found in the target task but not in the source task (2685 target samples). For this application, we found that Mol2Vec [26] embeddings of drugs outperformed 1024-bit circular fingerprints. We again used a 1-hidden layer ReLU NTK with an offset term for this analysis and solved kernel ridge regression with a ridge coefficient of 100100.

Appendix Fig. 12a shows the performance of the projected predictor as a function of the number of target samples when transferring to a target task with drugs that appear in the source task. All results are averaged over 55 folds of cross-validation and across 5 random seeds for the subset of target samples considered in each fold. It is apparent that performance is greatly improved when there are fewer than 20002000 samples, thereby highlighting the benefit of the imputed gene expression vectors in this setting. Interestingly, as in all the previous experiments, we find a clear logarithmic scaling law: fitting the coefficients of the curve y=alog2x+by=a\log_{2}x+b to the 7676 points on the graph yields an R2R^{2} of 0.9940.994, and fitting the curve to the first 1010 points lets us predict the R2R^{2} for the last point on the curve within 0.030.03. Appendix Fig. 12b shows how the performance on the target task is affected by the number of genes predicted in the source task. Again performance is averaged over 5 fold cross-validation and across 5 seeds per fold. When transferring to drugs that were available in the source task, performance monotonically increases when predicting more genes. On the other hand, when transferring to drugs that were not available in the target task, performance begins to degrade when increasing the number of predicted genes. This is intuitive, since not all genes would be useful for predicting the effect of an unseen drug and could add noise to the prediction problem upon transfer learning.

Appendix H Proof of Theorem 1

The proof of Theorem 1 relies on the following lemma.

Lemma 1.

Let D,Λ{0,1}d×dD,\Lambda\in\{0,1\}^{d\times d} be two diagonal matrices of rank pp and qq, respectively. Let VRd×dV\in R^{d\times d} be an orthogonal matrix and WRd×dW\in R^{d\times d} a Haar distributed random matrix. If P=WDWT,Q=VΛVTP=WDW^{T},Q=V\Lambda V^{T}, then

𝔼W[PQP]=pd(d1)(d+2)[q(dp)Id×d+[d(p+1)2]Q].\mathbb{E}_{W}\left[PQP\right]=\frac{p}{d(d-1)(d+2)}\left[q(d-p)I_{d\times d}+\left[d(p+1)-2\right]Q\right].
Proof.

Without loss of generality, assume that Λ=diag(𝟏q,𝟎dq)\Lambda=\text{diag}(\mathbf{1}_{q},\mathbf{0}_{d-q}), and D=diag(𝟏p,𝟎dp)D=\text{diag}(\mathbf{1}_{p},\mathbf{0}_{d-p}). Since the Haar distribution is rotational invariant, U=VTWU=V^{T}W is Haar distributed. Therefore,

𝔼W[PQP]\displaystyle\mathbb{E}_{W}\left[PQP\right] =𝔼WHaar[WDWTVΛVTWDWT]\displaystyle=\mathbb{E}_{W\sim Haar}\left[WDW^{T}V\Lambda V^{T}WDW^{T}\right]
=V𝔼UHaar[UDUTΛUDUT]VT\displaystyle=V\mathbb{E}_{U\sim Haar}\left[UDU^{T}\Lambda UDU^{T}\right]V^{T}
=V𝔼UHaar[ATA]VT,\displaystyle=V\mathbb{E}_{U\sim Haar}\left[A^{T}A\right]V^{T},

where A=ΛUDUTA=\Lambda UDU^{T}. Now the upper left q×pq\times p block of ΛUD\Lambda UD is equal to the corresponding block in UU, and all other entries of ΛUD\Lambda UD are 0. Letting u~i=(ui1,,uip)\tilde{u}_{i}=(u_{i1},\dots,u_{ip}), we have

A=(u~1,u~1u~1,u~du~q,u~1u~q,u~d0000).A=\left(\begin{array}[]{ccc}\langle\tilde{u}_{1},\tilde{u}_{1}\rangle&\cdots&\langle\tilde{u}_{1},\tilde{u}_{d}\rangle\\ \vdots&\ddots&\vdots\\ \langle\tilde{u}_{q},\tilde{u}_{1}\rangle&\cdots&\langle\tilde{u}_{q},\tilde{u}_{d}\rangle\\ 0&\cdots&0\\ \vdots&\ddots&\vdots\\ 0&\cdots&0\end{array}\right).

Thus, (ATA)i,r=k=1qu~i,u~ku~r,u~k(A^{T}A)_{i,r}=\sum_{k=1}^{q}\langle\tilde{u}_{i},\tilde{u}_{k}\rangle\langle\tilde{u}_{r},\tilde{u}_{k}\rangle, and so, 𝔼U[ATA]\mathbb{E}_{U}[A^{T}A] only depends on the fourth moments of the entries in UU. In particular,

𝔼U[(ATA)i,r]=α=1qj,s=1p𝔼U[uijursuαjuαs].\mathbb{E}_{U}\left[(A^{T}A)_{i,r}\right]=\sum_{\alpha=1}^{q}\sum_{j,s=1}^{p}\mathbb{E}_{U}\left[u_{ij}u_{rs}u_{\alpha j}u_{\alpha s}\right].

To calculate these moments, we use Lemma 9 from [10]. In particular, if iri\neq r, then 𝔼U[(ATA)i,r=0\mathbb{E}_{U}[(A^{T}A)_{i,r}=0, and if i=ri=r, then

𝔼U[uijuisuαjuαs]=1d(d1)(d+2)[dδjs+dδiα+(d2)δiαδjs1].\mathbb{E}_{U}\left[u_{ij}u_{is}u_{\alpha j}u_{\alpha s}\right]=\frac{1}{d(d-1)(d+2)}\left[d\delta_{js}+d\delta_{i\alpha}+(d-2)\delta_{i\alpha}\delta_{js}-1\right].

Therefore, we have the following closed form for the expectation:

𝔼U[(ATA)i,i]\displaystyle\mathbb{E}_{U}\left[(A^{T}A)_{i,i}\right] =pd(d1)(d+2)[(dp)Id×d+[d(p+1)2]Λ].\displaystyle=\frac{p}{d(d-1)(d+2)}\left[(d-p)I_{d\times d}+\left[d(p+1)-2\right]\Lambda\right].

Since 𝔼W[PQP]=V𝔼U[(ATA)]VT\mathbb{E}_{W}[PQP]=V\mathbb{E}_{U}\left[(A^{T}A)\right]V^{T}, VVT=Id×dVV^{T}=I_{d\times d}, and VΛVT=QV\Lambda V^{T}=Q, the result follows. ∎

We now prove the following simpler version of Theorem 1 for the case when ns=dn_{s}=d (i.e., when ω^s=ωs\hat{\omega}_{s}=\omega_{s}).

Theorem 3.

Let 𝒳=d,𝒴s=cs,𝒴t=ct\mathcal{X}=\mathbb{R}^{d},\mathcal{Y}_{s}=\mathbb{R}^{c_{s}},\mathcal{Y}_{t}=\mathbb{R}^{c_{t}} and let ω^p=yt(ωsXt)\hat{\omega}_{p}=y_{t}(\omega_{s}X_{t})^{\dagger}. Assuming s,t\mathbb{P}_{s},\mathbb{P}_{t} are independent, isotropic distributions on d\mathbb{R}^{d}, then

(ω^pωs)\displaystyle\mathcal{R}(\hat{\omega}_{p}\omega_{s}) =(1ntdcs(d+2)(d1))(ω^t)+nt(1d+dnt(d+2)(d1))ε,\displaystyle=\left(1-n_{t}\frac{d-c_{s}}{(d+2)(d-1)}\right)\mathcal{R}(\hat{\omega}_{t})+n_{t}\left(\frac{1}{d}+\frac{d-n_{t}}{(d+2)(d-1)}\right)\varepsilon,

where ε=||ωt(Id×dωsωs)|F2\varepsilon=||\omega_{t}(I_{d\times d}-\omega_{s}^{\dagger}\omega_{s})\||_{F}^{2}.

Proof.

Let X=XtX=X_{t} to simplify notation. Let ωs||=ωsωs\omega_{s}^{||}=\omega_{s}^{\dagger}\omega_{s}, ωs=Id×dωs||\omega_{s}^{\perp}=I_{d\times d}-\omega_{s}^{||}, X||=XXX^{||}=XX^{\dagger}, and X=Id×dX||X^{\perp}=I_{d\times d}-X^{||}, and note that ωs,ω||,X,X||{0,1}d×d\omega_{s}^{\perp},\omega^{||},X^{\perp},X^{||}\in\{0,1\}^{d\times d} are all projections. Then, we have:

ω^pωs=yt(ωsX)ωs=ωtXXωsωs=ωtX||ωs||.\hat{\omega}_{p}\omega_{s}=y_{t}(\omega_{s}X)^{\dagger}\omega_{s}=\omega_{t}XX^{\dagger}\omega_{s}^{\dagger}\omega_{s}=\omega_{t}X^{||}\omega_{s}^{||}.

Therefore,

ω^pωsωt\displaystyle\hat{\omega}_{p}\omega_{s}-\omega_{t} =ωt(X||ωs||Id×d)\displaystyle=\omega_{t}\left(X^{||}\omega_{s}^{||}-I_{d\times d}\right)
=ωt[(X||Id×d)ωs||+ωs||Id×d]\displaystyle=\omega_{t}\left[\left(X^{||}-I_{d\times d}\right)\omega_{s}^{||}+\omega_{s}^{||}-I_{d\times d}\right]
=ωt[Xωs||+ωs].\displaystyle=-\omega_{t}\left[X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right].

Using the cyclic property of the trace, the risk is given by

(ω^pωs)\displaystyle\mathcal{R}(\hat{\omega}_{p}\omega_{s}) =𝔼X,x[ω^pωsxωtx2]\displaystyle=\mathbb{E}_{X,x}\left[\left|\left|\hat{\omega}_{p}\omega_{s}x-\omega_{t}x\right|\right|^{2}\right]
=Tr(𝔼X,x[ωt(Xωs||+ωs)xxT(Xωs||+ωs)TωtT])\displaystyle=\operatorname{Tr}\left(\mathbb{E}_{X,x}\left[\omega_{t}\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)xx^{T}\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)^{T}\omega_{t}^{T}\right]\right)
=Tr(𝔼X[ωt(Xωs||+ωs)(Xωs||+ωs)TωtT]).\displaystyle=\operatorname{Tr}\left(\mathbb{E}_{X}\left[\omega_{t}\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)^{T}\omega_{t}^{T}\right]\right)\leavevmode\nobreak\ .

Using the idempotent property of projections and the fact that ωsωs||=ωs||ωs=𝟎\omega_{s}^{\perp}\omega_{s}^{||}=\omega_{s}^{||}\omega_{s}^{\perp}=\mathbf{0}, we conclude that

(Xωs||+ωs)(Xωs||+ωs)T\displaystyle\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)\left(X^{\perp}\omega_{s}^{||}+\omega_{s}^{\perp}\right)^{T} =Xωs||ωs||X+Xωs||ωs+ωsωs||X+ωsωs\displaystyle=X^{\perp}\omega_{s}^{||}\omega_{s}^{||}X^{\perp}+X^{\perp}\omega_{s}^{||}\omega_{s}^{\perp}+\omega_{s}^{\perp}\omega_{s}^{||}X^{\perp}+\omega_{s}^{\perp}\omega_{s}^{\perp}
=Xωs||X+ωs,\displaystyle=X^{\perp}\omega_{s}^{||}X^{\perp}+\omega_{s}^{\perp},

and as a consequence that

(ω^pωs)\displaystyle\mathcal{R}(\hat{\omega}_{p}\omega_{s}) =Tr(𝔼X[ωt(Xωs||X+ωs)ωtT])\displaystyle=\operatorname{Tr}\left(\mathbb{E}_{X}\left[\omega_{t}\left(X^{\perp}\omega_{s}^{||}X^{\perp}+\omega_{s}^{\perp}\right)\omega_{t}^{T}\right]\right)
=Tr(ωt(𝔼X[Xωs||X]+ωs)ωtT).\displaystyle=\operatorname{Tr}\left(\omega_{t}\left(\mathbb{E}_{X}\left[X^{\perp}\omega_{s}^{||}X^{\perp}\right]+\omega_{s}^{\perp}\right)\omega_{t}^{T}\right).

Both X,ωs||X^{\perp},\omega_{s}^{||} are projections, and since XX follows an isotropic distribution, its right singular vectors (the eigenvectors of XX^{\perp}) are Haar distributed. Now using Lemma 1 with p=dnt,q=csp=d-n_{t},q=c_{s} we obtain that

𝔼X[Xωs||X]=(1ntd)[csnt(d1)(d+2)Id×d+(1dnt(d1)(d+2))ωs||].\mathbb{E}_{X}\left[X^{\perp}\omega_{s}^{||}X^{\perp}\right]=\left(1-\frac{n_{t}}{d}\right)\left[\frac{c_{s}n_{t}}{(d-1)(d+2)}I_{d\times d}+\left(1-\frac{dn_{t}}{(d-1)(d+2)}\right)\omega_{s}^{||}\right].

Using ωs=Id×dωs\omega_{s}^{\parallel}=I_{d\times d}-\omega_{s}^{\perp} and reordering the terms we obtain

𝔼X[Xωs||X]+ωs\displaystyle\mathbb{E}_{X}\left[X^{\perp}\omega_{s}^{||}X^{\perp}\right]+\omega_{s}^{\perp} =(1ntd)(1nt(dcs)(d+2)(d1))Id×d\displaystyle=\left(1-\frac{n_{t}}{d}\right)\left(1-\frac{n_{t}(d-c_{s})}{(d+2)(d-1)}\right)I_{d\times d}
+(ntd+nt(dnt)(d+2)(d1))ωs.\displaystyle\quad\quad+\left(\frac{n_{t}}{d}+\frac{n_{t}(d-n_{t})}{(d+2)(d-1)}\right)\omega_{s}^{\perp}.

Lastly, we use the standard result that (ω^t)=(1ntd)\mathcal{R}(\hat{\omega}_{t})=\left(1-\frac{n_{t}}{d}\right) (see e.g. [6]) and that ε=Tr(ωtωsωtT)\varepsilon=\operatorname{Tr}\left(\omega_{t}\omega_{s}^{\perp}\omega_{t}^{T}\right) to conclude that

(ω^pωs)\displaystyle\mathcal{R}(\hat{\omega}_{p}\omega_{s}) =(1ntdcs(d+2)(d1))(ω^t)+nt(1d+dnt(d+2)(d1))ε,\displaystyle=\left(1-n_{t}\frac{d-c_{s}}{(d+2)(d-1)}\right)\mathcal{R}(\hat{\omega}_{t})+n_{t}\left(\frac{1}{d}+\frac{d-n_{t}}{(d+2)(d-1)}\right)\varepsilon,

which completes the proof. ∎

Using Lemma 1 and Theorem 3, we next prove Theorem 1, which is restated below for the reader’s convenience.

Theorem.

Let 𝒳=d\mathcal{X}=\mathbb{R}^{d}, 𝒴s=cs\mathcal{Y}_{s}=\mathbb{R}^{c_{s}}, 𝒴t=ct\mathcal{Y}_{t}=\mathbb{R}^{c_{t}}, and let ω^s=ysXs\hat{\omega}_{s}=y_{s}X_{s}^{\dagger} and ω^p=yt(ω^sXt)\hat{\omega}_{p}=y_{t}(\hat{\omega}_{s}X_{t})^{\dagger}. Assuming that s\mathbb{P}_{s} and t\mathbb{P}_{t} are independent, isotropic distributions on d\mathbb{R}^{d}, then the risk (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is given by

(ω^pω^s)=[(C1+C2K1)(1ntd)+(1C1C2)]ωtF2+C2K2ε,\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=\left[\left(C_{1}+C_{2}K_{1}\right)\left(1-\frac{n_{t}}{d}\right)+\left(1-C_{1}-C_{2}\right)\right]||\omega_{t}||_{F}^{2}+C_{2}K_{2}\varepsilon,

where ε=ωt(Id×dωsωs)F2\varepsilon=||\omega_{t}(I_{d\times d}-\omega_{s}^{\dagger}\omega_{s})||_{F}^{2} and

C1\displaystyle C_{1} =nscs(dns)d(d1)(d+2),C2=ns[d(ns+1)2]d(d1)(d+2),\displaystyle=\frac{n_{s}c_{s}(d-n_{s})}{d(d-1)(d+2)}\quad\textrm{,}\quad C_{2}=\frac{n_{s}\left[d(n_{s}+1)-2\right]}{d(d-1)(d+2)}\leavevmode\nobreak\ ,
K1\displaystyle K_{1} =1nt(dcs)(d1)(d+2),K2=ntd+nt(dnt)(d1)(d+2).\displaystyle=1-\frac{n_{t}(d-c_{s})}{(d-1)(d+2)}\quad\textrm{,}\quad K_{2}=\frac{n_{t}}{d}+\frac{n_{t}(d-n_{t})}{(d-1)(d+2)}\leavevmode\nobreak\ .
Proof.

As in the proof of Theorem 3, we let X=XtX=X_{t} to simplify notation. We follow the proof of Theorem 3, but now account for the expectation with respect to XsX_{s}. Namely,

(ω^pω^s)\displaystyle\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) =Tr(𝔼Xs,X[ωt(Xω^s||X+ω^s)ωtT])\displaystyle=\operatorname{Tr}\left(\mathbb{E}_{X_{s},X}\left[\omega_{t}\left(X^{\perp}\hat{\omega}_{s}^{||}X^{\perp}+\hat{\omega}_{s}^{\perp}\right)\omega_{t}^{T}\right]\right)
=Tr(ωt(𝔼Xs,X[Xω^s||X]+ω^s)ωtT).\displaystyle=\operatorname{Tr}\left(\omega_{t}\left(\mathbb{E}_{X_{s},X}\left[X^{\perp}\hat{\omega}_{s}^{||}X^{\perp}\right]+\hat{\omega}_{s}^{\perp}\right)\omega_{t}^{T}\right).

Using the independence of ω^s\hat{\omega}_{s} and XX^{\perp} and Fubini’s theorem, we compute the expectations sequentially:

(ω^pω^s)=Tr(ωt(𝔼X[X𝔼Xs[ω^s||]X]+𝔼Xs[ω^s])ωtT).\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=\operatorname{Tr}\left(\omega_{t}\left(\mathbb{E}_{X}\left[X^{\perp}\mathbb{E}_{X_{s}}\left[\hat{\omega}_{s}^{||}\right]X^{\perp}\right]+\mathbb{E}_{X_{s}}\left[\hat{\omega}_{s}^{\perp}\right]\right)\omega_{t}^{T}\right).

Now, since ω^s=ωsXs||\hat{\omega}_{s}=\omega_{s}X_{s}^{||}, we have that ω^s=Xs||ωs\hat{\omega}_{s}^{\dagger}=X_{s}^{||}\omega_{s}^{\dagger}. Therefore, ω^s||=ω^sω^s=Xs||ωs||Xs||\hat{\omega}_{s}^{||}=\hat{\omega}_{s}^{\dagger}\hat{\omega}_{s}=X_{s}^{||}\omega_{s}^{||}X_{s}^{||}. Similarly, ω^s=Id×dω^s||\hat{\omega}_{s}^{\perp}=I_{d\times d}-\hat{\omega}_{s}^{||}. As a consequence, we calculate the two expectations involving XsX_{s} by using Lemma 1 with p=ns,q=csp=n_{s},q=c_{s}. In particular, we conclude that

𝔼Xs[ω^s||]\displaystyle\mathbb{E}_{X_{s}}\left[\hat{\omega}_{s}^{||}\right] =C1Id×d+C2ωs||=(C1+C2)Id×dC2ωs,\displaystyle=C_{1}I_{d\times d}+C_{2}\omega_{s}^{||}=(C_{1}+C_{2})I_{d\times d}-C_{2}\omega_{s}^{\perp},
𝔼Xs[ω^s]\displaystyle\mathbb{E}_{X_{s}}\left[\hat{\omega}_{s}^{\perp}\right] =(1C1C2)Id×d+C2ωs.\displaystyle=(1-C_{1}-C_{2})I_{d\times d}+C_{2}\omega_{s}^{\perp}.

Therefore, (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is given by the sum of the following terms:

C1Tr(ωt𝔼X[X]ωtT)\displaystyle C_{1}\operatorname{Tr}\left(\omega_{t}\mathbb{E}_{X}\left[X^{\perp}\right]\omega_{t}^{T}\right) =C1(ω^t),\displaystyle=C_{1}\mathcal{R}(\hat{\omega}_{t}),
C2Tr(ωt(𝔼X[Xω||X])ωtT)\displaystyle C_{2}\operatorname{Tr}\left(\omega_{t}\left(\mathbb{E}_{X}\left[X^{\perp}\omega^{||}X^{\perp}\right]\right)\omega_{t}^{T}\right) =C2K1(ω^t)+C2(K21)ε,\displaystyle=C_{2}K_{1}\mathcal{R}(\hat{\omega}_{t})+C_{2}(K_{2}-1)\varepsilon,
(1C1C2)Tr(ωtωtT)+C2Tr(ωtωsωtT)\displaystyle(1-C_{1}-C_{2})\operatorname{Tr}\left(\omega_{t}\omega_{t}^{T}\right)+C_{2}\operatorname{Tr}\left(\omega_{t}\omega_{s}^{\perp}\omega_{t}^{T}\right) =(1C1C2)Tr(ωtωtT)+C2ε,,\displaystyle=(1-C_{1}-C_{2})\operatorname{Tr}\left(\omega_{t}\omega_{t}^{T}\right)+C_{2}\varepsilon,\leavevmode\nobreak\ ,

where for the second equality, we applied Theorem 3, which gives rise to K1K_{1} and K2K_{2}, thereby completing the proof. ∎

Appendix I Proof of Corollary 1

Proof.

We restate Corollary 1 below for the reader’s convenience.

Corollary.

Let S=nsd,T=ntd,C=csdS=\frac{n_{s}}{d},T=\frac{n_{t}}{d},C=\frac{c_{s}}{d} and assume ωtF=Θ(1)\|\omega_{t}\|_{F}=\Theta(1). Under the setting of Theorem 1, if S,T,C<S,T,C<\infty as dd\to\infty, then:

  1. a)

    (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is monotonically decreasing for S[0,1]S\in[0,1] if ε<(1C)ωtF\varepsilon<(1-C)\|\omega_{t}\|_{F}.

  2. b)

    If 2S1ST<02S-1-ST<0, then (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) decreases as CC increases.

  3. c)

    If S=1S=1, then (ω^pω^s)=(1T+TC)(ω^t)+εT(2T)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=(1-T+TC)\mathcal{R}(\hat{\omega}_{t})+\varepsilon T(2-T).

  4. d)

    If S=1S=1 and T,C=Θ(δ)T,C=\Theta(\delta), then (ω^pω^s)=(12T)ωtF2+2Tε+Θ(δ2).\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=(1-2T)\|\omega_{t}\|_{F}^{2}+2T\varepsilon+\Theta(\delta^{2}).

We first derive forms for the terms C1,C2,K1,K2C_{1},C_{2},K_{1},K_{2} from Theorem 1 as dd\to\infty. In particular, we have:

C1=SCS2C;C2=S2;K1=1T+TC;K2=2TT2.\displaystyle C_{1}=SC-S^{2}C\leavevmode\nobreak\ \leavevmode\nobreak\ ;\leavevmode\nobreak\ \leavevmode\nobreak\ C_{2}=S^{2}\leavevmode\nobreak\ \leavevmode\nobreak\ ;\leavevmode\nobreak\ \leavevmode\nobreak\ K_{1}=1-T+TC\leavevmode\nobreak\ \leavevmode\nobreak\ ;\leavevmode\nobreak\ \leavevmode\nobreak\ K_{2}=2T-T^{2}.

Substituting these values into (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) in Theorem 1, we obtain

(ω^pω^s)=[12S2T+S2T2+(2S1ST)STC]ωtF2+S2T[2T]ε.\displaystyle\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=\left[1-2S^{2}T+S^{2}T^{2}+(2S-1-ST)STC\right]\|\omega_{t}\|_{F}^{2}+S^{2}T[2-T]\varepsilon. (5)

Next, we analyze Eq. (5) for S[0,1]S\in[0,1]. For fixed TT and CC, it holds that (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is a quadratic in SS and given by

(ω^pω^s)=(1STC)ωtF2S2T[2T][1C]ωtF2+S2T[2T]ε.\displaystyle\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s})=(1-STC)\|\omega_{t}\|_{F}^{2}-S^{2}T[2-T][1-C]\|\omega_{t}\|_{F}^{2}+S^{2}T[2-T]\varepsilon.

For S[0,1]S\in[0,1] and ε<(1C)ωtF2\varepsilon<(1-C)\|\omega_{t}\|_{F}^{2}, this quadratic is strictly decreasing and thus we can conclude that (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is decreasing in SS. We next observe that (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) is linear in CC and thus (ω^pω^s)\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) decreases as CC increases if and only if the coefficient of CC is negative, i.e. (2SST1)<0(2S-ST-1)<0. Lastly, if S=1S=1, then

(ω^pω^s)\displaystyle\mathcal{R}(\hat{\omega}_{p}\hat{\omega}_{s}) =[(1T)2+(1T)TC]ωtF2+T[2T]ε\displaystyle=[(1-T)^{2}+(1-T)TC]\|\omega_{t}\|_{F}^{2}+T[2-T]\varepsilon
=(1T)(1T+TC)ωtF2+T[2T]ε\displaystyle=(1-T)(1-T+TC)\|\omega_{t}\|_{F}^{2}+T[2-T]\varepsilon
=(1T+TC)(ω^t)+T[2T]ε.\displaystyle=(1-T+TC)\mathcal{R}(\hat{\omega}_{t})+T[2-T]\varepsilon.

Corollary 1c, d follow from the above form of the risk, thus completing the proof. ∎

Appendix J Equivalence of Fine-tuned and Translated Linear Models

We now prove that for linear models transfer learning using the translated predictor from Definition 2 is equivalent to transfer learning via the conventional fine-tuning process. This follows from Proposition 1 below, which implies that when parameterized by a linear model, the translated predictor is the interpolating solution for the target dataset that is nearest to the source predictor.

Proposition 1.

Let f^s(x)=ws,ψ(x)\hat{f}_{s}(x)=\langle w_{s},\psi(x)\rangle_{\mathcal{H}}, where ψ:d\psi:\mathbb{R}^{d}\to\mathcal{H} is a feature map and \mathcal{H} is a Hilbert space. Then the translated predictor, f^t\hat{f}_{t}, is the solution to

argminwwws\displaystyle\operatorname*{arg\,min}_{w}\|w-w_{s}\|_{\mathcal{H}} (6)
subjecttow,ψ(Xt)=yt.\displaystyle\operatorname*{subject\leavevmode\nobreak\ to\leavevmode\nobreak\ }\langle w,\psi(X_{t})\rangle_{\mathcal{H}}=y_{t}.
Proof.

Note that any solution ww to Problem 6 can be written as w=ws+w~w=w_{s}+\tilde{w}. Hence, we can rewrite Problem 6 as follows:

argminw~w~\displaystyle\operatorname*{arg\,min}_{\tilde{w}}\|\tilde{w}\|_{\mathcal{H}}
subjecttows+w~,ψ(Xt)=yt,\displaystyle\operatorname*{subject\leavevmode\nobreak\ to\leavevmode\nobreak\ }\langle w_{s}+\tilde{w},\psi(X_{t})\rangle_{\mathcal{H}}=y_{t},

where the constraint can be simplified to w~,ψ(Xt)=ytf^s(Xt)\langle\tilde{w},\psi(X_{t})\rangle_{\mathcal{H}}=y_{t}-\hat{f}_{s}(X_{t}). This is precisely the constraint for the translated predictor in Definition 2, thereby completing the proof. ∎

Appendix K Proof of Theorem 2

We restate Theorem 2 below for convenience and then provide the proof.

Theorem.

Let 𝒳=d\mathcal{X}=\mathbb{R}^{d}, 𝒴s=cs\mathcal{Y}_{s}=\mathbb{R}^{c_{s}}, 𝒴t=ct\mathcal{Y}_{t}=\mathbb{R}^{c_{t}}, and let ω^t=ω^s+ω^c\hat{\omega}_{t}=\hat{\omega}_{s}+\hat{\omega}_{c} where ω^s=ysXs\hat{\omega}_{s}=y_{s}X_{s}^{\dagger} and ω^c=(ytω^sXt)Xt\hat{\omega}_{c}=(y_{t}-\hat{\omega}_{s}X_{t})X_{t}^{\dagger}. Assuming that s\mathbb{P}_{s} and t\mathbb{P}_{t} are independent, isotropic distributions on d\mathbb{R}^{d}, the the risk (ω^t)\mathcal{R}(\hat{\omega}_{t}) is given by

(ω^t)=[ωsωtF2ωtF2+(1nsd)(1ωsωtF2ωtF2)](ω^b),\mathcal{R}(\hat{\omega}_{t})=\left[\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}+\left(1-\frac{n_{s}}{d}\right)\left(1-\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}\right)\right]\mathcal{R}(\hat{\omega}_{b}),

where ω^b=ytXt\hat{\omega}_{b}=y_{t}X_{t}^{\dagger} is the baseline predictor.

Proof.

We prove the statement by directly simplifying the risk as follows.

(ω^t)\displaystyle\mathcal{R}(\hat{\omega}_{t}) =𝔼x,Xs,Xt[(ω^txωtx)2]\displaystyle=\mathbb{E}_{x,X_{s},X_{t}}\left[(\hat{\omega}_{t}x-\omega_{t}x)^{2}\right]
=𝔼Xs,Xt[ω^tωtF2]\displaystyle=\mathbb{E}_{X_{s},X_{t}}\left[\|\hat{\omega}_{t}-\omega_{t}\|_{F}^{2}\right]
=𝔼Xs,Xt[ω^s+(ytω^sXt)XtωtF2](By Definition 2)\displaystyle=\mathbb{E}_{X_{s},X_{t}}\left[\|\hat{\omega}_{s}+(y_{t}-\hat{\omega}_{s}X_{t})X_{t}^{\dagger}-\omega_{t}\|_{F}^{2}\right]\quad\text{(By Definition\leavevmode\nobreak\ \ref{def: Translation})}
=𝔼Xs,Xt[ω^s(IXtXt)ωt(IXtXt)F2](As yt=ωtXt)\displaystyle=\mathbb{E}_{X_{s},X_{t}}\left[\|\hat{\omega}_{s}(I-X_{t}X_{t}^{\dagger})-\omega_{t}(I-X_{t}X_{t}^{\dagger})\|_{F}^{2}\right]\quad\text{(As $y_{t}=\omega_{t}X_{t}$)}
=(1ntd)𝔼Xs[ω^sωtF2](As 𝔼Xt[XtXt]=ntd)\displaystyle=\left(1-\frac{n_{t}}{d}\right)\mathbb{E}_{X_{s}}\left[\|\hat{\omega}_{s}-\omega_{t}\|_{F}^{2}\right]\quad\text{$\left(\text{As\leavevmode\nobreak\ }\mathbb{E}_{X_{t}}[X_{t}X_{t}^{\dagger}]=\frac{n_{t}}{d}\right)$}
=(1ntd)𝔼Xs[ωsXsXsF2+ωtF22ωsXsXs,ωt]\displaystyle=\left(1-\frac{n_{t}}{d}\right)\mathbb{E}_{X_{s}}\left[\|\omega_{s}X_{s}X_{s}^{\dagger}\|_{F}^{2}+\|\omega_{t}\|_{F}^{2}-2\langle\omega_{s}X_{s}X_{s}^{\dagger},\omega_{t}\rangle\right]
=(1ntd)[nsdωsF2+ωtF22nsdωs,ωt]\displaystyle=\left(1-\frac{n_{t}}{d}\right)\left[\frac{n_{s}}{d}\|\omega_{s}\|_{F}^{2}+\|\omega_{t}\|_{F}^{2}-\frac{2n_{s}}{d}\langle\omega_{s},\omega_{t}\rangle\right]
=(1ntd)[ωsωtF2+(1nsd)(ωtF2ωtωsF2)]\displaystyle=\left(1-\frac{n_{t}}{d}\right)\left[\|\omega_{s}-\omega_{t}\|_{F}^{2}+\left(1-\frac{n_{s}}{d}\right)\left(\|\omega_{t}\|_{F}^{2}-\|\omega_{t}-\omega_{s}\|_{F}^{2}\right)\right]
=[ωsωtF2ωtF2+(1nsd)(1ωsωtF2ωtF2)](ω^b),\displaystyle=\left[\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}+\left(1-\frac{n_{s}}{d}\right)\left(1-\frac{\|\omega_{s}-\omega_{t}\|_{F}^{2}}{\|\omega_{t}\|_{F}^{2}}\right)\right]\mathcal{R}(\hat{\omega}_{b}),

where the penultimate equality follows from adding and subtracting the term nsdωtF2\frac{n_{s}}{d}\|\omega_{t}\|_{F}^{2} and the last equality is given by (ωb^)=(1ntd)wtF2\mathcal{R}(\hat{\omega_{b}})=\left(1-\frac{n_{t}}{d}\right)\|w_{t}\|_{F}^{2}, thereby completing the proof. ∎

Appendix L Code and Hardware Details

All experiments were run using two servers. One server had 128GB of CPU random access memory (RAM) and 2 NVIDIA Titan XP GPUs each with 12GB of memory. This server was used for the virtual drug screening experiments and for training the CNTK on ImageNet32. The second server had 128GB of CPU RAM and 4 NVIDIA Titan RTX GPUs each with 24GB of memory. This server was used for all the remaining experiments. All code is available at https://github.com/uhlerlab/kernel_tf.

Refer to caption
Figure 5: Image classification performance of CNNs that are finite-width analogs of the CNTK considered in this work. (a) The accuracy of the CNNs on 4 target tasks when transferred from ImageNet32. All layers of the CNNs are fine-tuned during transfer learning. The CNN in the top row achieves a test accuracy of 16.72%16.72\% on ImageNet32. The early stopped CNN in the bottom row achieves an accuracy of 10.692%10.692\% on ImageNet32, which is comparable with the accuracy of the CNTK (10.64%10.64\%). (b) Performance of a CNN pre-trained on CIFAR10 when transferred to CIFAR-C. (c) Performance of a CNN pre-trained on ImageNet32 when transfered to CIFAR-C.
Refer to caption
Figure 6: (a, b) Performance of the projected kernel method as a function of the number of target examples when transferred from ImageNet32 to DTD and SVHN. (c, d) Performance of the projected kernel method as a function of the number of source classes when transferred from ImageNet32 to DTD and SVHN. The number of source examples was fixed to 4040k and we ensured that the number of source classes divides 4040k.
Refer to caption
(a) Laplace Kernel
Refer to caption
(b) NTK
Refer to caption
(c) CNTK
Figure 7: (a, b, c) Performance of three different kernels as a function of the number of source examples and target examples when projected from ImageNet32 to CIFAR10. The baseline predictor performance is shown as a dashed black line. Overall, we find that performance improves as the number of source training samples per class increases.
Refer to caption
Figure 8: R2 values given by fitting the coefficients a,ba,b of the curve y=alog2x+by=a\log_{2}x+b to the curves found empirically for the projected predictor performance as a function of the target samples. We see that for all kernels and datasets, the fit yields R2 values greater than 0.940.94 and values higher than 0.990.99 on datasets with more samples such as CIFAR10 and SVHN.
Refer to caption
(a) Laplace Kernel
Refer to caption
(b) Neural Tangent Kernel
Refer to caption
(c) Convolutional Neural Tangent Kernel
Figure 9: Accuracy of the translated predictor form ImageNet32 to CIFAR10. The black dashed line corresponds to the baseline predictor while the dashed color lines correspond to the source predictors. We observe that the translated predictor outperforms both projected and baseline predictors when increasing the number of target samples, and the performance of the translated predictor increases as the number of source examples per class increases..
Refer to caption
Figure 10: Performance of kernels translated from CIFAR10 to CIFAR-C. We observe that for all 19 perturbations, the translated predictor outperforms the source, baseline, and projected predictors.
Refer to caption
Figure 11: Performance of the transferred kernels with respect to the first two principal components of gene expression for the A375, MCF7, and PC3 cell lines.
Refer to caption
Figure 12: Analysis of projecting the kernel predictors of gene expression in CMAP to viability score in DepMap. (a) We observe up to a 0.20.2 boost in R2 values when projecting to new cell lines for which the drugs were available in the source task. (b) We observe that predicting more genes in the source task is helpful when transferring to new cell lines for which the considered drugs were available in the source task. Predicting more genes for new cell lines for which the considered drugs were not available in the source task is harmful when transfer learning.