This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Task Arithmetic in Trust Region: A Training-Free Model Merging Approach to Navigate Knowledge Conflicts

Antiquus S. Hippocampus, Natalia Cerebro & Amelie P. Amygdale
Department of Computer Science
Cranberry-Lemon University
Pittsburgh, PA 15213, USA
{hippo,brain,jen}@cs.cranberry-lemon.edu
&Ji Q. Ren & Yevgeny LeNet
Department of Computational Neuroscience
University of the Witwatersrand
Joburg, South Africa
{robot,net}@wits.ac.za
\ANDCoauthor
Affiliation
Address
email
Use footnote for providing further information about author (webpage, alternative address)—not for acknowledging funding agencies. Funding acknowledgements go at the end of the paper.

Let Xin×dX_{i}\in\mathbb{R}^{n\times d} and Tpred×hT_{pre}\in\mathbb{R}^{d\times h} be the feature and the task vector of the task ii. Now our target is learning a group of removal basis Bih×cB_{i}\in\mathbb{R}^{h\times c} for task ii such that:

maxBijiXiWjBiBiF2λXjWjBiBiF2.\max_{B_{i}}\sum_{j\neq i}\left\|X_{i}W_{j}B_{i}B_{i}^{\top}\right\|_{F}^{2}-\lambda\left\|X_{j}W_{j}B_{i}B_{i}^{\top}\right\|_{F}^{2}. (1)

Then we have:

jiXiWjBiBiF2λXjWjBiBiF2\displaystyle\sum_{j\neq i}\left\|X_{i}W_{j}B_{i}B_{i}^{\top}\right\|_{F}^{2}-\lambda\left\|X_{j}W_{j}B_{i}B_{i}^{\top}\right\|_{F}^{2} (2)
=\displaystyle= jiTr(XiWjBiBiWjXi)λTr(XjWjBiBiWjXj)\displaystyle\sum_{j\neq i}\text{Tr}(X_{i}W_{j}B_{i}B_{i}^{\top}W_{j}^{\top}X_{i}^{\top})-\lambda Tr(X_{j}W_{j}B_{i}B_{i}^{\top}W_{j}^{\top}X_{j}^{\top})
=\displaystyle= jiTr(WjBiBiWjXiXi)λTr(WjBiBiWjXjXj)\displaystyle\sum_{j\neq i}\text{Tr}(W_{j}B_{i}B_{i}^{\top}W_{j}^{\top}X_{i}^{\top}X_{i})-\lambda Tr(W_{j}B_{i}B_{i}^{\top}W_{j}^{\top}X_{j}^{\top}X_{j})
=\displaystyle= jiTr(WjBiBiWj(XiXiλXjXj))\displaystyle\sum_{j\neq i}\text{Tr}\left(W_{j}B_{i}B_{i}^{\top}W_{j}^{\top}\left(X_{i}^{\top}X_{i}-\lambda X_{j}^{\top}X_{j}\right)\right)
=\displaystyle= Tr(Bi(jiWj(XiXiλXjXj)Wj)GBi)\displaystyle\text{Tr}\left(B_{i}^{\top}\underbrace{\left(\sum_{j\neq i}W_{j}^{\top}\left(X_{i}^{\top}X_{i}-\lambda X_{j}^{\top}X_{j}\right)W_{j}\right)}_{G}B_{i}\right)

The above equation implies that the largest cc eigenvectors of GG admit an optimal solution.

Let Di={xid}D_{i}=\{x_{i}\in\mathbb{R}^{d}\} and tpredt_{pre}\in\mathbb{R}^{d} be the feature and the task vector of the task ii for a weight of a layer normalization. Now our target is learning a group of removal binary mask midm_{i}\in\mathbb{R}^{d} for task ii such that:

maxmijixiDixjDjxitjmi2λxjtjmi2.\max_{m_{i}}\sum_{j\neq i}\sum_{x_{i}\in D_{i}}\sum_{x_{j}\in D_{j}}\left\|x_{i}\odot t_{j}\odot m_{i}\right\|^{2}-\lambda\left\|x_{j}\odot t_{j}\odot m_{i}\right\|^{2}. (3)

Then we have:

