This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Emergence of hierarchical modes from deep learning

Chan Li1    Haiping Huang1,2 huanghp7@mail.sysu.edu.cn 1PMI Lab, School of Physics, Sun Yat-sen University, Guangzhou 510275, People’s Republic of China 2Guangdong Provincial Key Laboratory of Magnetoelectric Physics and Devices, Sun Yat-sen University, Guangzhou 510275, People’s Republic of China
Abstract

Large-scale deep neural networks consume expensive training costs, but the training results in less-interpretable weight matrices constructing the networks. Here, we propose a mode decomposition learning that can interpret the weight matrices as a hierarchy of latent modes. These modes are akin to patterns in physics studies of memory networks, but the least number of modes increases only logarithmically with the network width, and becomes even a constant when the width further grows. The mode decomposition learning not only saves a significant large amount of training costs, but also explains the network performance with the leading modes, displaying a striking piecewise power-law behavior. The modes specify a progressively compact latent space across the network hierarchy, making a more disentangled subspaces compared to standard training. Our mode decomposition learning is also studied in an analytic on-line learning setting, which reveals multi-stage of learning dynamics with a continuous specialization of hidden nodes. Therefore, the proposed mode decomposition learning points to a cheap and interpretable route towards the magical deep learning.

Introduction.— Deep neural networks are dominant tools with a broad range of applications in not only image and language processing, but also scientific researches Goodfellow et al. (2016); Carleo et al. (2019). These networks are parameterized by a huge amount of trainable weight matrices, thereby consuming an expensive training cost. However, these weight matrices are hard to interpret, and thus mechanisms underlying the macroscopic performance of the networks remain a big mystery in theoretical studies of neural networks Huang (2022); Roberts et al. (2022).

To save the computational cost, previous studies of deep networks applied singular value decomposition to the weight matrices Jaderberg et al. (2014); Yang et al. (2020); Giambagli et al. (2021); Chicchi et al. (2021). This decomposition requires the orthogonality condition for the singular vectors and positive singular values. The training also involves a carefully-designed structure for the trainable decomposition scheme Giambagli et al. (2021); Chicchi et al. (2021). These constraints and designs make the training process complicated, and thus a concise physics interpretation is still lacking. In addition, previous studies of recurrent memory networks showed that the network weight can be decomposed into separate random orthogonal patterns with corresponding importance scores Jiang et al. (2021); Zhou et al. (2021). Inspired by these studies, we conjecture that the learning in deep networks is shaped by a hierarchy of latent modes, which are not necessarily orthogonal, and the weight matrix can be expressed by these modes.

The mode decomposition learning (MDL) leads to a progressively compact latent mode space across the network hierarchy, and meanwhile the subspaces corresponding to different types of input are strongly disentangled, facilitating discrimination. The least number of latent modes achieving the comparable performance with the costly standard methods grows only logarithmically with the network width and even could be a constant, thereby reducing significantly the training cost. The mode spectrum exhibits an intriguing piecewise power-law behavior. In particular, these properties do not depend on details of the training setting. Therefore, our proposed MDL calls for a rethinking of conventional weight-based deep learning through the lens of cheap and interpretable mode-based learning.

Model.— To show the effectiveness of the MDL scheme, we train a deep network to implement a classification task of handwritten digits mni . The deep network has LL layers (L2L-2 hidden layers) with NlN_{l} neurons in the ll-th layer. The weight value of the connection from the neuron ii at the upstream layer ll to the neuron jj at the downstream layer l+1l+1 is specified by wijlw_{ij}^{l}. The activation of the neuron jj at the downstream layer hjl+1=f(zjl+1)=max(0,zjl+1)h_{j}^{l+1}=f(z_{j}^{l+1})=\max(0,z_{j}^{l+1}), where the pre-activation zjl+1=iwijlhilz_{j}^{l+1}=\sum_{i}w_{ij}^{l}h_{i}^{l}. For the output layer, the softmax function hk=ezk/iezih_{k}=e^{z_{k}}/\sum_{i}e^{z_{i}} is chosen to specify the probability over all classes of the input images. The cross entropy 𝒞=ih^ilnhi\mathcal{C}=-\sum_{i}\hat{h}_{i}\ln h_{i} is used as the cost function for the supervised learning, and h^i\hat{h}_{i} is the target label (one-hot form). After training (the cross entropy is repeatedly averaged over mini-batches of training examples), we evaluate the generalization performance of the network on an unseen test dataset.

Single weight values are not interpretable. According to our hypothesis, latent patterns would emerge from training in each layer. We call these patterns hierarchical modes for deep learning. Therefore, the relationship between the modes and weight values is expressed by the following mode decomposition,

𝐰l=𝝃^l𝚺l(𝝃l+1)T,\mathbf{w}^{l}=\bm{\hat{\xi}}^{l}\bm{\Sigma}^{l}(\bm{{\xi}}^{l+1})^{\operatorname{T}}, (1)

where there are plp^{l} upstream modes 𝝃^lNl×pl\bm{\hat{\xi}}^{l}\in\mathbb{R}^{N_{l}\times p^{l}}, and the same number of downstream modes 𝝃l+1Nl+1×pl\bm{\xi}^{l+1}\in\mathbb{R}^{N_{l+1}\times p^{l}}. The importance of each pair of adjacent modes is specified by the diagonal of the importance matrix 𝚺lpl×pl\bm{\Sigma}^{l}\in\mathbb{R}^{p^{l}\times p^{l}}. These modes may not be orthogonal with each other, and the importance score can take a real value. This setting allows for more degrees of freedom for learning features of input-output mappings. We will detail their geometric and physical interpretations below.

Refer to caption
Figure 1: A simple illustration of the mode decomposition learning. (a) A deep neural network of three layers, including one hidden layer with three hidden nodes, for a classification task of non-linearly separable data. The weight matrix wijl=α=1pξ^i,αlΣαlξj,αlw_{ij}^{l}=\sum_{\alpha=1}^{p}\hat{\xi}_{i,\alpha}^{l}\Sigma_{\alpha}^{l}{\xi}_{j,\alpha}^{l}, where p=3p=3. The distribution of input data is modeled as a Gaussian mixture (see the main text) from which samples are assigned to labels t=±1t=\pm 1 based on the corresponding mixture component. The training performance is measured by the mean-squared-error loss function MSE(y,t)=yt2/2\ell_{\operatorname{MSE}}(y,t)=\|y-t\|^{2}/2. (b) The representation of hidden neurons 𝐡\mathbf{h} plotted in the 3D space, displaying the geometric separation. (c) The successive mappings from input sample 𝐱\mathbf{x} (grey) to (𝝃^1)T𝐱(\bm{\hat{\xi}}^{1})^{\operatorname{T}}\mathbf{x} (dark red), followed by 𝚺1(𝝃^1)T𝐱\bm{\Sigma}^{1}(\bm{\hat{\xi}}^{1})^{\operatorname{T}}\mathbf{x} (green), and finally 𝝃2𝚺1(𝝃^1)T𝐱\bm{\xi}^{2}\bm{\Sigma}^{1}(\bm{\hat{\xi}}^{1})^{\operatorname{T}}\mathbf{x} (blue).

A geometric interpretation of Eq. (1) in a simple learning task is shown in Fig. 1. We use a three-layer network with three hidden neurons. The input data is sampled from a four-component Gaussian mixture Fischer et al. (2022),

(𝐱,t)=P(t)±P±𝒩(𝐱|μxt,±,Σxt,±),\mathbb{P}(\mathbf{x},t)=P(t)\sum_{\pm}P_{\pm}\mathcal{N}\left(\mathbf{x}|\mu_{x}^{t,\pm},\Sigma_{x}^{t,\pm}\right), (2)

where 𝒩(𝐱|μxt,±,Σxt,±)\mathcal{N}\left(\mathbf{x}|\mu_{x}^{t,\pm},\Sigma_{x}^{t,\pm}\right) denotes a Gaussian distribution with mean μxt,±\mu_{x}^{t,\pm} and covariances Σxt,±\Sigma_{x}^{t,\pm}, and P(t)=P±=12P(t)=P_{\pm}=\frac{1}{2}. For the label t=+1t=+1, μxt=+1,±=±(0.5,0.5)T\mu_{x}^{t=+1,\pm}=\pm(0.5,0.5)^{\operatorname{T}}, while for t=1t=-1, μxt=1,±=±(0.5,0.5)T\mu_{x}^{t=-1,\pm}=\pm(-0.5,0.5)^{\operatorname{T}}. Covariances are isotropic throughout with Σxt,±=0.05𝟙\Sigma_{x}^{t,\pm}=0.05\mathds{1}. The input samples 𝐱2\mathbf{x}\in\mathbb{R}^{2} are first projected to the input pattern space spanned by (𝝃^1)iT(\hat{\bm{\xi}}^{1})^{\operatorname{T}}_{i} (i=1,2,3i=1,2,3). Then all three directions of this projection get expanded or contracted via 𝚺1(𝝃^1)T𝐱\bm{\Sigma}^{1}(\bm{\hat{\xi}}^{1})^{\operatorname{T}}\mathbf{x}. Finally the geometrically modified representation is re-mapped to the downstream representation space of a higher dimensionality, as 𝝃2𝚺1(𝝃^1)T𝐱\bm{{\xi}}^{2}\bm{\Sigma}^{1}(\bm{\hat{\xi}}^{1})^{\operatorname{T}}\mathbf{x} [Fig. 1 (c)]. The non-linearity of the transfer function is then applied to the last linear transformation, leading to the geometric separation [Fig. 1 (b)]. We conclude that the MDL provides rich angles to look at the geometric transformation of the input information along the hierarchy of deep networks.

Rather than the conventional weight values in standard backpropagation (BP) algorithms Goodfellow et al. (2016), the trainable parameters are latent patterns in the MDL. The training is implemented by stochastic gradient descent in the mode space 𝜽l=(𝝃^l,𝚺l,𝝃l+1)\bm{\theta}^{l}=(\bm{\hat{\xi}}^{l},\bm{\Sigma}^{l},\bm{{\xi}}^{l+1}) SM ,

Δξjαl+1\displaystyle\Delta\xi^{l+1}_{j\alpha} ηξjαl+1=η𝒦jl+1Σαliξ^iαlhil,\displaystyle\equiv-\eta\frac{\partial\mathcal{L}}{\partial\xi^{l+1}_{j\alpha}}=-\eta\mathcal{K}_{j}^{l+1}\Sigma_{\alpha}^{l}\sum_{i}\hat{\xi}_{i\alpha}^{l}h_{i}^{l}, (3)
ΔΣαl\displaystyle\Delta\Sigma_{\alpha}^{l} ηΣαl=ηj𝒦jl+1ξjαl+1iξ^iαlhil,\displaystyle\equiv-\eta\frac{\partial\mathcal{L}}{\partial\Sigma^{l}_{\alpha}}=-\eta\sum_{j}\mathcal{K}_{j}^{l+1}\xi_{j\alpha}^{l+1}\sum_{i}\hat{\xi}_{i\alpha}^{l}h_{i}^{l},
Δξ^iαl\displaystyle\Delta\hat{\xi}^{l}_{i\alpha} ηξ^iαl=ηΣαlhilj𝒦jl+1ξjαl+1,\displaystyle\equiv-\eta\frac{\partial\mathcal{L}}{\partial\hat{\xi}_{i\alpha}^{l}}=-\eta\Sigma_{\alpha}^{l}h_{i}^{l}\sum_{j}\mathcal{K}_{j}^{l+1}\xi_{j\alpha}^{l+1},

where \mathcal{L} denotes the cost function (e.g., cross-entropy or mean-squared error) over a mini-batch of training data, η\eta denotes the learning rate, and 𝒦jl+1/zjl+1\mathcal{K}_{j}^{l+1}\equiv\partial\mathcal{L}/\partial z_{j}^{l+1} denotes the error term, which could back-propagate from the top layer where 𝒦jL=h^jL(1hjL)\mathcal{K}_{j}^{L}=-\hat{h}_{j}^{L}\left(1-h_{j}^{L}\right) for =𝒞\mathcal{L}=\mathcal{C} (cross entropy). Based on the chain rule, the error backpropagation equation can be derived as 𝒦il=j𝒦jl+1αξiαl+1Σαlξ^jαlf(zil)\mathcal{K}_{i}^{l}=\sum_{j}\mathcal{K}_{j}^{l+1}\sum_{\alpha}\xi_{i\alpha}^{l+1}\Sigma_{\alpha}^{l}\hat{\xi}_{j\alpha}^{l}f^{\prime}(z_{i}^{l}) SM . To ensure the pre-activation is independent of the upstream-layer width, we take the initialization scheme that [𝝃l+1𝚺l(𝝃^l)T]ij𝒪(1Nl)[\bm{{\xi}}^{l+1}\bm{{\Sigma}}^{l}(\hat{\bm{\xi}}^{l})^{\operatorname{T}}]_{ij}\sim\mathcal{O}(\frac{1}{\sqrt{N_{l}}}) Jiang et al. (2021). To avoid the ambiguity of choosing patterns (e.g., scaled by a factor), we impose an identical regularization with strength 10410^{-4} for all trainable parameters. However, our result does not change qualitatively with the specific values of regularization SM .

We remark that for each hidden layer, there exist two types of pattern (𝝃l𝝃^l\bm{\xi}^{l}\neq\hat{\bm{\xi}}^{l}). Equation (3) is used to learn these patterns. We call this case 1L2P. If we assume 𝝃l=𝝃^l\bm{\xi}^{l}=\hat{\bm{\xi}}^{l}, the training can be further simplified as in SM , and we call this case 1L1P. The nature of this mode-based-computation can be understood as an expanded linear-nonlinear layered computation, as f(zjl+1)=f(αcαjκα)f(z_{j}^{l+1})=f(\sum_{\alpha}c_{\alpha j}\kappa_{\alpha}) where the linear field κα=iξ^iαlhil\kappa_{\alpha}=\sum_{i}\hat{\xi}_{i\alpha}^{l}h_{i}^{l} and the equivalent weight cαj=ξjαl+1Σαlc_{\alpha j}=\xi_{j\alpha}^{l+1}\Sigma_{\alpha}^{l}. Therefore, the number of modes acts as the linear-layer width. We leave a systematic exploration of this linear-nonlinear structure by statistical mechanics in forthcoming works.

On-line learning dynamics in a shallow network.— The MDL can be analytically understood in an on-line learning setting, where we consider one-hidden-layer architecture. The on-line learning can be considered as a special case of the above mini-batch learning (i.e., the batch size is set to one, and the sample is visited by the learning only once). The training dataset consists of nn pairs {𝐱ν,yν}ν=1n\left\{\mathbf{x}^{\nu},y^{\nu}\right\}_{\nu=1}^{n}. Each training example is independently sampled from a probability distribution (𝐱,y)=(y|𝐱)(𝐱)\mathbb{P}(\mathbf{x},y)=\mathbb{P}(y|\mathbf{x})\mathbb{P}(\mathbf{x}), where (𝐱)\mathbb{P}(\mathbf{x}) is a standard Gaussian distribution, and the scalar label yν{y}^{\nu} is generated by the neural network of kk hidden neurons, (i.e., teacher, indicated by the symbol * below). Given an input 𝐱νd\mathbf{x}^{\nu}\in\mathbb{R}^{d}, the corresponding label is created by

yν=1kr=1kσ([𝝃𝚺(𝝃^)T]r𝐱νd)=1kr=1kσ(λrν),y^{\nu}=\frac{1}{k}\sum_{r=1}^{k}\sigma\left(\frac{[\bm{{\xi}}^{*}\bm{{\Sigma}^{*}}(\bm{\hat{\xi}^{*}})^{\operatorname{T}}]_{r}\mathbf{x}^{\nu}}{\sqrt{d}}\right)=\frac{1}{k}\sum_{r=1}^{k}\sigma\left(\lambda_{r}^{*\nu}\right), (4)

where [𝝃𝚺(𝝃^)T]r[\bm{\xi}^{*}\bm{\Sigma}^{*}(\hat{\bm{\xi}}^{*})^{\operatorname{T}}]_{r} denotes the rr-th row of the matrix 𝝃𝚺(𝝃^)T\bm{\xi}^{*}\bm{\Sigma}^{*}(\hat{\bm{\xi}}^{*})^{\operatorname{T}}, and λrν=[𝝃𝚺(𝝃^)T]r𝐱ν/d\lambda_{r}^{*\nu}=[\bm{{\xi}}^{*}\bm{{\Sigma}^{*}}(\bm{\hat{\xi}^{*}})^{\operatorname{T}}]_{r}\mathbf{x}^{\nu}/{\sqrt{d}} represents the rr-th element of the teacher local field vector 𝝀νk\bm{\lambda}^{*\nu}\in\mathbb{R}^{k}. The teacher network is quenched as [𝝃𝚺(𝝃^)T]ij𝒪(1)[\bm{{\xi}}^{*}\bm{{\Sigma}^{*}}(\bm{\hat{\xi}^{*}})^{\operatorname{T}}]_{ij}\sim\mathcal{O}(1). Here, we focus on the non-linear transfer function σ(x)=erf(x/2)\sigma(x)=\operatorname{erf}(x/\sqrt{2}). In addition, we train the other shallow network called the student network, by minimizing the loss function (y,f^(𝐱,Θ))\mathcal{L}(y,\hat{f}(\mathbf{x},\Theta)) over the training data (labels are given by the teacher network), where Θ\Theta denotes the trainable parameters. The student’s prediction for a fresh sample 𝐱\mathbf{x} is given by

f^(𝐱,𝝃^,𝚺,𝝃)=1mr=1mσ([𝝃𝚺(𝝃^)T]r𝐱d)=1mr=1mσ(λr),\hat{f}(\mathbf{x},\bm{\hat{\xi}},\bm{{\Sigma}},\bm{{\xi}})=\frac{1}{m}\sum_{r=1}^{m}\sigma\left(\frac{[\bm{{\xi}}\bm{{\Sigma}}(\bm{\hat{\xi}})^{\operatorname{T}}]_{r}\mathbf{x}}{\sqrt{d}}\right)=\frac{1}{m}\sum_{r=1}^{m}\sigma\left(\lambda_{r}\right), (5)

where λr\lambda_{r} denotes the rr-th component of the student local field 𝝀=𝝃𝚺(𝝃^)𝐱\bm{\lambda}=\bm{{\xi}}\bm{{\Sigma}}(\bm{\hat{\xi}})^{\top}\mathbf{x}, and the student has mm hidden neurons. The student is supplied with data samples in sequence (one sample each time step). We next use ν\nu to indicate the time step as well.

The mean-squared-error can be evaluated as

MSE(𝛀)=12𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[(f^(𝝀)f(𝝀))2],\ell_{\mathrm{MSE}}(\bm{\Omega})=\frac{1}{2}\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\bm{\Omega}\right)}\left[\left(\hat{f}(\bm{\lambda})-f\left(\bm{\lambda}^{*}\right)\right)^{2}\right], (6)

where f()f(\cdot) indicates the teacher’s output, and we have replaced the expectation 𝔼𝐱,y(𝐱,y)[]\mathbb{E}_{\mathbf{x},y\sim\mathbb{P}(\mathbf{x},y)}[\cdot] by 𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[]\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\bm{\Omega}\right)}[\cdot], because of the central-limit theorem and the i.i.d. setting we consider Biehl and Schwarze (1995); Saad and Solla (1995); Goldt et al. (2019). The covariance of the local field 𝛀ν(k+m)×(k+m)\bm{\Omega}^{\nu}\in\mathbb{R}^{(k+m)\times(k+m)} can be specified as follows,

𝛀ν[𝐐ν𝐌ν(𝐌ν)T𝐏],\bm{\Omega}^{\nu}\equiv\left[\begin{array}[]{cc}\mathbf{Q}^{\nu}&\mathbf{M}^{\nu}\\ (\mathbf{M}^{\nu})^{\operatorname{T}}&\mathbf{P}\end{array}\right], (7)

where 𝐐ν𝔼𝐱,y(𝐱,y)[𝝀ν(𝝀ν)T]\mathbf{Q}^{\nu}\equiv\mathbb{E}_{\mathbf{x},y\sim\mathbb{P}(\mathbf{x},y)}\left[\bm{\lambda}^{\nu}(\bm{\lambda}^{\nu})^{\operatorname{T}}\right], 𝐌ν𝔼𝐱,y(𝐱,y)[𝝀ν(𝝀ν)T]\mathbf{M}^{\nu}\equiv\mathbb{E}_{\mathbf{x},y\sim\mathbb{P}(\mathbf{x},y)}\left[\bm{\lambda}^{\nu}(\bm{\lambda}^{*\nu})^{\operatorname{T}}\right], and 𝐏ν𝔼𝐱,y(𝐱,y)[𝝀ν(𝝀ν)T]\mathbf{P}^{\nu}\equiv\mathbb{E}_{\mathbf{x},y\sim\mathbb{P}(\mathbf{x},y)}\left[\bm{\lambda}^{*\nu}(\bm{\lambda}^{*\nu})^{\operatorname{T}}\right]. By definition, 𝐏\mathbf{P} is fixed, while 𝐐ν\mathbf{Q}^{\nu} and 𝐌ν\mathbf{M}^{\nu} evolve according to the gradient updates, following a set of deterministic ordinary differential equations (ODEs) as the input dimension dd\to\infty SM . These matrices are exactly the order parameters in physics. For simplicity, we consider 𝝃=𝝃\bm{\xi}=\bm{\xi}^{*} and 𝚺=𝚺\bm{\Sigma}=\bm{\Sigma}^{*}, i.e., only the upstream patterns are learned.