jixiDixjDjxitjmi2λxjtjmi2\displaystyle\sum_{j\neq i}\sum_{x_{i}\in D_{i}}\sum_{x_{j}\in D_{j}}\left\|x_{i}\odot t_{j}\odot m_{i}\right\|^{2}-\lambda\left\|x_{j}\odot t_{j}\odot m_{i}\right\|^{2} (4)
=\displaystyle= jixiDixjDjk=1d(xi,ktj,kmi,k)2λjixjDjk=1d(xj,ktj,kmi,k)2\displaystyle\sum_{j\neq i}\sum_{x_{i}\in D_{i}}\sum_{x_{j}\in D_{j}}\sum_{k=1}^{d}\left(x_{i,k}t_{j,k}m_{i,k}\right)^{2}-\lambda\sum_{j\neq i}\sum_{x_{j}\in D_{j}}\sum_{k=1}^{d}\left(x_{j,k}t_{j,k}m_{i,k}\right)^{2}
=\displaystyle= k=1dmi,k2jixiDixjDj(xi,ktj,k)2λk=1dmi,k2jixjDj(xj,ktj,k)2\displaystyle\sum_{k=1}^{d}m_{i,k}^{2}\sum_{j\neq i}\sum_{x_{i}\in D_{i}}\sum_{x_{j}\in D_{j}}\left(x_{i,k}t_{j,k}\right)^{2}-\lambda\sum_{k=1}^{d}m_{i,k}^{2}\sum_{j\neq i}\sum_{x_{j}\in D_{j}}\left(x_{j,k}t_{j,k}\right)^{2}
=\displaystyle= k=1dmi,k(jixiDixjDj(xi,ktj,k)2λjixjDj(xj,ktj,k)2)gk\displaystyle\sum_{k=1}^{d}m_{i,k}\underbrace{\left(\sum_{j\neq i}\sum_{x_{i}\in D_{i}}\sum_{x_{j}\in D_{j}}\left(x_{i,k}t_{j,k}\right)^{2}-\lambda\sum_{j\neq i}\sum_{x_{j}\in D_{j}}\left(x_{j,k}t_{j,k}\right)^{2}\right)}_{g_{k}}
=\displaystyle= k=1dmi,kgk,\displaystyle\sum_{k=1}^{d}m_{i,k}g_{k},

The above equation implies that the largest cc values of gkg_{k} should be set to 1.