Refer to caption
Figure 2: Test performance and mode hierarchy of MDL in deep neural networks. (a) Training trajectories of a four-layer network, indicated by 784784-100100-100100-1010, where each number indicates the corresponding layer width. The number of modes pl=pp^{l}=p for layer ll, where l=1,,Ll=1,\ldots,L. p=20p=20, or p=30p=30. Networks are trained on the full MNIST dataset (6×1046\times 10^{4} images) and tested on an unseen dataset containing 10410^{4} images. The fluctuation is computed over five independent runs. (b) Testing accuracy versus pp (the number of modes is the same for all layers). The same architecture as (a) is used. The error bar characterizes the fluctuation across five independently trained networks, and each marker denotes the average result. The least number of modes is indicated by the dash-dot line. (c) The performance changes with the network width. The inset shows the least number of modes versus the layer width NN (in the logarithmic scale). The network architecture is given by 784784-NN-NN-1010. The dash-dot line in the inset separates the piecewise logarithmic increase (lnN\propto\ln N) regions. The result is obtained from five independent runs. (d) The averaged Euclidean distance (dispersion) from the pattern-cloud center (1pα𝝃αl)(\frac{1}{p}\sum_{\alpha}\bm{\xi}^{l}_{\alpha}) as a function of layer index. The network architecture is specified by 784784-100100-100100-100100-1010 (p=30p=30). (e-f) Subspace overlap (principal angle) versus layer. The overlap is averaged with five independent runs, and seven-layer networks with hidden-layer width 100100 are trained (p=30p=30).

Results.— MDL can reach a similar test accuracy with that of BP performed in the weight space, when pp is sufficiently large [Fig. 2 (a)]. The computational cost of the BP scales with Nl2N_{l}^{2}. In contrast, MDL works in the mode space, requiring a training cost of only the order of pNlpN_{l}. Note that pp is much smaller than NlN_{l} (or limNlpl/Nl=0\lim_{N_{l}\to\infty}p^{l}/N_{l}=0), and our MDL does not need any additional training constraints (compared to other matrix factorization algorithms SM ). Remarkably, when p=30p=30, the performance of MDL already matches that of BP [Fig. 2 (b)], but only utilizes 40%40\% of the full sets of parameters that are consumed by the BP. In fact, each hidden layer can have two different types of latent pattern (1L2P) due to the mode decomposition. But if we assume that 𝝃l=𝝃^l\bm{{\xi}}^{l}=\bm{\hat{\xi}}^{l}, i.e., each layer share a single type of pattern (1L1P), we can further reduce the computational cost by an amount of lplNl\sum_{l}p^{l}N_{l}, without sacrificing the test accuracy [Fig. 2 (b)]. Varying the network width, we reveal a logarithmic increase of the least number of modes [Fig. 2 (c)], which is a novel property of deep learning in the mode space, in stark contrast to a linear number of memory patterns in previous studies Jiang et al. (2021). When the network width further grows, the least number can even become a constant. We argue that this manifests three separated phases of poor-good-saturated performance with increasing layer width (see Fig. S9 in SM ).

Refer to caption
Figure 3: The robustness properties of well-trained four-layer MDL models with the architecture 784784-100100-100100-1010. The case of 1L1P is considered with p=70p=70 in the hidden layers. (a) Effects of removing modes through two protocols: removing modes with weak measure τ\tau first (solid line) and removing modes randomly (dashed line). The fluctuation is computed over ten independent runs. (b) The rescaled 2\ell_{2} norms γξ2\gamma\|\xi\|_{2}, γξ^2\gamma\|\hat{\xi}\|_{2} and the absolute values of Σ\Sigma versus their rank (in descending order) in the hidden layers, where γ=α|Σα|/α(ξα2+ξ^α2)\gamma=\sum_{\alpha}|{\Sigma_{\alpha}}|/\sum_{\alpha}(\|{\xi}_{\alpha}\|_{2}+\|{\hat{\xi}}_{\alpha}\|_{2}). The inset shows a log-log plot of the τ\tau measure, displaying a piecewise power-law behavior. The error bar is computed over five independent runs. The marked percentage indicates the generalization accuracy after removing the corresponding side of modes.

To see how the latent patterns are transformed in geometry along the network hierarchy, we first calculate the center of the pattern space. Then the Euclidean distance from this center to each pattern is analyzed. We find that the pattern space becomes progressively compact when going to deep layers [Fig. 2 (d)]. To further characterize the geometric details, we define the subspace spanned by the principal eigenvectors of the layer neural responses to one type of inputs. Then the subspace overlap is calculated as the cosine of the principal angle between two subspaces corresponding to two types of inputs Bjoerck and Golub (1973); SM . We find that the hidden-layer representation becomes more disentangled with layer in comparison with BP [Fig. 2 (e,f)]. MDL shows great computational benefits of representation disentanglement, thereby facilitating discrimination. A slight increase of the overlap is observed for deeper layers, which is caused by the saturation of the test performance (see more analyses in SM ).

Compared to other matrix factorization methods, MDL has no additional constraints for the modes and importance scores, therefore being flexible for feature extraction. We find that the interlayer patterns are less orthogonal than the intralayer ones. The geometric transformation carried out by these latent pattern matrices is not strictly a rotation for which the 2\ell_{2} norm is preserved. This flexibility may be the key to make our method better than other matrix factorization methods in both training cost and learning performance (see details in SM ).

Refer to caption
Figure 4: Mean-squared error dynamics in terms of t=νdt=\frac{\nu}{d}, where ν\nu denotes the on-line sample index, and dd is the input dimension. The teacher and student networks share the same number of hidden neurons (m=k=8m=k=8). Markers represent results of the simulation, while the solid lines denote the theoretical predictions from solving the mean-field ODEs. The number of modes p=p=αlndp^{*}=p=\alpha\ln d (α\alpha denotes the mode load here). (a) Fixed α=1\alpha=1. (b) Fixed d=100d=100. The color deepens as α\alpha or dd increases. The insets display the evolving 𝐌\mathbf{M} matrix for d=30d=30 and α=1.0\alpha=1.0, respectively.

We next ask whether some modes are more important than the others. Therefore, we rank the modes according to the measure τα=γξα2+γξ^α2+|Σα|\tau_{\alpha}=\gamma\|\xi_{\alpha}\|_{2}+\gamma\|\hat{\xi}_{\alpha}\|_{2}+|\Sigma_{\alpha}|, where γ=α|Σα|/α(ξα2+ξ^α2)\gamma=\sum_{\alpha}|{\Sigma_{\alpha}}|/\sum_{\alpha}(\|{\xi}_{\alpha}\|_{2}+\|{\hat{\xi}}_{\alpha}\|_{2}) to make comparable the magnitudes of the pattern and importance (𝚺\bm{\Sigma}) score. Removing modes with weak values of τ\tau first yields much higher accuracy than the random removal protocol [Fig. 3 (a)], suggesting the existence of leading modes. Moreover, deeper layers are more robust. Figure 3 (b) shows the measure as a function of rank in descending order, which can be approximately captured by piecewise power-law behavior (a transition point at the rank 1010). Ranking with only the importance scores yields similar behavior SM . A small exponent is observed for the leading measures, while the remaining measures bear a large exponent, thereby revealing the coding hierarchy of latent modes in the deep networks. This intriguing behavior does not change with the regularization strength or the hidden-layer width SM .

Finally, the on-line mean-squared error dynamics of our model can be predicted perfectly in a teacher-student setting. The number of modes strongly affects the shape of the learning dynamics, and a large mode load can make the plateaus disappear (Fig. 4). Moreover, during learning, the alignment between receptive fields of the student’s hidden nodes and the teacher’s ones continuously emerge, which is called the specialization transition Schwarze (1993); Goldt et al. (2019).

Conclusion.— In this Letter, we propose a mode decomposition learning that works in the mode space rather than the conventional weight space. This learning scheme has three-fold technical and conceptual advances. First, the learning can achieve the comparable performance with standard methods, with a significant reduction of training costs. We also find that the least number of modes grows only logarithmically with the network width and becomes even independent of larger width, which is in stark contrast to a linear number of patterns in recurrent memory networks. Second, the learning leads to progressively compact pattern spaces, which promotes highly disentangled hierarchical representations. The upstream pattern maps the activity into a low-dimensional space, and then the resulting embedding is further expanded or contracted. After that, the modified embedding is re-mapped into the high-dimensional activity space. This sequence of geometric transformation can be understood as a linear-nonlinear hidden structure. Third, all modes are not equally important to the generalization ability of the network, showing an intriguing piecewise power-law behavior. Finally, the mode learning dynamics can be predicted by the mean-field ODEs, revealing the mode specialization transition. Therefore, the MDL inspires a rethinking of conventional deep learning, offering a faster, more interpretable training framework. Future works along this direction will be inspired. For example, the impact of other structured dataset, mode dynamics in over-parameterized or recurrent networks, and the origin of adversarial vulnerability of deep networks in terms of geometry of the mode space.

I Acknowledgments

This research was supported by the National Natural Science Foundation of China for Grant number 12122515, and Guangdong Provincial Key Laboratory of Magnetoelectric Physics and Devices (No. 2022B1212010008), and Guangdong Basic and Applied Basic Research Foundation (Grant No. 2023B1515040023).

Supplemental Material

Appendix A Derivation of learning equations

In this section, we show how to derive the updating equations for the mode parameters 𝜽l=(𝝃^l,𝚺l,𝝃l+1)\bm{\theta}^{l}=(\bm{\hat{\xi}}^{l},\bm{\Sigma}^{l},\bm{{\xi}}^{l+1}) where the superscript ll indicates the layer index in the range from 11 to LL. The loss function is the cross entropy 𝒞=ih^ilnhi\mathcal{C}=-\sum_{i}\hat{h}_{i}\ln h_{i} averaged over all training examples (divided into mini-batches in stochastic gradient descent), where h^i\hat{h}_{i} is defined as the target label (one-hot representation as common in machine learning). After training the network on the training dataset with size TT, we evaluate the generalization performance of the network on the unseen dataset with size VV.

In our framework of mode decomposition learning, the weight is decomposed into the form as follows,

𝐰l=𝝃^l𝚺l(𝝃l+1)T.\mathbf{w}^{l}=\bm{\hat{\xi}}^{l}\bm{\Sigma}^{l}(\bm{{\xi}}^{l+1})^{\operatorname{T}}. (S1)

The mode parameters are updated according to gradient descent of the loss function,

Δ𝜽ijl=η𝒦jl+1zjl+1𝜽ijl,\Delta\bm{\theta}_{ij}^{l}=-\eta\mathcal{K}_{j}^{l+1}\frac{\partial z_{j}^{l+1}}{\partial\bm{\theta}_{ij}^{l}}, (S2)

where η\eta denotes the learning rate, and the error propagation term 𝒦jl+1𝒞/zjl+1\mathcal{K}_{j}^{l+1}\equiv\partial\mathcal{C}/\partial z_{j}^{l+1}. On the top layer, 𝒦jl+1\mathcal{K}_{j}^{l+1} can be computed with the result 𝒦jL=h^jL(1hjL)\mathcal{K}_{j}^{L}=-\hat{h}_{j}^{L}\left(1-h_{j}^{L}\right). For lower layers, the term 𝒦il\mathcal{K}_{i}^{l} can be iteratively computed using the chain rule. More precisely,

𝒦il\displaystyle\mathcal{K}_{i}^{l} =𝒞/zil=j𝒞zjl+1zjl+1zil\displaystyle=\partial\mathcal{C}/\partial z_{i}^{l}=\sum_{j}\frac{\partial\mathcal{C}}{\partial z_{j}^{l+1}}\frac{\partial z_{j}^{l+1}}{\partial z_{i}^{l}} (S3)
=j𝒦jl+1αξiαl+1Σαlξ^jαlf(zil).\displaystyle=\sum_{j}\mathcal{K}_{j}^{l+1}\sum_{\alpha}\xi^{l+1}_{i\alpha}\Sigma^{l}_{\alpha}\hat{\xi}_{j\alpha}^{l}f^{\prime}(z_{i}^{l}).

The explicit expressions of gradient steps for the three sets of mode parameters are given as follows,

Δξjαl+1\displaystyle\Delta\xi^{l+1}_{j\alpha} =η𝒞ξjαl+1=η𝒦jl+1iΣαlξ^iαlhil,\displaystyle=-\eta\frac{\partial\mathcal{C}}{\partial\xi^{l+1}_{j\alpha}}=-\eta\mathcal{K}_{j}^{l+1}\sum_{i}\Sigma_{\alpha}^{l}\hat{\xi}_{i\alpha}^{l}h_{i}^{l}, (S4)
ΔΣαl\displaystyle\Delta\Sigma_{\alpha}^{l} =η𝒞Σαl=ηj𝒦jl+1iξjαl+1ξ^iαlhil,\displaystyle=-\eta\frac{\partial\mathcal{C}}{\partial\Sigma^{l}_{\alpha}}=-\eta\sum_{j}\mathcal{K}_{j}^{l+1}\sum_{i}\xi_{j\alpha}^{l+1}\hat{\xi}_{i\alpha}^{l}h_{i}^{l},
Δξ^iαl\displaystyle\Delta\hat{\xi}^{l}_{i\alpha} =η𝒞ξ^iαl=ηj𝒦jl+1ξjαl+1Σαlhil.\displaystyle=-\eta\frac{\partial\mathcal{C}}{\partial\hat{\xi}_{i\alpha}^{l}}=-\eta\sum_{j}\mathcal{K}_{j}^{l+1}\xi_{j\alpha}^{l+1}\Sigma_{\alpha}^{l}h_{i}^{l}.

The above learning equations apply to the case of 1L2P case.

Next, we consider the 1L1P case (𝝃l=𝝃^l\bm{{\xi}}^{l}=\bm{\hat{\xi}}^{l}). Apart from the single input pattern for the first layer 𝝃^1\bm{\hat{\xi}}^{1} and the single output pattern for the last layer 𝝃L\bm{\xi}^{L}, two types of pattern in each hidden layer [𝝃l,𝝃^l][\bm{{\xi}}^{l},\bm{\hat{\xi}}^{l}] take the same form, and we denote 𝝃l=𝝃^l=𝚵l\bm{{\xi}}^{l}=\bm{\hat{\xi}}^{l}=\bm{\Xi}^{l}. The expression of 𝒦il\mathcal{K}_{i}^{l} remains unchanged, and we can then update [𝝃^1,𝚺l,𝝃L][\bm{\hat{\xi}}^{1},\bm{\Sigma}^{l},\bm{{\xi}}^{L}] according to Eq. (S4). Next, we give the gradient descent equation for 𝚵l\bm{\Xi}^{l} where l=2,,L1l=2,...,L-1 as follows

ΔΞjαl\displaystyle\Delta\Xi^{l}_{j\alpha} =η𝒞Ξjαl=η𝒦jliΣαl1ξ^iαl1hil1ηi𝒦il+1ξiαl+1Σαlhjl,\displaystyle=-\eta\frac{\partial\mathcal{C}}{\partial\Xi^{l}_{j\alpha}}=-\eta\mathcal{K}_{j}^{l}\sum_{i}\Sigma_{\alpha}^{l-1}\hat{\xi}_{i\alpha}^{l-1}h_{i}^{l-1}-\eta\sum_{i}\mathcal{K}_{i}^{l+1}\xi_{i\alpha}^{l+1}\Sigma_{\alpha}^{l}h_{j}^{l}, (S5)

where two terms contribute to the gradient—the first one comes from the contribution of 𝝃l\bm{\xi}^{l}, while the second one originates from the fact that the same pattern can act as 𝝃^l\bm{\hat{\xi}}^{l}.

To ensure the weighted sum in the pre-activation is independent of the upstream layer width and the number of modes plp^{l}, we choose the initialization scheme such that [𝝃l+1𝚺l(𝝃^l)T]ij𝒪(1Nl)[\bm{{\xi}}^{l+1}\bm{{\Sigma}}^{l}(\bm{\hat{\xi}}^{l})^{\operatorname{T}}]_{ij}\sim\mathcal{O}(\frac{1}{\sqrt{N_{l}}}). This scaling is inspired by studies of Hopfield models Jiang et al. (2021). In practice, we independently and identically sample the initial elements ξiαl+1,Σαl,ξ^jαl\xi^{l+1}_{i\alpha},\Sigma^{l}_{\alpha},\hat{\xi}^{l}_{j\alpha} from the standard Gaussian distribution, and then the weight values are multiplied by a factor of 1/NllnNl1/\sqrt{N_{l}\ln N_{l}}. Note that the number of modes are assumed to be proportional to lnNl\ln N_{l}. But if the number is a constant denoted by PlP^{l}, then the factor could be 1/PlNl1/\sqrt{P^{l}N_{l}}.

Refer to caption
Figure S1: The comparison among the SVD training, MDL (1L2P) and traditional BP in learning performance. The network structure is specified by [784,100,100,10][784,100,100,10] in all cases. The full MDL indicates the MDL with the same number of parameters as that of the SVD training, while the blue dot (pruned SVD) indicates the pruning of the full SVD model 60%60\% (the modes with small |si||s_{i}| ranked in descending order) modes off each layer (except the output layer) to make the consuming parameter amount comparable with that of MDL with p=30p=30.
Refer to caption
Figure S2: The comparison among low-rank decomposition (LRD), MDL (1L2P) and traditional BP. Both decomposition methods use p=30p=30. The network structure is [784,100,100,10][784,100,100,10] in all cases.

Appendix B Comparison to other matrix factorization methods

Here, we compared our MDL method to other matrix factorization methods in learning performance. These other methods include singular value decomposition (SVD), low rank decomposition (LRD) and spectral training Yang et al. (2020); Chicchi et al. (2021).

First, the SVD learning scheme is implemented by decomposing the weight of each layer as

𝑾l=𝑼ldiag(𝒔l)(𝑽l),\bm{W}^{l}=\bm{U}^{l}\operatorname{diag}(\bm{s}^{l})(\bm{V}^{l})^{\top}, (S6)

where the diagonal matrix contains min(Nl,Nl+1)\min(N_{l},N_{l+1}) non-zero elements in the diagonal, and the elements of 𝐬l\mathbf{s}^{l} is constrained to be positive. The orthogonality is forced by two regularization terms as

L(𝑼,𝒔,𝑽)=LT+λol=1DLo(𝑼l,𝑽l)+λsl=1DLs(𝒔l),L(\bm{U},\bm{s},\bm{V})=L_{T}+\lambda_{o}\sum_{l=1}^{D}L_{o}\left(\bm{U}_{l},\bm{V}_{l}\right)+\lambda_{s}\sum_{l=1}^{D}L_{s}\left(\bm{s}_{l}\right), (S7)

where LTL_{T} is the original training loss, Lo(𝑼,𝑽)=1r2(𝑼T𝑼𝑰F2+𝑽T𝑽𝑰F2)L_{o}(\bm{U},\bm{V})=\frac{1}{r^{2}}\left(\left\|\bm{U}^{T}\bm{U}-\bm{I}\right\|_{F}^{2}+\left\|\bm{V}^{T}\bm{V}-\bm{I}\right\|_{F}^{2}\right), and Ls(𝒔)=𝒔1𝒔2=i|si|isi2L_{s}(\bm{s})=\frac{\|\bm{s}\|_{1}}{\|\bm{s}\|_{2}}=\frac{\sum_{i}\left|s_{i}\right|}{\sqrt{\sum_{i}s_{i}^{2}}}. rr is the rank of 𝑼\bm{U} and 𝑽\bm{V}, F\|\bullet\|_{F} denotes the Frobenius norm of a matrix. The regularization term LoL_{o} forces 𝑼\bm{U} and 𝑽\bm{V} to be orthogonal, while LsL_{s} adjusts the sparsity level of 𝒔\bm{s}. The gradients for each set of parameters are derived below,

Lo𝑼\displaystyle\frac{\partial L_{o}}{\partial\bm{U}} =4r2(𝑼𝑼𝑰)×𝑼,\displaystyle=\frac{4}{r^{2}}\left(\bm{U}^{\top}\bm{U}-\bm{I}\right)^{\top}\times\bm{U}^{\top}, (S8)
Lo𝑽\displaystyle\frac{\partial L_{o}}{\partial\bm{V}} =4r2(𝑽𝑽𝑰)×𝑽,\displaystyle=\frac{4}{r^{2}}\left(\bm{V}^{\top}\bm{V}-\bm{I}\right)^{\top}\times\bm{V}^{\top},
Lssi\displaystyle\frac{\partial L_{s}}{\partial s_{i}} =sign(si)isi2i|si|(isi2)12siisi2.\displaystyle=\frac{\operatorname{sign}(s_{i})\sqrt{\sum_{i}s_{i}^{2}}-\sum_{i}|s_{i}|(\sum_{i}s_{i}^{2})^{-\frac{1}{2}}s_{i}}{\sum_{i}s_{i}^{2}}.

For comparison, we carried out the SVD learning, with Lo=100L_{o}=100, Ls=0.0L_{s}=0.0, and Lo=100L_{o}=100, Ls=5.0L_{s}=5.0, as shown in Fig. S1. We remark that the training cost is larger for SVD models, which can be calculated as l[Nl×Nl+1+min(Nl,Nl+1)2+min(Nl,Nl+1)]\sum_{l}[N_{l}\times N_{l+1}+\min{(N_{l},N_{l+1})}^{2}+\min{(N_{l},N_{l+1})}]. Taking [784,100,100,10][784,100,100,10] as an example, the learning needs 109710109710 parameters in total. However, for our MDL with p=30p=30 which already reaches the traditional BP performance, the learning only needs 3591035910 parameters (but traditional BP needs 8940089400 parameters). In simulations, we prune the full SVD model 60%60\% ((the modes with small |si||s_{i}| ranked in descending order)) modes off each layer (except the output layer) to make the number of trainable parameters comparable with that of MDL with p=30p=30. We conclude that the MDL consumes less parameters, yet produces rapid learning with even better performances.