Input: Pre-trained model WpreW_{\text{pre}}; Task vectors {T1,,TK}\{T_{1},\dots,T_{K}\}; Unlabeled exemplar-set {D1,,DK}\{D_{1},\dots,D_{K}\}
Output: Merged model θMTL\theta_{\text{MTL}}
1 // Collecting the input features for each task
2 for k=1k=1 to KK do
3       Initialize task inputs: Xk1=DkX^{1}_{k}=D_{k}
4       for l=1l=1 to LL do
5             Update task features: Xkl+1=f(Xkl;Wprel+Tkl)X^{l+1}_{k}=f(X^{l}_{k};W_{\text{pre}}^{l}+T_{k}^{l})
6      
7
8// Clip all task vectors
9 for k=1k=1 to KK do
10       for l=1l=1 to LL do
11             if parameter ll is a linear layer (weight matrix) then
12                   Compute basis: Bkl=argmaxBikXklTilBBFαXilTilBBFB^{l}_{k}=\arg\max_{B}\sum_{i\neq k}\|X_{k}^{l}T_{i}^{l}BB^{\top}\|_{F}-\alpha\|X_{i}^{l}T_{i}^{l}BB^{\top}\|_{F}
13                   for i=1i=1 to KK, iki\neq k do
14                         Project and update: Til=TilTilBklBklT_{i}^{l}=T_{i}^{l}-T_{i}^{l}B^{l}_{k}{B^{l}_{k}}^{\top}
15                  
16            else if parameter ll is an instance normalization layer (weight) then
17                   Compute mask: mkl=argmaxmikxklXklxklTilm2αxilXilxilTilm2m^{l}_{k}=\arg\max_{m}\sum_{i\neq k}\sum_{x_{k}^{l}\in X_{k}^{l}}\|x_{k}^{l}\circ T_{i}^{l}\circ m\|^{2}-\alpha\sum_{x_{i}^{l}\in X_{i}^{l}}\|x_{i}^{l}\circ T_{i}^{l}\circ m\|^{2}
18                   for i=1i=1 to KK, iki\neq k do
19                         Element-wise clipping: Til=TilTilmklT_{i}^{l}=T_{i}^{l}-T_{i}^{l}\circ m^{l}_{k}
20                  
21            else if parameter ll is a bias term then
22                   Compute mask: mkl=argmaxmikTilm2m^{l}_{k}=\arg\max_{m}\sum_{i\neq k}\|T_{i}^{l}\circ m\|^{2}
23                   for i=1i=1 to KK, iki\neq k do
24                         Element-wise clipping: Til=TilTilmklT_{i}^{l}=T_{i}^{l}-T_{i}^{l}\circ m^{l}_{k}
25                  
26            
27      
28
29// Merging
30 Compute merged model: θMTL=θpre+λkTk\theta_{\text{MTL}}=\theta_{\text{pre}}+\lambda\sum_{k}T_{k}
return θMTL\theta_{\text{MTL}}
Algorithm 1 The model merging process
Table 1: Multi-task performance when merging ViT-B/32 models on eight tasks. The column of “# Best” indicates the number of datasets on which the proposed method achieved the best performance, and the best and second-best performance are highlighted with bold and underline. Results with * stem from the original paper, which may have a different setting.
Method SUN397 Cars RESISC45 EuroSAT SVHN GTSRB MNIST DTD # Best Avg Acc
Basic baseline methods
Pre-trained 62.3 59.7 60.7 45.5 31.4 32.6 48.5 43.8 - 48.0
Individual 75.3 77.7 96.1 99.7 97.5 98.7 99.7 79.4 - 90.5
Traditional MTL 73.9 74.4 93.9 98.2 95.8 98.9 99.5 77.9 - 88.9
Test-time training based methods
TW AdaMerging 58.0 53.2 68.8 85.7 81.1 84.4 92.4 44.8 0 71.1
TW AdaMerging++ 60.8 56.9 73.1 83.4 87.3 82.4 95.7 50.1 0 73.7
LW AdaMerging 64.5 68.1 79.2 93.8 87.0 91.9 97.5 59.1 1 80.1
LW AdaMerging++ 66.6 68.3 82.2 94.2 89.6 89.0 98.3 60.6 0 81.1
Surgery Merging 63.8 59.9 83.3 97.9 87.0 87.0 98.6 69.4 1 80.9
LW AdaMerging++ & TATR 69.8 70.3 83.7 93.7 90.0 90.2 98.3 63.7 1 82.5
Surgery & TATR 67.1 62.2 87.1 97.4 87.3 88.5 98.7 70.9 2 82.4
Training-free methods
Weight Averaging 65.3 63.4 71.4 71.7 64.2 52.8 87.5 50.1 65.8
Fisher Merging 68.6 69.2 70.7 66.4 72.9 51.1 87.9 59.9 68.3
RegMean 65.3 63.5 75.6 78.6 78.1 67.4 93.7 52.0 71.8
Task Arithmetic 55.2 54.9 66.7 78.9 80.2 69.7 97.3 50.4 69.1
Ties-Merging 59.8 58.6 70.7 79.7 86.2 72.1 98.3 54.2 72.4
TATR 62.7 59.3 72.3 82.3 80.5 72.6 97.0 55.4 72.8
Ties-Merging & TATR 66.3 65.9 75.9 79.4 79.9 68.1 96.2 54.8 73.3
Consensus Merging 65.7 63.6 76.5 77.2 81.7 70.3 97.0 57.1 73.6
PCB Merging 75.9
Ours 68.8 66.5 79.4 85.9 80.1 72.9 97.8 57.9 76.1
Table 2: Multi-task performance when merging ViT-L/14 models on eight tasks. Results with * stem from the original paper, which may have a different setting.
Method SUN397 Cars RESISC45 EuroSAT SVHN GTSRB MNIST DTD # Best Avg Acc
Basic baseline methods
Pre-trained 66.8 77.7 71.0 59.9 58.4 50.5 76.3 55.3 - 64.5
Individual 82.3 92.4 97.4 100.0 98.1 99.2 99.7 84.1 - 94.2
Traditional MTL 80.8 90.6 96.3 96.3 97.6 99.1 99.6 84.4 - 93.5
Test-time training based methods
AdaMerging 79.0 90.3 90.8 96.2 93.4 98.0 99.0 79.9 2 90.8
AdaMerging++ 79.4 90.3 91.6 97.4 93.4 97.5 99.0 79.2 1 91.0
Surgery Merging 75.7 84.4 93.1 98.8 91.3 93.4 99.1 76.1 1 89.0
AdaMerging++ & TATR 81.6 95.9 95.8 95.5 83.2 92.6 99.7 87.5 4 91.5
Surgery & TATR 76.3 85.8 93.8 98.8 91.4 93.0 99.2 77.9 1 89.5
Training-free methods
Weight Averaging 72.1 81.6 82.6 91.9 78.2 70.7 97.1 62.8 79.6
Fisher Merging 69.2 88.6 87.5 93.5 80.6 74.8 93.3 70.0 82.2
RegMean 73.3 81.8 86.1 97.0 88.0 84.2 98.5 60.8 83.7
Task Arithmetic 73.9 82.1 86.6 94.1 87.9 86.7 98.9 65.6 84.5
Ties-Merging 76.5 85.0 89.3 95.7 90.3 83.3 99.0 68.8 86.0
TATR 74.6 83.7 87.6 93.7 88.6 88.1 99.0 66.8 85.3
Ties-Merging & TATR 76.3 85.3 88.8 94.4 90.8 88.7 99.2 68.8 86.5
Consensus Merging 75.0 84.3 89.4 95.6 88.3 82.4 98.9 68.0 85.2
PCB Merging 86.9
Ours 75.2 85.8 90.6 95.2 85.4 91.9 99.1 70.7 86.7
Table 3: Impact of the number of exemplars when merging ViT-B/32 models on eight tasks over 5 runs.
Exemplar number Avg Acc (ViT-B/32)
1 per task 78.00±\pm0.2
2 per task 77.75±\pm0.1
4 per task 77.59±\pm0.1
8 per task 77.40±\pm0.1
16 per task 77.21±\pm0.1
32 per task 77.00±\pm0.1