Refer to caption
Figure S3: The comparison among the spectral learning, MDL (1L2P) and traditional BP. The network structure is [784,100,100,10][784,100,100,10] in all cases.

Next, we fix 𝚺=𝕀\bm{\Sigma}=\mathbb{I} in our MDL, and this reduced form is called the low rank decomposition as follows,

𝐖l=𝝃^l(𝝃𝒍+𝟏).\mathbf{W}^{l}=\bm{\hat{\xi}}^{l}(\bm{\xi^{l+1}})^{\top}. (S9)

In the simulation, we set p=30p=30. We can see in Fig. S2 that the performance of the LRD is much worse than that of MDL and traditional BP.

For the recently proposed spectral learning Chicchi et al. (2021), a carefully-designed transformation matrix 𝐀k\mathbf{A}^{k} (an N×NN\times N matrix, NN is the total number of units in the network, and kk is a layer index) is used with a spectral decomposition. The eigenvalues and the associated basis are optimized. However, this training performs worse compared to our MDL in the examples shown in Fig. S3.

Appendix C Ranking the modes according to the importance matrix

Here, we rank the modes according to the diagonal of the importance matrix, rather than the τ\tau measure. We found that these two ranking schemes lead to qualitatively identical results. Removing the most important modes (according to either the τ\tau measure or the importance score) will significantly impair the generalization ability of the network. Details are illustrated in Fig. S4. The non-smooth behavior can be attributed to the existence of mode-contribution gap, i.e., the most important modes (<15%<15\% for the τ\tau measure; <30%<30\% for the Σ\Sigma measure) dominate the generalization capability of the network, while other modes capture irrelevant noise in the data.

Refer to caption
Figure S4: Ranking modes. The network structure is [784,100,100,10][784,100,100,10], and we analyze the 1L1P case here with p=70p=70. The marked percentage in the inset indicates the generalization accuracy after removing the corresponding side of modes in the hidden layer. The piecewise power law behavior is retained for both types of ranking.

Appendix D The qualitative behavior of the MDL does not change with the regularization strength or the hidden-layer width

Further, our MDL is in essence a matrix factorization. Therefore, the pattern and importance matrices are not unique. However, in practice, we impose the 2\ell_{2} norm level for these patterns and importance scores. In fact, we find the intriguing properties of the MDL in deep learning do not change with the regularization strength of the 2\ell_{2} norm (denoted as λ\lambda, see Fig. S8). Figure S5 shows an example for the behavior of the optimal number of modes versus hidden-layer width, while Fig. S6 shows that the piecewise power law behavior of the τ\tau measure does not change with the regularization strength. In addition, the piecewise power law behavior of the τ\tau measure does not change with the hidden-layer width as well (Fig. S7).

Refer to caption
Figure S5: The piecewise increasing behavior of the least pp with the hidden-layer width. The network has structure [784,N,N,10][784,N,N,10], and we vary NN to get the corresponding least pp, which is defined as the least number of modes that MDL needs to reach the performance of the traditional BP. 1L2P case is considered. (a) and (b) are obtained under different regularization strengths. (a) λ=103\lambda=10^{-3}. (b) λ=104\lambda=10^{-4}.
Refer to caption
Figure S6: The piecewise power law behavior of the τ\tau measure does not change with the regularization strength λ\lambda in the 1L1P case. The network structure is specified by [784,100,100,10][784,100,100,10]. (a) λ=0.01\lambda=0.01. (b) λ=0.001\lambda=0.001. (c) λ=0.0001\lambda=0.0001.
Refer to caption
Figure S7: The piecewise power law behavior of the τ\tau measure does not change with the hidden-layer width in the 1L1P case. The network structure is specified by [784,N,N,10][784,N,N,10]. (a) N=100N=100. (b) N=150N=150. (c) N=200N=200. (d) N=300N=300.
Refer to caption
Figure S8: The 2\ell_{2} norm of parameters [𝝃,𝝃^,𝚺][\bm{\xi},\bm{\hat{\xi}},\bm{\Sigma}] under three regularization strengths— λ=102,103,104\lambda=10^{-2},10^{-3},10^{-4}. The network structure is specified by [784,100,100,10][784,100,100,10] and p=70p=70 for all layers, where the 1L1P case is considered (the 1L2P case yields qualitatively the same results). (a, b, c) are plotted for the first layer, the second layer, and the output layer, respectively.

Appendix E Subspace overlap of layered response to pairs of stimuli

In this section, we provide details of estimating the average degree of correlation between neural responses 𝐡l\mathbf{h}^{l} to pairs of different input stimuli (e.g., one stimulus contains images of the same class). The covariance of neural response in each layer to the stimulus can be diagonalized to specify a low-dimensional subspace. The subspace is spanned by the first KK principal components. The subspace overlap can then be evaluated via the cosine of the principal angle between these two subspaces corresponding to two different stimuli. In practice, for neural responses in each layer to the stimulus s1s_{1} (e.g., many images of digit 0), we first identify the first KK principal components of the covariance of 𝐡1\mathbf{h}_{1}^{\ell}, which explains over 80%80\% of the total variance, and then reorganize the eigenvectors to an N×KN_{\ell}\times K matrix, namely 𝐐(s1)\mathbf{Q}^{\ell}(s_{1}). We repeat this procedure for another stimulus s2s_{2}, and get another matrix 𝐐(s2)\mathbf{Q}^{\ell}(s_{2}). Therefore, the columns of 𝐐(s1)\mathbf{Q}^{\ell}(s_{1}) and 𝐐(s2)\mathbf{Q}^{\ell}(s_{2}) span two subspaces corresponding to the neural responses to s1s_{1} and s2s_{2} respectively. The cosine of the principal angle between these two subspaces is calculated as follows Bjoerck and Golub (1973)

cosθp(s1,s2)=σmax(𝐐(s1)T𝐐(s2)),\cos\theta_{p}\left(s_{1},s_{2}\right)=\sigma_{\max}\left({\mathbf{Q}^{\ell}(s_{1})}^{\rm T}\mathbf{Q}^{\ell}(s_{2})\right), (S10)

where σmax(𝐁)\sigma_{\max}(\mathbf{B}) denotes the largest singular value of the matrix 𝐁\mathbf{B}. In simulations, we consider the classification task of the MNIST dataset, where ten classes of digits are fed into a seven-layer neural network. Specifically, we choose KK^{\ell} that can explain over 80%80\% of the total variance for each stimulus and each layer, and therefore the value of KK^{\ell} varies with layer and input stimulus.

In the main text, we observe a mild increase of the subspace overlap in deep layers. Here, as shown in Fig. S9, we link this behavior to the saturation of the test performance with increasing number of layers and network width. In addition, the task we consider is relatively simple, and thus three hidden layers (five layers in total) are sufficient to classify the digits with a high accuracy. The subspace overlap under the MDL setting thus suggests a consistent way to determine the optimal number of layers and the network width in practical training.

Refer to caption
Figure S9: The averaged subspace overlap versus layers. Different number of layers with different hidden-layer widths are considered. The results are averaged over five independent trainings. The inset shows the corresponding test accuracy changing with the hidden-layer width.
Refer to caption
Figure S10: A simple illustration of the teacher-student setup with i.i.d. standard Gaussian input. The teacher network has kk hidden nodes and pp^{*} modes, while the student network has mm hidden nodes and pp modes. The goal of the student network is to predict the labels generated by the teacher network, minimizing the mean-squared error. The weights to the output layer are set to one in the linear readout for both teacher and student networks (𝐯=𝐈m,𝐯=𝐈k\mathbf{v}=\mathbf{I}_{m},\mathbf{v}^{*}=\mathbf{I}_{k}, and 𝐈x\mathbf{I}_{x} is an all-one vector of length xx).

Appendix F Mean-field predictions of on-line learning dynamics

In this section, we give a detailed derivation of the mean-field ordinary differential equations for the on-line dynamics. A sketch of the toy model setting is shown in Fig. S10. The label for each sample 𝐱ν\mathbf{x}^{\nu} (i.i.d. standard Gaussian variable) is generated by the teacher network,

yν=f(𝝀ν)=1kr=1kσ([𝝃𝚺(𝝃^)T]r𝐱νd)=1kr=1kσ(λrν),y^{\nu}=f(\bm{\lambda}^{*\nu})=\frac{1}{k}\sum_{r=1}^{k}\sigma\left(\frac{[\bm{{\xi}}^{*}\mathbf{{\Sigma}^{*}}(\hat{\bm{\xi}}^{*})^{\operatorname{T}}]_{r}\mathbf{x}^{\nu}}{\sqrt{d}}\right)=\frac{1}{k}\sum_{r=1}^{k}\sigma\left(\lambda_{r}^{*\nu}\right), (S11)

where [𝐀]r[\mathbf{A}]_{r} denotes the rr-th row of the matrix 𝐀\mathbf{A}, and λrν=[𝝃𝚺(𝝃^)T]r𝐱ν/d\lambda_{r}^{*\nu}=[\bm{{\xi}}^{*}\bm{{\Sigma}}^{*}(\hat{\bm{\xi}}^{*})^{\operatorname{T}}]_{r}\mathbf{x}^{\nu}/{\sqrt{d}} represents the rr-th element of the teacher local field vector 𝝀νk\bm{\lambda}^{*\nu}\in\mathbb{R}^{k}. To ensure the local field is independent of the input dimension, we choose the initialization scheme for the teacher network such that [𝝃𝚺(𝝃^)T]ij𝒪(1)[\bm{{\xi}}^{*}\bm{\Sigma}^{*}(\hat{\bm{\xi}}^{*})^{\operatorname{T}}]_{ij}\sim\mathcal{O}(1). More precisely, we set the elements ξik,Σk,ξ^jk\xi^{*}_{ik},\Sigma^{*}_{k},\hat{\xi}^{*}_{jk} to be independent standard Gaussian variables, and then multiply the weight values by a factor of 1lnd\frac{1}{\sqrt{\ln d}} for logarithmic increasing number of modes. This scaling ensures that the magnitude of the weight values is of the order one. Different forms of transfer function σ()\sigma(\cdot) can be considered, but we choose the error function for the simplicity of the following theoretical analysis. The prediction of the label by the student network for a new sample 𝐱\mathbf{x} is given by

f^(𝐱,𝝃^,𝚺,𝝃)=1mr=1mσ([𝝃𝚺(𝝃^)T]r𝐱d)=1mr=1mσ(λr),\hat{f}(\mathbf{x},\bm{\hat{\xi}},\mathbf{{\Sigma}},\bm{{\xi}})=\frac{1}{m}\sum_{r=1}^{m}\sigma\left(\frac{[\bm{\xi}\bm{{\Sigma}}(\hat{\bm{\xi}})^{\operatorname{T}}]_{r}\mathbf{x}}{\sqrt{d}}\right)=\frac{1}{m}\sum_{r=1}^{m}\sigma\left(\lambda_{r}\right), (S12)

where λr\lambda_{r} denotes the rr-th component of the student local field 𝝀=𝝃𝚺𝝃^T𝐱\bm{\lambda}=\bm{\xi}\bm{\Sigma}\hat{\bm{\xi}}^{\operatorname{T}}\mathbf{x}. The student network has mm hidden nodes and pp patterns. For simplicity, we assume m=km=k, p=pp=p^{*}, and only the pattern 𝝃^\hat{\bm{\xi}} is learned.

Training the student network with the one-pass gradient descent (on-line learning) directly minimizes the following mean-squared error (MSE):

MSE(𝝀,𝝀)=12(f^(𝝀)f(𝝀))2,\displaystyle\ell_{MSE}(\bm{\lambda},\bm{\lambda}^{*})=\frac{1}{2}\langle\left(\hat{f}(\bm{\lambda})-f\left(\bm{\lambda}^{*}\right)\right)^{2}\rangle, (S13)

where \langle\cdot\rangle indicates the average over {𝐱,𝐲}\{\mathbf{x,y}\} that can be replaced by the average over local fields. For the Gaussian data (𝐱)=𝒩(𝐱|0,𝟙)\mathbb{P}(\mathbf{x})=\mathcal{N}(\mathbf{x}|0,\mathds{1}), the dynamics of MSE\ell_{MSE} can be completely determined by the following order parameters: 𝐐ν𝔼𝐱,𝐲(𝐱,𝐲)[𝝀ν(𝝀ν)T]=1d𝝃ν𝚺ν(𝝃^ν)T𝝃^ν𝚺ν(𝝃ν)T\mathbf{Q}^{\nu}\equiv\mathbb{E}_{\mathbf{x},\mathbf{y}\sim\mathbb{P}(\mathbf{x},\mathbf{y})}\left[\bm{\lambda}^{\nu}(\bm{\lambda}^{\nu})^{\operatorname{T}}\right]=\frac{1}{d}\bm{\xi}^{\nu}\bm{\Sigma}^{\nu}(\bm{\hat{\xi}}^{\nu})^{\operatorname{T}}\bm{\hat{\xi}}^{\nu}\bm{\Sigma}^{\nu}(\bm{\xi}^{\nu})^{\operatorname{T}}, 𝐌ν𝔼𝐱,𝐲(𝐱,𝐲)[𝝀ν(𝝀ν)T]=1d𝝃ν𝚺ν(𝝃^ν)T𝝃^ν𝚺ν(𝝃ν)T\mathbf{M}^{\nu}\equiv\mathbb{E}_{\mathbf{x},\mathbf{y}\sim\mathbb{P}(\mathbf{x},\mathbf{y})}\left[\bm{\lambda}^{\nu}(\bm{\lambda}^{*\nu})^{\operatorname{T}}\right]=\frac{1}{d}\bm{\xi}^{\nu}\bm{\Sigma}^{\nu}(\bm{\hat{\xi}}^{\nu})^{\operatorname{T}}\bm{\hat{\xi}}^{*\nu}\bm{\Sigma}^{*\nu}(\bm{\xi}^{*\nu})^{\operatorname{T}}, and 𝐏ν𝔼𝐱,𝐲(𝐱,𝐲)[𝝀ν(𝝀ν)T]=1d𝝃ν𝚺ν(𝝃^ν)T𝝃^ν𝚺ν(𝝃ν)T\mathbf{P}^{\nu}\equiv\mathbb{E}_{\mathbf{x},\mathbf{y}\sim\mathbb{P}(\mathbf{x},\mathbf{y})}\left[\bm{\lambda}^{*\nu}(\bm{\lambda}^{*\nu})^{\operatorname{T}}\right]=\frac{1}{d}\bm{\xi}^{*\nu}\bm{\Sigma}^{*\nu}(\bm{\hat{\xi}}^{*\nu})^{\operatorname{T}}\bm{\hat{\xi}}^{*\nu}\bm{\Sigma}^{*\nu}(\bm{\xi}^{*\nu})^{\operatorname{T}}. The corresponding matrix elements are denoted as qjlν[𝐐ν]jlq_{jl}^{\nu}\equiv\left[\mathbf{Q}^{\nu}\right]_{jl}, mjrν[𝐌ν]jrm_{jr}^{\nu}\equiv\left[\mathbf{M}^{\nu}\right]_{jr} and ρrs[𝐏]rs\rho_{rs}\equiv[\mathbf{P}]_{rs}. Then we can define the local-field covariance matrix 𝛀ν(k+m)×(k+m)\mathbf{\Omega}^{\nu}\in\mathbb{R}^{(k+m)\times(k+m)} at time step ν\nu as follows,

𝛀ν[𝐐ν𝐌ν(𝐌ν)T𝐏],\mathbf{\Omega}^{\nu}\equiv\left[\begin{array}[]{cc}\mathbf{Q}^{\nu}&\mathbf{M}^{\nu}\\ (\mathbf{M}^{\nu})^{\operatorname{T}}&\mathbf{P}\end{array}\right], (S14)

where 𝐏\mathbf{P} is fixed by definition (parameters of the teacher network are quenched), the sample index ν\nu is also the time step in the on-line learning setting, and the evolution of other order parameters 𝐐ν\mathbf{Q}^{\nu} and 𝐌ν\mathbf{M}^{\nu} is driven by the gradient flow of the mode parameters 𝜽l=(𝝃^l,𝚺l,𝝃l+1)\bm{\theta}^{l}=(\hat{\bm{\xi}}^{l},\bm{\Sigma}^{l},\bm{{\xi}}^{l+1}). The loss is completely determined by the evolving order parameters,

MSE(𝛀)=12𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[(f^(𝝀)f(𝝀))2]\displaystyle\ell_{\mathrm{MSE}}(\bm{\Omega})=\frac{1}{2}\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\bm{\Omega}\right)}\left[\left(\hat{f}(\bm{\lambda})-f\left(\bm{\lambda}^{*}\right)\right)^{2}\right] (S15)
=t(𝐏)+s(𝐐)+st(𝐏,𝐐,𝐌),\displaystyle={\ell}_{\mathrm{t}}(\mathbf{P})+{\ell}_{\mathrm{s}}(\mathbf{Q})+\ell_{\mathrm{st}}(\mathbf{P},\mathbf{Q},\mathbf{M}),\

where

t(𝐏)𝔼𝝀𝒩(𝝀0,𝐏)[f(𝝀)2]=1k2r,s=1k𝔼𝝀𝒩(𝝀0,𝐏)[σ(λr)σ(λs)],\displaystyle{\ell}_{\mathrm{t}}(\mathbf{P})\equiv\mathbb{E}_{\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda}^{*}\mid 0,\mathbf{P}\right)}\left[f\left(\bm{\lambda}^{*}\right)^{2}\right]=\frac{1}{k^{2}}\sum_{r,s=1}^{k}\mathbb{E}_{\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda}^{*}\mid 0,\mathbf{P}\right)}\left[\sigma\left({\lambda}_{r}^{*}\right)\sigma\left({\lambda}_{s}^{*}\right)\right], (S16)
s(𝐐)𝔼𝝀𝒩(𝝀0,𝐐)[f^(𝝀)2]=1m2j,l=1m𝔼𝝀𝒩(𝝀0,𝐐)[σ(λj)σ(λl)],\displaystyle{\ell}_{\mathrm{s}}(\mathbf{Q})\equiv\mathbb{E}_{\bm{\lambda}\sim\mathcal{N}(\bm{\lambda}\mid 0,\mathbf{Q})}\left[\hat{f}(\bm{\lambda})^{2}\right]=\frac{1}{m^{2}}\sum_{j,l=1}^{m}\mathbb{E}_{\bm{\lambda}\sim\mathcal{N}(\bm{\lambda}\mid 0,\mathbf{Q})}\left[\sigma\left(\lambda_{j}\right)\sigma\left(\lambda_{l}\right)\right]\text{, }
st(𝐏,𝐐,𝐌)𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[f^(𝝀)f(𝝀)]=2mkj=1mr=1k𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[σ(λj)σ(λr)].\displaystyle\ell_{\mathrm{st}}(\mathbf{P},\mathbf{Q},\mathbf{M})\equiv\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\hat{f}(\bm{\lambda})f\left(\bm{\lambda}^{*}\right)\right]=-\frac{2}{mk}\sum_{j=1}^{m}\sum_{r=1}^{k}\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\sigma\left(\lambda_{j}\right)\sigma\left(\lambda_{r}^{*}\right)\right].

To proceed, we define the integral I2=𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[σ(𝝀α)σ(𝝀β)]I_{2}=\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\sigma\left(\bm{\lambda}^{\alpha}\right)\sigma\left(\bm{\lambda}^{\beta}\right)\right], which has an analytic form for σ(x)=erf(x/2)\sigma(x)=\operatorname{erf}(x/\sqrt{2}) as follows,

𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[σ(𝝀α)σ(𝝀β)]=2πarcsin(Ω12αβ(1+Ω11αβ)(1+Ω22αβ)),\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\sigma\left(\bm{\lambda}^{\alpha}\right)\sigma\left(\bm{\lambda}^{\beta}\right)\right]=\frac{2}{\pi}\arcsin\left(\frac{\Omega_{12}^{\alpha\beta}}{\sqrt{\left(1+\Omega_{11}^{\alpha\beta}\right)\left(1+\Omega_{22}^{\alpha\beta}\right)}}\right), (S17)

where Ωijαβ\Omega_{ij}^{\alpha\beta} denotes the element of the overlap matrix for 𝝀α\bm{\lambda}^{\alpha} and 𝝀β\bm{\lambda}^{\beta}, in which α\alpha and β\beta indicate the attributes of the network—teacher or student. Therefore, the generalization error can be estimated as follows,

MSE(𝛀)=1k2r,s=1k1πarcsin(ρrs(1+ρrr)(1+ρss))+1m2j,l=1m1πarcsin(qjl(1+qjj)(1+qll))2mkj=1mr=1k1πarcsin(mjr(1+qjj)(1+ρrr)).\displaystyle\begin{aligned} \ell_{MSE}(\mathbf{\Omega})&=\frac{1}{k^{2}}\sum_{r,s=1}^{k}\frac{1}{\pi}\arcsin\left(\frac{\rho_{rs}}{\sqrt{\left(1+\rho_{rr}\right)\left(1+\rho_{ss}\right)}}\right)+\frac{1}{m^{2}}\sum_{j,l=1}^{m}\frac{1}{\pi}\arcsin\left(\frac{q_{jl}}{\sqrt{\left(1+q_{jj}\right)\left(1+q_{ll}\right)}}\right)\\ &-\frac{2}{mk}\sum_{j=1}^{m}\sum_{r=1}^{k}\frac{1}{\pi}\arcsin\left(\frac{m_{jr}}{\sqrt{\left(1+q_{jj}\right)\left(1+\rho_{rr}\right)}}\right).\end{aligned} (S18)

We next consider the evolution of the order parameters, which involves only the update of 𝝃^\bm{\hat{\xi}} in our toy model setting. Therefore, we derive the evolution of order parameters qjlνq^{\nu}_{jl} and mjrνm^{\nu}_{jr} based on the gradient of 𝝃^\hat{\bm{\xi}}: Δξ^jα=ηf^(𝝀)f(𝝀)mdiσ(λi)ξiαΣαxjν\Delta\hat{\xi}_{j\alpha}=-\eta\frac{\hat{f}(\bm{\lambda})-{f}(\bm{\lambda}^{*})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\alpha}\Sigma_{\alpha}x_{j}^{\nu}. In the high-dimensional limit (dd\to\infty), we use the self-averaging property of the order parameters considering the disorder average over the input data distribution Biehl and Schwarze (1995); Saad and Solla (1995); Goldt et al. (2019). Then we have the following expressions,

qjlν+1qjlν\displaystyle q_{jl}^{\nu+1}-q_{jl}^{\nu} =1d𝔼[n,α,βξjαΣα(ξ^nα+Δξ^nα)(ξ^nβ+Δξ^nβ)Σβξlβ]qjlν,\displaystyle=\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}(\hat{\xi}_{n\alpha}+\Delta\hat{\xi}_{n\alpha})(\hat{\xi}_{n\beta}+\Delta\hat{\xi}_{n\beta})\Sigma_{\beta}\xi_{l\beta}\right]-q_{jl}^{\nu}, (S19)
mjrν+1mjrν\displaystyle m_{jr}^{\nu+1}-m_{jr}^{\nu} =1d𝔼[n,α,βξjαΣα(ξ^nα+Δξ^nα)ξ^nβΣβξrβ]mjrν,\displaystyle=\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}(\hat{\xi}_{n\alpha}+\Delta\hat{\xi}_{n\alpha})\hat{\xi}^{*}_{n\beta}\Sigma^{*}_{\beta}\xi^{*}_{r\beta}\right]-m_{jr}^{\nu},

where the expectation is carried out with respect to the data distribution.

Inserting the update equation of 𝝃^\hat{\bm{\xi}} into the equation of the order parameter qjlq_{jl}, we get

qjlν+1qjlν\displaystyle q_{jl}^{\nu+1}-q_{jl}^{\nu} =1d𝔼[n,α,βξjαΣα(ξ^nα+ηf(𝝀)f^(𝝀)mdiσ(λi)ξiαΣαxnν)(ξ^nβ+ηf(𝝀)f^(𝝀)mdiσ(λi)ξiβΣβxnν)Σβξlβ]\displaystyle=\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}(\hat{\xi}_{n\alpha}+\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\alpha}\Sigma_{\alpha}x_{n}^{\nu})(\hat{\xi}_{n\beta}+\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\beta}\Sigma_{\beta}x_{n}^{\nu})\Sigma_{\beta}\xi_{l\beta}\right] (S20)
qjlν,\displaystyle-q_{jl}^{\nu},
=1d𝔼[n,α,βξjαΣαξ^nαηf(𝝀)f^(𝝀)mdiσ(λi)ξiβΣβxnνΣβξlβ]\displaystyle=\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}\hat{\xi}_{n\alpha}\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\beta}\Sigma_{\beta}x_{n}^{\nu}\Sigma_{\beta}\xi_{l\beta}\right]
+1d𝔼[n,α,βξjαΣαηf(𝝀)f^(𝝀)mdiσ(λi)ξiαΣαxnνξ^nβΣβξlβ]\displaystyle+\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\alpha}\Sigma_{\alpha}x_{n}^{\nu}\hat{\xi}_{n\beta}\Sigma_{\beta}\xi_{l\beta}\right]
+1d𝔼[n,α,βξjαΣαηf(𝝀)f^(𝝀)mdiσ(λi)ξiαΣαxnν(ηf(𝝀)f^(𝝀)mdiσ(λi)ξiβΣβxnν)Σβξlβ],\displaystyle+\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\alpha}\Sigma_{\alpha}x_{n}^{\nu}(\eta\frac{{f}(\bm{\lambda}^{*})-\hat{f}(\bm{\lambda})}{m\sqrt{d}}\sum_{i}\sigma^{\prime}(\lambda_{i})\xi_{i\beta}\Sigma_{\beta}x_{n}^{\nu})\Sigma_{\beta}\xi_{l\beta}\right],

where we have applied the definition of qjlν=1dn,α,βξjαΣαξ^nαξ^nβΣβξlβq_{jl}^{\nu}=\frac{1}{d}\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}\hat{\xi}_{n\alpha}\hat{\xi}_{n\beta}\Sigma_{\beta}\xi_{l\beta} to derive the second equality. Considering the definition of f^(𝝀)\hat{f}(\bm{\lambda}), f(𝝀){f}(\bm{\lambda}^{*}), 𝝀\bm{\lambda}, and 𝝀\bm{\lambda}^{*}, we recast Eq. (S20) as follows,

qjlν+1qjlν\displaystyle q_{jl}^{\nu+1}-q_{jl}^{\nu} =ηdkm𝔼[β,r,iλjσ(λrν)σ(λi)ξiβΣβ2ξlβ]ηdm2𝔼[β,r^,iλjσ(λr^ν)σ(λi)ξiβΣβ2ξlβ]\displaystyle=\frac{\eta}{dkm}\mathbb{E}\left[\sum_{\beta,r,i}\lambda_{j}\sigma\left(\lambda_{r}^{*\nu}\right)\sigma^{\prime}(\lambda_{i})\xi_{i\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right]-\frac{\eta}{dm^{2}}\mathbb{E}\left[\sum_{\beta,\hat{r},i}\lambda_{j}\sigma\left(\lambda_{\hat{r}}^{\nu}\right)\sigma^{\prime}(\lambda_{i})\xi_{i\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right] (S21)
+ηdkm𝔼[α,r,iξjαΣα2ξiαλlσ(λrν)σ(λi)]ηdm2𝔼[α,r^,iξjαΣα2ξiαλlσ(λr^ν)σ(λi)]\displaystyle+\frac{\eta}{dkm}\mathbb{E}\left[\sum_{\alpha,r,i}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}\lambda_{l}\sigma\left(\lambda_{r}^{*\nu}\right)\sigma^{\prime}(\lambda_{i})\right]-\frac{\eta}{dm^{2}}\mathbb{E}\left[\sum_{\alpha,\hat{r},i}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}\lambda_{l}\sigma\left(\lambda_{\hat{r}}^{\nu}\right)\sigma^{\prime}(\lambda_{i})\right]
+η2dm4𝔼[α,β,i,a,r,r^ξjαΣα2ξiασ(λi)σ(λa)σ(λrν)σ(λr^ν)ξaβΣβ2ξlβ]\displaystyle+\frac{\eta^{2}}{dm^{4}}\mathbb{E}\left[\sum_{\alpha,\beta,i,a,r,\hat{r}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}\sigma^{\prime}(\lambda_{i})\sigma^{\prime}(\lambda_{a})\sigma\left(\lambda_{r}^{*\nu}\right)\sigma\left(\lambda_{\hat{r}}^{*\nu}\right)\xi_{{a\beta}}\Sigma_{\beta}^{2}\xi_{l\beta}\right]
+η2dm2k2𝔼[α,β,i,a,r,r^ξjαΣα2ξiασ(λi)σ(λa)σ(λrν)σ(λr^ν)ξaβΣβ2ξlβ]\displaystyle+\frac{\eta^{2}}{dm^{2}k^{2}}\mathbb{E}\left[\sum_{\alpha,\beta,i,a,r,\hat{r}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}\sigma^{\prime}(\lambda_{i})\sigma^{\prime}(\lambda_{a})\sigma\left(\lambda_{r}^{\nu}\right)\sigma\left(\lambda_{\hat{r}}^{\nu}\right)\xi_{{a\beta}}\Sigma_{\beta}^{2}\xi_{l\beta}\right]
2η2dm3k𝔼[α,β,i,a,r,r^ξjαΣα2ξiασ(λi)σ(λa)σ(λrν)σ(λr^ν)ξaβΣβ2ξlβ,].\displaystyle-2\frac{\eta^{2}}{dm^{3}k}\mathbb{E}\left[\sum_{\alpha,\beta,i,a,r,\hat{r}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}\sigma^{\prime}(\lambda_{i})\sigma^{\prime}(\lambda_{a})\sigma\left(\lambda_{r}^{*\nu}\right)\sigma\left(\lambda_{\hat{r}}^{\nu}\right)\xi_{{a\beta}}\Sigma_{\beta}^{2}\xi_{l\beta},\right].

To proceed, we have to estimate the integral defined by I3(α,β,η)=𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[σ(𝝀α)𝝀βσ(𝝀η)]=2πΩ23αβη(1+Ω11αβη)Ω12αβηΩ13αβη(1+Ω11αβη)(1+Ω11αβη)(1+Ω33αβη)(Ω13αβη)2I_{3}(\alpha,\beta,\eta)=\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\sigma^{\prime}\left(\bm{\lambda}^{\alpha}\right)\bm{\lambda}^{\beta}\sigma\left(\bm{\lambda}^{\eta}\right)\right]=\frac{2}{\pi}\frac{\Omega_{23}^{\alpha\beta\eta}\left(1+\Omega_{11}^{\alpha\beta\eta}\right)-\Omega_{12}^{\alpha\beta\eta}\Omega_{13}^{\alpha\beta\eta}}{\left(1+\Omega_{11}^{\alpha\beta\eta}\right)\sqrt{\left(1+\Omega_{11}^{\alpha\beta\eta}\right)\left(1+\Omega_{33}^{\alpha\beta\eta}\right)-\left(\Omega_{13}^{\alpha\beta\eta}\right)^{2}}} for our transfer function σ(x)=erf(x/2)\sigma(x)=\operatorname{erf}(x/\sqrt{2}), where Ωijαβη\Omega_{ij}^{\alpha\beta\eta} denotes the element of the field-covariance matrix 𝛀αβη\mathbf{\Omega}^{\alpha\beta\eta}. We also have to estimate the second integral defined by I4(i,j,k,l)=𝔼𝝀,𝝀𝒩(𝝀,𝝀0,𝛀)[σ(λi)σ(λj)σ(λk)σ(λl)]I_{4}(i,j,k,l)=\mathbb{E}_{\bm{\lambda},\bm{\lambda}^{*}\sim\mathcal{N}\left(\bm{\lambda},\bm{\lambda}^{*}\mid 0,\mathbf{\Omega}\right)}\left[\sigma^{\prime}(\lambda_{i})\sigma^{\prime}(\lambda_{j})\sigma(\lambda_{k})\sigma(\lambda_{l})\right], which has a closed form as

I4(i,j,k,l)=4π21Ω¯0ijklarcsin(Ω¯1ijklΩ¯2ijklΩ¯3ijkl),I_{4}(i,j,k,l)=\frac{4}{\pi^{2}}\frac{1}{\sqrt{\bar{\Omega}_{0}^{ijkl}}}\arcsin\left(\frac{\bar{\Omega}_{1}^{ijkl}}{\sqrt{\bar{\Omega}_{2}^{ijkl}\bar{\Omega}_{3}^{ijkl}}}\right), (S22)

where

Ω¯0ijkl(1+Ω11ijkl)(1+Ω22ijkl)(Ω12ijkl)2,Ω¯1ijklΩ¯0ijklΩ34ijklΩ23ijklΩ24ijkl(1+Ω11ijkl)Ω13ijklΩ14ijkl(1+Ω22ijkl)+Ω12ijklΩ13ijklΩ24ijkl+Ω12ijklΩ14ijklΩ23ijkl,Ω¯2ijklΩ¯0ijkl(1+Ω44ijkl)(Ω24ijkl)2(1+Ω11ijkl)(Ω13ijkl)2(1+Ω22ijkl)+2Ω12ijklΩ13ijklΩ23ijkl,Ω¯3ijklΩ¯0ijkl(1+Ω44ijkl)(Ω24ijkl)2(1+Ω11ijkl)(Ω14ijkl)2(1+Ω22ijkl)+2Ω12ijklΩ14ijklΩ24ijkl.\begin{array}[]{c}\bar{\Omega}_{0}^{ijkl}\equiv\left(1+\Omega_{11}^{ijkl}\right)\left(1+\Omega_{22}^{ijkl}\right)-\left(\Omega_{12}^{ijkl}\right)^{2},\\ \bar{\Omega}_{1}^{ijkl}\equiv\bar{\Omega}_{0}^{ijkl}\Omega_{34}^{ijkl}-\Omega_{23}^{ijkl}\Omega_{24}^{ijkl}\left(1+\Omega_{11}^{ijkl}\right)-\Omega_{13}^{ijkl}\Omega_{14}^{ijkl}\left(1+\Omega_{22}^{ijkl}\right)\\ +\Omega_{12}^{ijkl}\Omega_{13}^{ijkl}\Omega_{24}^{ijkl}+\Omega_{12}^{ijkl}\Omega_{14}^{ijkl}\Omega_{23}^{ijkl},\\ \bar{\Omega}_{2}^{ijkl}\equiv\bar{\Omega}_{0}^{ijkl}\left(1+\Omega_{44}^{ijkl}\right)-\left(\Omega_{24}^{ijkl}\right)^{2}\left(1+\Omega_{11}^{ijkl}\right)-\left(\Omega_{13}^{ijkl}\right)^{2}\left(1+\Omega_{22}^{ijkl}\right)+2\Omega_{12}^{ijkl}\Omega_{13}^{ijkl}\Omega_{23}^{ijkl},\\ \bar{\Omega}_{3}^{ijkl}\equiv\bar{\Omega}_{0}^{ijkl}\left(1+\Omega_{44}^{ijkl}\right)-\left(\Omega_{24}^{ijkl}\right)^{2}\left(1+\Omega_{11}^{ijkl}\right)-\left(\Omega_{14}^{ijkl}\right)^{2}\left(1+\Omega_{22}^{ijkl}\right)+2\Omega_{12}^{ijkl}\Omega_{14}^{ijkl}\Omega_{24}^{ijkl}.\end{array} (S23)

In an analogous way, we can derive the mean-field evolution of mjrνm_{jr}^{\nu} as follows,

mjrν+1mjrν\displaystyle m_{jr}^{\nu+1}-m_{jr}^{\nu} =1d𝔼[n,α,βξjαΣα(ξ^nα+Δξ^nα)ξ^nβΣβξrβ]mjrν\displaystyle=\frac{1}{d}\mathbb{E}\left[\sum_{n,\alpha,\beta}\xi_{j\alpha}\Sigma_{\alpha}(\hat{\xi}_{n\alpha}+\Delta\hat{\xi}_{n\alpha})\hat{\xi}^{*}_{n\beta}\Sigma^{*}_{\beta}\xi^{*}_{r\beta}\right]-m_{jr}^{\nu} (S24)
=ηkmd𝔼[α,i,aξjαΣα2ξiασ(λa)σ(λi)λr]ηm2d𝔼[α,i,aξjαΣα2ξiασ(λa)σ(λi)λr],\displaystyle=\frac{\eta}{kmd}\mathbb{E}\left[\sum_{\alpha,i,a}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}{\sigma(\lambda_{a}^{*})}\sigma^{\prime}(\lambda_{i})\lambda_{r}^{*}\right]-\frac{\eta}{m^{2}d}\mathbb{E}\left[\sum_{\alpha,i,a}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}{\sigma(\lambda_{a})}\sigma^{\prime}(\lambda_{i})\lambda_{r}^{*}\right],

where the definition of mjrm_{jr} has bee used. If we define τν/d\tau\equiv\nu/d, and take the thermodynamic limit of dd\to\infty, the time step becomes continuous, and we can thus write down the following ODEs,

dqjldτ\displaystyle\frac{\mathrm{d}q_{jl}}{\mathrm{d}\tau} =ηkm[β,r,iI3(i,j,r)ξiβΣβ2ξlβ]ηm2[β,r^,iI3(i,j,r^)ξiβΣβ2ξlβ]\displaystyle=\frac{\eta}{km}\left[\sum_{\beta,r^{*},i}I_{3}(i,j,r^{*})\xi_{i\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right]-\frac{\eta}{m^{2}}\left[\sum_{\beta,\hat{r},i}I_{3}(i,j,\hat{r})\xi_{i\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right] (S25)
+ηkm[α,r,iξjαΣα2ξiαI3(i,l,r)]ηm2𝔼[α,r^,iξjαΣα2ξiαI3(i,l,r^)]\displaystyle+\frac{\eta}{km}\left[\sum_{\alpha,r^{*},i}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{3}(i,l,r^{*})\right]-\frac{\eta}{m^{2}}\mathbb{E}\left[\sum_{\alpha,\hat{r},i}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{3}(i,l,\hat{r})\right]
+η2m4[α,β,i,a,r,r^ξjαΣα2ξiαI4(i,a,r,r^)ξaβΣβ2ξlβ]\displaystyle+\frac{\eta^{2}}{m^{4}}\left[\sum_{\alpha,\beta,i,a,r^{*},\hat{r}^{*}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{4}(i,a,r^{*},\hat{r}^{*})\xi_{a\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right]
+η2m2k2[α,β,i,a,r,r^ξjαΣα2ξiαI4(i,a,r,r^)ξaβΣβ2ξlβ]\displaystyle+\frac{\eta^{2}}{m^{2}k^{2}}\left[\sum_{\alpha,\beta,i,a,r,\hat{r}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{4}(i,a,r,\hat{r})\xi_{a\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right]
2η2m3k[α,β,i,a,r,r^ξjαΣα2ξiαI4(i,a,r,r^)ξaβΣβ2ξlβ],\displaystyle-2\frac{\eta^{2}}{m^{3}k}\left[\sum_{\alpha,\beta,i,a,r^{*},\hat{r}}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{4}(i,a,r^{*},\hat{r})\xi_{a\beta}\Sigma_{\beta}^{2}\xi_{l\beta}\right],
dmjrνdτ\displaystyle\frac{\mathrm{d}m_{jr^{*}}^{\nu}}{\mathrm{d}\tau} =ηkm[α,i,aξjαΣα2ξiαI3(i,r,a)]ηm2[α,i,aξjαΣα2ξiαI3(i,r,a)],\displaystyle=\frac{\eta}{km}\left[\sum_{\alpha,i,a}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{3}(i,r^{*},a^{*})\right]-\frac{\eta}{m^{2}}\left[\sum_{\alpha,i,a}\xi_{j\alpha}\Sigma_{\alpha}^{2}\xi_{i\alpha}I_{3}(i,r^{*},a)\right],

where the index in I3I_{3} or I4I_{4} with the symbol * labels the teacher’s local-field.

References

  • Goodfellow et al. (2016) I. Goodfellow, Y. Bengio, and A. Courville, Deep Learning (MIT Press, Cambridge, MA, 2016).
  • Carleo et al. (2019) G. Carleo, I. Cirac, K. Cranmer, L. Daudet, M. Schuld, N. Tishby, L. Vogt-Maranto, and L. Zdeborová, Rev. Mod. Phys. 91, 045002 (2019).
  • Huang (2022) H. Huang, Statistical Mechanics of Neural Networks (Springer, Singapore, 2022).
  • Roberts et al. (2022) D. A. Roberts, S. Yaida, and B. Hanin, The Principles of Deep Learning Theory: An Effective Theory Approach to Understanding Neural Networks (Cambridge University Press, Cambridge, 2022).
  • Jaderberg et al. (2014) M. Jaderberg, A. Vedaldi, and A. Zisserman, arXiv:1405.3866 (2014).
  • Yang et al. (2020) H. Yang, M. Tang, W. Wen, F. Yan, D. Hu, A. Li, H. Li, and Y. Chen, 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW) (2020).
  • Giambagli et al. (2021) L. Giambagli, L. Buffoni, T. Carletti, W. Nocentini, and D. Fanelli, Nature Communications 12, 1330 (2021).
  • Chicchi et al. (2021) L. Chicchi, L. Giambagli, L. Buffoni, T. Carletti, M. Ciavarella, and D. Fanelli, Phys. Rev. E 104, 054312 (2021).
  • Jiang et al. (2021) Z. Jiang, J. Zhou, T. Hou, K. Y. M. Wong, and H. Huang, Phys. Rev. E 104, 064306 (2021).
  • Zhou et al. (2021) J. Zhou, Z. Jiang, T. Hou, Z. Chen, K. Y. M. Wong, and H. Huang, Phys. Rev. E 104, 064307 (2021).
  • (11) Y. LeCun, The MNIST database of handwritten digits, retrieved from http://yann.lecun.com/exdb/mnist.
  • Fischer et al. (2022) K. Fischer, A. Ren’e, C. Keup, M. Layer, D. Dahmen, and M. Helias, arXiv:2202.04925 (2022).
  • (13) See the supplemental material at http://… for technical and experimental details.
  • Biehl and Schwarze (1995) M. Biehl and H. Schwarze, Journal of Physics A: Mathematical and General 28, 643 (1995).
  • Saad and Solla (1995) D. Saad and S. A. Solla, Phys. Rev. Lett. 74, 4337 (1995).
  • Goldt et al. (2019) S. Goldt, M. S. Advani, A. M. Saxe, F. Krzakala, and L. Zdeborová, arXiv:1901.09085 (2019).
  • Bjoerck and Golub (1973) A. Bjoerck and G. H. Golub, Mathematics of Computation 27, 579 (1973).
  • Schwarze (1993) H. Schwarze, J. Phys. A: Math. Gen. 26, 5781 (1993).