Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (2024)

Ruochen Jin1,2,Bojian Hou2,Jiancong Xiao2,Weijie Su2,and Li Shen2
1
East China Normal University, Shanghai, China,
2 University of Pennsylvania, Philadelphia, PA, USA
{kyrie.jin, li.shen}@pennmedicine.upenn.edu

Abstract

Task arithmetic has recently emerged as a cost-effective and scalable approach to edit pre-trained models directly in weight space, by adding the fine-tuned weights of different tasks. The performance has been further improved by a linear property which is illustrated by weight disentanglement. Yet, conventional linearization methods (e.g., NTK linearization) not only double the time and training cost but also have a disadvantage on single-task performance.We propose a simple yet effective and efficient method that only fine-tunes linear layers, which improves weight disentanglement and efficiency simultaneously.Specifically, our study reveals that only fine-tuning the linear layers in the attention modules makes the whole model occur in a linear regime, significantly improving weight disentanglement.To further understand how our method improves the disentanglement of task arithmetic, we present a comprehensive study of task arithmetic by differentiating the role of representation model and task-specific model. In particular, we find that the representation model plays an important role in improving weight disentanglement whereas the task-specific models such as the classification heads can degenerate the weight disentanglement performance.Overall, our work uncovers novel insights into the fundamental mechanisms of task arithmetic and offers a more reliable and effective approach to editing pre-trained models 111The code is available at https://github.com/kyrie-23/linear_task_arithmetic..

1 Introduction

The emergence of large pre-trained models in the open-source community maximizes the potential to boost performance on downstream tasks [1, 2, 3], align with human preferences [4, 5, 6, 7, 8], and enhance robustness [9, 10, 11, 12]. Traditional methods involve expensive joint fine-tuning across various tasks [3] and reliance on human feedback [5], limiting their scalability and widespread adoption. Moreover, optimizing performance for specific downstream tasks usually compromises the modelโ€™s initial pre-training performance or zero-shot accuracy [13, 14].

Based on the extensive resources of open-source models, we often aim to edit the pre-trained models without requiring access to additional training data, in order to improve performance on multiple downstream tasks. Task arithmetic [1] is a method proposed for this goal. The central to task arithmetic is a concept called the task vector that enables models to adapt to new tasks [1]. A task vector can be viewed as a set of weight adjustments specifically calibrated for a given task through fine-tuning, obtained by subtracting the task-specific weights from the original pre-trained weights. Essentially, each task vector encodes a unique representational signature tailored to a particular task. The illustration of task arithmetic is shown in Figure 1.Recent research in this domain, regarding task vector-centric approaches [1, 15, 16, 17, 18], has demonstrated that by aggregating multiple task vectors and integrating them into a pre-trained model, it is possible to create a new model, which we refer to as a unified model, and adapt it to multi-task learning.

However, task arithmetic from non-linear fine-tuning is not without limitations. For each individual task, although the unified model shows some improvement over the pre-trained model, its performance is usually not comparable to that of a model specifically trained for that task. This is because a task vector for one particular task usually has a negative effect on the performance of the other tasks. Therefore, the first question of this paper arises:

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (1)

Question 1: How can we improve the performance of task arithmetic?

Ideally, Question 1 is supposed to be resolved by weight disentanglement, where the unifined modelโ€™s performance should not be affected by other tasks when it is applied to one specific task (illustrated in the right handside of Figure1). However, weight disentanglement is the most challenging problem in task arithmetic. It has been found that enforcing models to be fine-tuned in the tangent space significantly improves weight disentanglement, thanks to the fact that models inherently operate in a linear regime [17]. Here the linear regime refers to the behavior of neural networks during the beginning of the fine-tuning phase, where the network updates occur primarily around the pre-trained parameter initialization. This phenomenon is formalized by the neural tangent kernel (NTK) theory [19]. While NTK linearization is effective, it requires two to three times more computational resources and doubles the memory footprint compared to its nonlinear counterpart. The additional cost of NTK linearization violates the original goal of task arithmetic we described above. As a result, Question 1 boils down to the following:

Question 2: How to improve the disentanglement and efficiency of task arithmetic simultaneously?

To answer Question 2, we aim to propose a novel fine-tuning method with the following two properties. For disentanglement, the fine-tuning method is supposed to operate better in the linear regime. For efficiency, the fine-tuning method should only be applied to a part of the network to reduce computational cost. Hence, we propose a simple yet efficient and effective method that is to fine-tune only linear layers to achieve both of the two goals simultaneously. Specifically, our study reveals that only fine-tuning the linear layers can make the attention modules occur in a linear regime, significantly improving both weight disentanglement and accuracy compared to both nonlinear counterparts and NTK linearization [17] (see Table 1). Meanwhile, our method is much more efficient, as it fine-tunes only a small portion of layers, thereby reducing the computational burden and memory usage. Additionally, our approach mitigates the non-linear advantage where non-linear fine-tuning consistently achieves the highest accuracy, offering a balanced and efficient alternative.

To further understand how our method improves the disentanglement of task arithmetic, we present a study by differentiating the role of representation model and task-specific model, while existing literature [17] formulated task arithmetic using a single model without clearly differentiating them. We conduct a comprehensive study of task arithmetic in contrastively pre-trained vision-language (ViT) models like CLIP [20], providing new insights into its fundamental mechanisms and proposing novel methods to improve the performance of pre-trained models through task arithmetic.Specifically, we illustrate that the representation model plays an important role in improving weight disentanglement whereas this has been constrained by task-specific models, such as classification heads.

Notably, we demonstrate that the attention module lies in a strong linear regime within its linear layers. Without any prior knowledge, we can achieve superior performance to the current state-of-the-art methods by simply fine-tuning all linear layers in the attention module. However, the performance can either improve or degrade depending on whether the bias parameters are fine-tuned. The best results are obtained when the settings align closely with those used in LoRA [21], which fine-tunes only Q,K,V๐‘„๐พ๐‘‰Q,K,Vitalic_Q , italic_K , italic_V, and output projection weights, indicating that this aspect warrants further exploration.

In particular, our main contributions are as follows:

  • โ€ข

    We propose a simple yet effective and efficient method that only fine-tunes linear layers, which improves weight disentanglement and multi-task performance up to 2.38% improvement compared to the state-of-the-art methods and 8.37% over the nonlinear baseline on several vision-language benchmarks.

  • โ€ข

    We demonstrate that fine-tuning all linear layers within the attention module occurring in a linear regime, without selectively freezing or omitting certain layers achieves superior performance compared to current state-of-the-art methods on this benchmark task.

  • โ€ข

    We reformulate the architecture of task arithmetic in [17] into a representation model and several task-specific models, aligning our notation with previous work. This allows us to more clearly illustrate their individual contributions.

  • โ€ข

    We illustrate that the representation model plays an important role in improving weight disentanglement, whereas the effectiveness of task arithmetic has been constrained by the task-specific models, such as classification heads.

Overall, our work provides new insights into the fundamental mechanisms of task arithmetic, enhancing the reliability and scalability of model editing. Our findings indicate that fine-tuning pre-trained models, particularly focusing on the orthogonality of task vectors, deserves further investigation due to its significant potential for improving model editing effectiveness. These insights can lead to the development of more efficient and accurate model editing techniques, enabling practitioners to adapt pre-trained models to a wider array of tasks.

2 Task arithmetic is a reflection of weight disentanglement

Existing literature [17] formulated task arithmetic using a single model, which actually combines a representation model (e.g., CLIP) with several task-specific models (e.g., classification heads). This combination leads to different pre-trained parameters ฮธ0subscript๐œƒ0\theta_{0}italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT for each task due to the use of different task-specific models (e.g., different classification heads for different datasets). The implementation of task arithmetic focuses on the representation model, extracting task vectors, fine-tuning, and replacing different task-specific models.

To better understand the relationship between the representation model and the task-specific models, we separate their definitions. This allows us to more clearly illustrate their individual contributions to task arithmetic. We formulate the task arithmetic property along with weight disentanglement, providing distinct definitions for the representation and task-specific models while aligning our notation with previous work.

Let F:๐’ณร—ฮ˜โ†’๐’ด:๐นโ†’๐’ณฮ˜๐’ดF:\mathcal{X}\times\Theta\rightarrow\mathcal{Y}italic_F : caligraphic_X ร— roman_ฮ˜ โ†’ caligraphic_Y be a neural network taking inputs xโˆˆ๐’ณ๐‘ฅ๐’ณx\in\mathcal{X}italic_x โˆˆ caligraphic_X and parameterized by a set of weights ฯ‘โˆˆฮ˜italic-ฯ‘ฮ˜\vartheta\in\Thetaitalic_ฯ‘ โˆˆ roman_ฮ˜, which consists of a representation model fโข(โ‹…;ฮธ)๐‘“โ‹…๐œƒf(\cdot;\theta)italic_f ( โ‹… ; italic_ฮธ ) and a task-specific model gโข(โ‹…;ฯ‰)๐‘”โ‹…๐œ”g(\cdot;\omega)italic_g ( โ‹… ; italic_ฯ‰ ) where ฯ‘={ฮธ,ฯ‰}italic-ฯ‘๐œƒ๐œ”\vartheta=\{\theta,\omega\}italic_ฯ‘ = { italic_ฮธ , italic_ฯ‰ }. We will assume ๐’ณโŠ†โ„d๐’ณsuperscriptโ„๐‘‘\mathcal{X}\subseteq\mathbb{R}^{d}caligraphic_X โŠ† blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, ฮ˜โŠ†โ„mฮ˜superscriptโ„๐‘š\Theta\subseteq\mathbb{R}^{m}roman_ฮ˜ โŠ† blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and ๐’ดโŠ†โ„c๐’ดsuperscriptโ„๐‘\mathcal{Y}\subseteq\mathbb{R}^{c}caligraphic_Y โŠ† blackboard_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT. Consider T๐‘‡Titalic_T tasks, with every task t๐‘กtitalic_t consisting of a triplet (Dt,ฮผt,Ftโˆ—)subscript๐ท๐‘กsubscript๐œ‡๐‘กsuperscriptsubscript๐น๐‘ก(D_{t},\mu_{t},F_{t}^{*})( italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ฮผ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT ), where DtโŠ†๐’ณsubscript๐ท๐‘ก๐’ณD_{t}\subseteq\mathcal{X}italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โŠ† caligraphic_X is a data support (e.g., ImageNet [22] images), ฮผtsubscript๐œ‡๐‘ก\mu_{t}italic_ฮผ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT an input distribution such that suppโข(ฮผt)=Dtsuppsubscript๐œ‡๐‘กsubscript๐ท๐‘ก\text{supp}(\mu_{t})=D_{t}supp ( italic_ฮผ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and Ftโˆ—:Dtโ†’๐’ด:superscriptsubscript๐น๐‘กโ†’subscript๐ท๐‘ก๐’ดF_{t}^{*}:D_{t}\rightarrow\mathcal{Y}italic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT : italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โ†’ caligraphic_Y a target function (e.g., labels). In practice, each task is identified with a training set {(xv,Ftโˆ—โข(xv))}vโˆˆ[n]subscriptsubscript๐‘ฅ๐‘ฃsuperscriptsubscript๐น๐‘กsubscript๐‘ฅ๐‘ฃ๐‘ฃdelimited-[]๐‘›\{(x_{v},F_{t}^{*}(x_{v}))\}_{v\in[n]}{ ( italic_x start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT , italic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ) } start_POSTSUBSCRIPT italic_v โˆˆ [ italic_n ] end_POSTSUBSCRIPT where Ftโˆ—โข(xv)=gโข(fโข(xv;ฮธtโˆ—);ฯ‰t)superscriptsubscript๐น๐‘กsubscript๐‘ฅ๐‘ฃ๐‘”๐‘“subscript๐‘ฅ๐‘ฃsuperscriptsubscript๐œƒ๐‘กsubscript๐œ”๐‘กF_{t}^{*}(x_{v})=g(f(x_{v};\theta_{t}^{*});\omega_{t})italic_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) = italic_g ( italic_f ( italic_x start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ; italic_ฮธ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT ) ; italic_ฯ‰ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) with xโˆผฮผtsimilar-to๐‘ฅsubscript๐œ‡๐‘กx\sim\mu_{t}italic_x โˆผ italic_ฮผ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, that is used to fine-tune the representation models starting from the pre-trained weights ฮธ0subscript๐œƒ0\theta_{0}italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and obtain the fine-tuned weights ฮธtโˆ—superscriptsubscript๐œƒ๐‘ก\theta_{t}^{*}italic_ฮธ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT, while the task specified models are fixed at ฯ‰tsubscript๐œ”๐‘ก\omega_{t}italic_ฯ‰ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Task arithmetic. Let the task vector of task t๐‘กtitalic_t be the difference between the fine-tuned and the pre-trained weights, i.e., ฯ„t=ฮธtโˆ—โˆ’ฮธ0subscript๐œ๐‘กsuperscriptsubscript๐œƒ๐‘กsubscript๐œƒ0\tau_{t}=\theta_{t}^{*}-\theta_{0}italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_ฮธ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โˆ— end_POSTSUPERSCRIPT - italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The following property formalizes the notion of task arithmetic introduced in Ortiz-Jimenez et al. [17], where the authors observed that the accuracies of pre-trained models on different datasets can be modified independently through the addition or removal of task vectors.

Property 1 (Task arithmetic)

Consider a set of task vectors T={ฯ„t}tโˆˆ[T]๐‘‡subscriptsubscript๐œ๐‘ก๐‘กdelimited-[]๐‘‡T=\{\tau_{t}\}_{t\in[T]}italic_T = { italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT with associated non-intersecting task supports D={DtโŠ‚๐’ณ}tโˆˆ[T]๐ทsubscriptsubscript๐ท๐‘ก๐’ณ๐‘กdelimited-[]๐‘‡D=\{D_{t}\subset\mathcal{X}\}_{t\in[T]}italic_D = { italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โŠ‚ caligraphic_X } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT, i.e., โˆ€t,tโ€ฒโขifโขtโ‰ tโ€ฒโขthenโขDtโˆฉDtโ€ฒ=โˆ…for-all๐‘กsuperscript๐‘กโ€ฒif๐‘กsuperscript๐‘กโ€ฒthensubscript๐ท๐‘กsubscript๐ทsuperscript๐‘กโ€ฒ\forall t,t^{\prime}\text{ if }t\neq t^{\prime}\text{ then }D_{t}\cap D_{t^{%\prime}}=\emptysetโˆ€ italic_t , italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT if italic_t โ‰  italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT then italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โˆฉ italic_D start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = โˆ…. We say a network F๐นFitalic_F satisfies the task arithmetic property around ฯ‘0subscriptitalic-ฯ‘0\vartheta_{0}italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT with respect to T๐‘‡Titalic_T and D๐ทDitalic_D if

Fโข(x;ฯ‘0+โˆ‘t=1Tฮฑtโขฯ„t)={Fโข(x;ฯ‘0+ฮฑtโขฯ„t)ifโขxโˆˆDt,Fโข(x;ฯ‘0)ifโขxโˆ‰โ‹ƒt=1TDt,๐น๐‘ฅsubscriptitalic-ฯ‘0superscriptsubscript๐‘ก1๐‘‡subscript๐›ผ๐‘กsubscript๐œ๐‘กcases๐น๐‘ฅsubscriptitalic-ฯ‘0subscript๐›ผ๐‘กsubscript๐œ๐‘กif๐‘ฅsubscript๐ท๐‘ก๐น๐‘ฅsubscriptitalic-ฯ‘0if๐‘ฅsuperscriptsubscript๐‘ก1๐‘‡subscript๐ท๐‘ก\displaystyle F\left(x;\vartheta_{0}+\sum_{t=1}^{T}\alpha_{t}\tau_{t}\right)=%\begin{cases}F(x;\vartheta_{0}+\alpha_{t}\tau_{t})&\text{if }x\in D_{t},\\F(x;\vartheta_{0})&\text{if }x\notin\bigcup_{t=1}^{T}D_{t},\end{cases}italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = { start_ROW start_CELL italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL start_CELL if italic_x โˆˆ italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_CELL start_CELL if italic_x โˆ‰ โ‹ƒ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , end_CELL end_ROW(1)

where ฯ‘0={ฮธ0,ฯ‰t}subscriptitalic-ฯ‘0subscript๐œƒ0subscript๐œ”๐‘ก\vartheta_{0}=\{\theta_{0},\omega_{t}\}italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = { italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_ฯ‰ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } and ฯ‘0+ฮฑtโขฯ„t={ฮธ0+ฮฑtโขฯ„t,ฯ‰t}subscriptitalic-ฯ‘0subscript๐›ผ๐‘กsubscript๐œ๐‘กsubscript๐œƒ0subscript๐›ผ๐‘กsubscript๐œ๐‘กsubscript๐œ”๐‘ก\vartheta_{0}+\alpha_{t}\tau_{t}=\{\theta_{0}+\alpha_{t}\tau_{t},\omega_{t}\}italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_ฯ‰ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } in certain task t๐‘กtitalic_t,with (ฮฑ1,โ€ฆ,ฮฑT)โˆˆ๐’œโŠ†โ„Tsubscript๐›ผ1โ€ฆsubscript๐›ผ๐‘‡๐’œsuperscriptโ„๐‘‡(\alpha_{1},\ldots,\alpha_{T})\in\mathcal{A}\subseteq\mathbb{R}^{T}( italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , โ€ฆ , italic_ฮฑ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ) โˆˆ caligraphic_A โŠ† blackboard_R start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT.

In short, a model satisfies Property 1 if adding ฯ„tsubscript๐œ๐‘ก\tau_{t}italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT does not modify the output of the model outside Dtsubscript๐ท๐‘กD_{t}italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Property 2 (Weight disentanglement)

A parametric function F:๐’ณร—ฮ˜โ†’๐’ด:๐นโ†’๐’ณฮ˜๐’ดF:\mathcal{X}\times\Theta\rightarrow\mathcal{Y}italic_F : caligraphic_X ร— roman_ฮ˜ โ†’ caligraphic_Y is weight disentangled with respect to a set of task vectors T={ฯ„t}tโˆˆ[T]๐‘‡subscriptsubscript๐œ๐‘ก๐‘กdelimited-[]๐‘‡T=\{\tau_{t}\}_{t\in[T]}italic_T = { italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT and the corresponding supports D={Dt}tโˆˆ[T]๐ทsubscriptsubscript๐ท๐‘ก๐‘กdelimited-[]๐‘‡D=\{D_{t}\}_{t\in[T]}italic_D = { italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT if

Fโข(x;ฯ‘0+โˆ‘t=1Tฮฑtโขฯ„t)=Fโข(x;ฯ‘0+ฮฑtโขฯ„t)โข๐Ÿ™โข(xโˆˆDt)+Fโข(x;ฯ‘0)โข๐Ÿ™โข(xโˆ‰โ‹ƒtโˆˆ[T]Dt).๐น๐‘ฅsubscriptitalic-ฯ‘0superscriptsubscript๐‘ก1๐‘‡subscript๐›ผ๐‘กsubscript๐œ๐‘ก๐น๐‘ฅsubscriptitalic-ฯ‘0subscript๐›ผ๐‘กsubscript๐œ๐‘ก1๐‘ฅsubscript๐ท๐‘ก๐น๐‘ฅsubscriptitalic-ฯ‘01๐‘ฅsubscript๐‘กdelimited-[]๐‘‡subscript๐ท๐‘ก\displaystyle F\left(x;\vartheta_{0}+\sum_{t=1}^{T}\alpha_{t}\tau_{t}\right)=F%(x;\vartheta_{0}+\alpha_{t}\tau_{t})\mathds{1}(x\in D_{t})+F(x;\vartheta_{0})%\mathds{1}\left(x\notin\bigcup_{t\in[T]}D_{t}\right).italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) blackboard_1 ( italic_x โˆˆ italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + italic_F ( italic_x ; italic_ฯ‘ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) blackboard_1 ( italic_x โˆ‰ โ‹ƒ start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) .

From Property1 and Property2 we can see that weight disentanglement is not necessarily linked to performance on various tasks. In other words, a model can achieve weight disentanglement relative to a group of task vectors but still underperforms on a given task. For example, if fโข(โ‹…;ฮธ0+ฮฑโขฯ„)๐‘“โ‹…subscript๐œƒ0๐›ผ๐œf(\cdot;\theta_{0}+\alpha\tau)italic_f ( โ‹… ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ italic_ฯ„ ) does not generalize well for some ฮฑ๐›ผ\alphaitalic_ฮฑ, the model may still fail to perform effectively despite being weight disentangled [17].

3 Fine-tuning linear layers only is much more efficient than NTK linearization

The objective of this work is to illustrate the relationship between the linear regime while fine-tuning ViT models (attention module) and task arithmetic performance. Existing work [17] demonstrated that the linear regime is not necessary for task arithmetic, but models within the linear regime exhibit superior disentanglement performance.

We have seen that linearized models are more weight-disentangled than non-linear ones[17]. However, NTK linearization often degrades single-task performance, which is demonstrated as a non-linear advantage. In this study, we show that fine-tuning only the linear layers significantly improves task arithmetic performance by reducing the single-task accuracy gap. This approach maintains the benefits of weight disentanglement while enhancing the overall effectiveness of task arithmetic across various settings.

Neural tangent kernel. Around the initialization weights ฮธ0subscript๐œƒ0\theta_{0}italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, a representation model fโข(x;ฮธ)๐‘“๐‘ฅ๐œƒf(x;\theta)italic_f ( italic_x ; italic_ฮธ ) can be approximated with a first-order Taylor expansion:

fโข(x;ฮธ)=fโข(x;ฮธ0)+(ฮธโˆ’ฮธ0)โŠคโขโˆ‡ฮธfโข(x;ฮธ0)+higher order terms.๐‘“๐‘ฅ๐œƒ๐‘“๐‘ฅsubscript๐œƒ0superscript๐œƒsubscript๐œƒ0topsubscriptโˆ‡๐œƒ๐‘“๐‘ฅsubscript๐œƒ0higher order terms\displaystyle f(x;\theta)=f(x;\theta_{0})+(\theta-\theta_{0})^{\top}\nabla_{%\theta}f(x;\theta_{0})+\text{higher order terms}.italic_f ( italic_x ; italic_ฮธ ) = italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ( italic_ฮธ - italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT โŠค end_POSTSUPERSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + higher order terms .(2)

This approximation is equivalent to the neural tangent kernel (NTK) [19], kNTKโข(x,xโ€ฒ)=โˆ‡ฮธfโข(x;ฮธ0)โŠคโขโˆ‡ฮธfโข(xโ€ฒ;ฮธ0)subscript๐‘˜NTK๐‘ฅsuperscript๐‘ฅโ€ฒsubscriptโˆ‡๐œƒ๐‘“superscript๐‘ฅsubscript๐œƒ0topsubscriptโˆ‡๐œƒ๐‘“superscript๐‘ฅโ€ฒsubscript๐œƒ0k_{\text{NTK}}(x,x^{\prime})=\nabla_{\theta}f(x;\theta_{0})^{\top}\nabla_{%\theta}f(x^{\prime};\theta_{0})italic_k start_POSTSUBSCRIPT NTK end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) = โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT โŠค end_POSTSUPERSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), and defines a neural tangent space in which the relationship between weights and functions is linear. However, this linear approximation is often invalid at finite widths, as the evolution of parameters during training is inadequately captured by Eq. (2). In such cases, training occurs in a non-linear regime. Conversely, often during fine-tuning, parameter evolution in many pre-trained models is frequently minimal, meaning that training does not exit the tangent space and Eq. 2 closely approximates the network behavior [17]. In such cases, training occurs in a linear regime.

Specifically, rather than constraining models to fine-tune within the tangent space as proposed by [17], we enhance the linear regime within the attention module, which naturally arises from the linear architecture. This enhancement leverages the inherent properties of linear layers to simplify the modelโ€™s adaptation process. By fine-tuning only a select number of linear layers, the model not only lies in a pronounced linear regime but also maintains computational efficiency. This approach achieves performance levels comparable to full fine-tuning while significantly reducing the complexity and resources required.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (2)

Figure 2 illustrates three different types of fine-tuning paradigms. The first two are existing methods, while the third is our proposed partial linearization fine-tuning method. (a) The full fine-tuning paradigm where all parameters ฮธ๐œƒ\thetaitalic_ฮธ are updated during fine-tuning. (b) The full-model linearization paradigm where we fine-tune the model in the tangent space. It is worth noting that although the Jacobian-vector products can be computed in a single forward pass [23], training and inference in this paradigm are usually twice or three times as expensive as full fine-tuning [18]. (c) The linear layers only fine-tuning paradigm in which only a small number of linear layers are updated, which exhibits a linear regime, which is similar to linear fine-tuning. This approach incurs only a fraction of the training and inference costs compared to NTK linearization.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (3)

The results in Figure 3 indicate that our method reduces the gap between linear and non-linear models and has a better performance than NTK linearization respectively222Please see Appendix B.1 for performance on each task. Specifically, we fine-tune (FT) several CLIP pre-trained ViTs [24] of different sizes following the same setup as Ilharco et al. [1] on 8 tasks: Cars [25], DTD [26], SUN397 [27], EuroSAT [28], GTSRB [29], MNIST [30], SVHN [31] and RESISC45 [32]. The round dots represent the comparison between our method and non-linear models, while the triangle dots show the comparison between NTK linearization and non-linear models. The proximity of dots to the diagonal dashed line indicates the accuracy of the linearization method. Our method, represented by round dots, consistently appears closer to the diagonal dashed line than the NTK linearization (triangle dots), suggesting superior performance.

Even though NTK linearization models exhibit a linear regime and achieve impressive disentanglement performance, they omit higher-order terms, leading to accumulated errors as the number of tasks increases. To address this, we propose leveraging linear layers to enhance disentanglement performance. Our method retains the benefits of NTK linearization while mitigating the drawbacks associated with neglected higher-order terms. By fine-tuning the linear layers within the attention module, we improve computational efficiency and robustness across multiple tasks.

Considering a completely linear representation model f๐‘“fitalic_f, we can easily determine that โˆ‡ฮธfโข(x;ฮธ0)=xsubscriptโˆ‡๐œƒ๐‘“๐‘ฅsubscript๐œƒ0๐‘ฅ\nabla_{\theta}f(x;\theta_{0})=xโˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_x. This foundational insight supports our approach, as the gradient with respect to the parameters simplifies to the input data itself. Thus, by fine-tuning only the linear layers, we enhance weight disentanglement and maintain model efficiency without extensive retraining, aligning with the concept of weight disentanglement to improve performance across various tasks.

Proposition 1 (Orthogonality between Vectors and Data Points)

For a linear representation model f๐‘“fitalic_f, the gradient with respect to the parameters is given by โˆ‡ฮธfโข(x;ฮธ0)=xsubscriptโˆ‡๐œƒ๐‘“๐‘ฅsubscript๐œƒ0๐‘ฅ\nabla_{\theta}f(x;\theta_{0})=xโˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_x. Therefore, the model can be expressed as:

fโข(x;ฮธ)=fโข(x;ฮธ0)+(ฮธโˆ’ฮธ0)โŠคโขx.๐‘“๐‘ฅ๐œƒ๐‘“๐‘ฅsubscript๐œƒ0superscript๐œƒsubscript๐œƒ0top๐‘ฅ\displaystyle f(x;\theta)=f(x;\theta_{0})+(\theta-\theta_{0})^{\top}x.italic_f ( italic_x ; italic_ฮธ ) = italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ( italic_ฮธ - italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT โŠค end_POSTSUPERSCRIPT italic_x .(3)

For a specific task t=i๐‘ก๐‘–t=iitalic_t = italic_i, the function f๐‘“fitalic_f can be further expressed as:

fโข(x;ฮธ0+โˆ‘t=1Tฯ„t)=fโข(x;ฮธ0+ฯ„i)+โˆ‘tโ‰ iฯ„tโŠคโขx.๐‘“๐‘ฅsubscript๐œƒ0superscriptsubscript๐‘ก1๐‘‡subscript๐œ๐‘ก๐‘“๐‘ฅsubscript๐œƒ0subscript๐œ๐‘–subscript๐‘ก๐‘–superscriptsubscript๐œ๐‘กtop๐‘ฅ\displaystyle f\left(x;\theta_{0}+\sum_{t=1}^{T}\tau_{t}\right)=f(x;\theta_{0}%+\tau_{i})+\sum_{t\neq i}\tau_{t}^{\top}x.italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฯ„ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + โˆ‘ start_POSTSUBSCRIPT italic_t โ‰  italic_i end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โŠค end_POSTSUPERSCRIPT italic_x .

When โˆ‘tโ‰ iฯ„tโŠคโขxโ†’0โ†’subscript๐‘ก๐‘–superscriptsubscript๐œ๐‘กtop๐‘ฅ0\sum_{t\neq i}\tau_{t}^{\top}x\to 0โˆ‘ start_POSTSUBSCRIPT italic_t โ‰  italic_i end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT โŠค end_POSTSUPERSCRIPT italic_x โ†’ 0, weight disentanglement is achieved.

Proposition 1 implies that the orthogonality of data points x๐‘ฅxitalic_x and task vectors ฯ„tsubscript๐œ๐‘ก\tau_{t}italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT leads to perfect disentanglement. This allows for the evaluation of cosine similarity between data points and task vectors, providing a clear advantage over NTK linearization, where such perfect disentanglement is not guaranteed.

Inspired by LoRA [21] which should be considered as a special form of linearization, we choose the matrix of Q,K,V๐‘„๐พ๐‘‰Q,K,Vitalic_Q , italic_K , italic_V, and the output projection.Concentrating on attention modules, give an example of a single attention layer as applied in Vaswani et al. [33], due to the linear property of the weight matrix, we can easily get a disentangle output from each attention head for a certain task i๐‘–iitalic_i:

head=Attentionโข(qโขWQ,kโขWK,vโขWV)absentAttention๐‘žsuperscript๐‘Š๐‘„๐‘˜superscript๐‘Š๐พ๐‘ฃsuperscript๐‘Š๐‘‰\displaystyle=\text{Attention}(qW^{Q},kW^{K},vW^{V})= Attention ( italic_q italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT , italic_k italic_W start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT , italic_v italic_W start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT )
=Attentionโข(Qi+โˆ‘tโ‰ iqiโขWtQ,Ki+โˆ‘tโ‰ ikiโขWtK,Vi+โˆ‘tโ‰ iviโขWtV)absentAttentionsubscript๐‘„๐‘–subscript๐‘ก๐‘–subscript๐‘ž๐‘–subscriptsuperscript๐‘Š๐‘„๐‘กsubscript๐พ๐‘–subscript๐‘ก๐‘–subscript๐‘˜๐‘–subscriptsuperscript๐‘Š๐พ๐‘กsubscript๐‘‰๐‘–subscript๐‘ก๐‘–subscript๐‘ฃ๐‘–subscriptsuperscript๐‘Š๐‘‰๐‘ก\displaystyle=\text{Attention}(Q_{i}+\sum_{t\neq i}q_{i}W^{Q}_{t},K_{i}+\sum_{%t\neq i}k_{i}W^{K}_{t},V_{i}+\sum_{t\neq i}v_{i}W^{V}_{t})= Attention ( italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t โ‰  italic_i end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t โ‰  italic_i end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t โ‰  italic_i end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
=Attentionโข(Qi,Ki,Vi),absentAttentionsubscript๐‘„๐‘–subscript๐พ๐‘–subscript๐‘‰๐‘–\displaystyle=\text{Attention}(Q_{i},K_{i},V_{i}),= Attention ( italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,

where q,k,v๐‘ž๐‘˜๐‘ฃq,k,vitalic_q , italic_k , italic_v represent the query, key, value input, WQ,WK,WVsuperscript๐‘Š๐‘„superscript๐‘Š๐พsuperscript๐‘Š๐‘‰W^{Q},W^{K},W^{V}italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT , italic_W start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT , italic_W start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT represent each weight, e.g. Qi=qโข(W0+WiQ)subscript๐‘„๐‘–๐‘žsubscript๐‘Š0subscriptsuperscript๐‘Š๐‘„๐‘–Q_{i}=q(W_{0}+W^{Q}_{i})italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_q ( italic_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), WQ=W0Q+โˆ‘tWtQsuperscript๐‘Š๐‘„subscriptsuperscript๐‘Š๐‘„0subscript๐‘กsubscriptsuperscript๐‘Š๐‘„๐‘กW^{Q}=W^{Q}_{0}+\sum_{t}W^{Q}_{t}italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT = italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

MethodViT-B-32ViT-B-16ViT-L-14
Abs.(โ†‘โ†‘\uparrowโ†‘)Norm.(โ†‘โ†‘\uparrowโ†‘)Abs.(โ†‘โ†‘\uparrowโ†‘)Norm.(โ†‘โ†‘\uparrowโ†‘)Abs.(โ†‘โ†‘\uparrowโ†‘)Norm.(โ†‘โ†‘\uparrowโ†‘)
Pre-trained48.40-55.25-66.40-
Non-linear70.0077.0474.7580.5984.4089.47
Linear76.2685.8279.0186.3285.5391.44
Ours78.3787.4280.4487.2587.9193.66

Remarkably, as we show in Section 4, this increase in single-task performance does not compromise weight disentanglement, instead, itโ€™s much higher than NTK linearization. As a result, linear fine-tuning allows for improved task arithmetic compared to standard non-linear fine-tuning.To better illustrate the multi-task performance, we employ the benchmark proposed by Ilharco et al. [1] to evaluate the task arithmetic ability of a pre-trained model, which consists of the 8 tasks described before:The sum of the task vectors ฯ„=โˆ‘tฯ„t๐œsubscript๐‘กsubscript๐œ๐‘ก\tau=\sum_{t}\tau_{t}italic_ฯ„ = โˆ‘ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is added to a pre-trained checkpoint to produce a multi-task model. The success of this benchmark is measured in terms of the maximum average accuracy over the different tasks. Results are shown in Table 1.In contrast to recent work [16, 18], our method focuses on the original task arithmetic setting, evaluating multi-task performance and further examining weight disentanglement, which does not require access to all test data and is not in a parameter-efficient setting either.

To obtain the task vectors, we use the fine-tuned weights of different ViTs from before, and use a single mixing coefficient ฮฑ=ฮฑ1=โ‹ฏ=ฮฑT๐›ผsubscript๐›ผ1โ‹ฏsubscript๐›ผ๐‘‡\alpha=\alpha_{1}=\dots=\alpha_{T}italic_ฮฑ = italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = โ‹ฏ = italic_ฮฑ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT optimized separately for each fine-tuning paradigm to ensure a fair comparison. We provide all the details of this experiment in Appendix A.

In particular, Table 1 in its last rows shows that our method significantly outperforms its non-linear counterparts [1] and achieves state-of-the-art results on the task addition benchmarks. Our model achieves higher multi-task accuracies through task addition (up to 2.38% more). Additionally, our method not only outperforms on averaged accuracy but also on normalization accuracy.

In addition to input and output projection layers, can we improve disentanglement performance by fine-tuning all linear layers in the attention module, and how would performance go when bias parameters are also fine-tuned? To dig into this question, we conduct an ablation experiment with four paradigms: (1) only fine-tuning Q,K,V๐‘„๐พ๐‘‰Q,K,Vitalic_Q , italic_K , italic_V, and output layer weights (ours), (2) fine-tuning Q,K,V๐‘„๐พ๐‘‰Q,K,Vitalic_Q , italic_K , italic_V, and output layer weights and bias, (3) fine-tuning all linear layers weights in the attention module, (4) fine-tuning all linear layers weights and bias in the attention module.

ParadigmSingle-taskMulti-task
Accuracy(โ†‘โ†‘\uparrowโ†‘)Abs.(โ†‘โ†‘\uparrowโ†‘)Norm.(โ†‘โ†‘\uparrowโ†‘)
(1)89.5578.3787.42
(2)89.4877.7186.79
(3)88.9576.5286.11
(4)89.4377.8086.93

Surprisingly, all four paradigms outperform NTK linearization in both performance and disentanglement, demonstrating that ViT models exhibit a strong linear regime within the linear layers of the attention module (see Table 2). However, the performance can either improve or degrade depending on whether the bias parameters are fine-tuned. The best results are achieved when the settings align closely with those used in LoRA, suggesting that this is an area deserving further exploration.

4 Weight disentanglement emerges from representation model

Existing works demonstrate the effectiveness of task arithmetic primarily in the context of downstream tasks. However, the reliance on downstream tasks and task-specific models may limit the generalizability and scalability of these methods. To further illustrate the power of task arithmetic and broaden its applicability, we aim to investigate whether the representation model itself can satisfy Property 3 without the need for a task-specific model.

Unlike the approach outlined in Section 2, we redefine task arithmetic and weight disentanglement within the representation model itself. Our hypothesis is that pre-trained models can exhibit task arithmetic properties independently of downstream tasks, maintaining a similar representation through task arithmetic. By focusing on the representation model alone, we seek to show that the inherent properties of the pre-trained models are sufficient to support task arithmetic, potentially simplifying the process and broadening its applicability.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (4)
Property 3 (Task arithmetic*)

Consider a set of task vectors T={ฯ„t}tโˆˆ[T]๐‘‡subscriptsubscript๐œ๐‘ก๐‘กdelimited-[]๐‘‡T=\{\tau_{t}\}_{t\in[T]}italic_T = { italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT with associated non-intersecting task supports D={DtโŠ‚๐’ณ}tโˆˆ[T]๐ทsubscriptsubscript๐ท๐‘ก๐’ณ๐‘กdelimited-[]๐‘‡D=\{D_{t}\subset\mathcal{X}\}_{t\in[T]}italic_D = { italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โŠ‚ caligraphic_X } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT, i.e., โˆ€t,tโ€ฒโขifโขtโ‰ tโ€ฒโขthenโขDtโˆฉDtโ€ฒ=โˆ…for-all๐‘กsuperscript๐‘กโ€ฒif๐‘กsuperscript๐‘กโ€ฒthensubscript๐ท๐‘กsubscript๐ทsuperscript๐‘กโ€ฒ\forall t,t^{\prime}\text{ if }t\neq t^{\prime}\text{ then }D_{t}\cap D_{t^{%\prime}}=\emptysetโˆ€ italic_t , italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT if italic_t โ‰  italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT then italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT โˆฉ italic_D start_POSTSUBSCRIPT italic_t start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = โˆ…. We say a representation model f๐‘“fitalic_f satisfies the task arithmetic property around ฮธ0subscript๐œƒ0\theta_{0}italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT with respect to T๐‘‡Titalic_T and D๐ทDitalic_D if

fโข(x;ฮธ0+โˆ‘t=1Tฮฑtโขฯ„t)={fโข(x;ฮธ0+ฮฑtโขฯ„t)ifโขxโˆˆDt;fโข(x;ฮธ0)ifโขxโˆ‰โ‹ƒt=1TDt.๐‘“๐‘ฅsubscript๐œƒ0superscriptsubscript๐‘ก1๐‘‡subscript๐›ผ๐‘กsubscript๐œ๐‘กcases๐‘“๐‘ฅsubscript๐œƒ0subscript๐›ผ๐‘กsubscript๐œ๐‘กif๐‘ฅsubscript๐ท๐‘ก๐‘“๐‘ฅsubscript๐œƒ0if๐‘ฅsuperscriptsubscript๐‘ก1๐‘‡subscript๐ท๐‘ก\displaystyle f\left(x;\theta_{0}+\sum_{t=1}^{T}\alpha_{t}\tau_{t}\right)=%\begin{cases}f(x;\theta_{0}+\alpha_{t}\tau_{t})&\text{if }x\in D_{t};\\f(x;\theta_{0})&\text{if }x\notin\bigcup_{t=1}^{T}D_{t}.\end{cases}italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = { start_ROW start_CELL italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_CELL start_CELL if italic_x โˆˆ italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; end_CELL end_ROW start_ROW start_CELL italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_CELL start_CELL if italic_x โˆ‰ โ‹ƒ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . end_CELL end_ROW
Property 4 (Weight disentanglement*)

A parametric function f:๐’ณร—ฮ˜โ†’๐’ด:๐‘“โ†’๐’ณฮ˜๐’ดf:\mathcal{X}\times\Theta\rightarrow\mathcal{Y}italic_f : caligraphic_X ร— roman_ฮ˜ โ†’ caligraphic_Y is weight disentangled with respect to a set of task vectors T={ฯ„t}tโˆˆ[T]๐‘‡subscriptsubscript๐œ๐‘ก๐‘กdelimited-[]๐‘‡T=\{\tau_{t}\}_{t\in[T]}italic_T = { italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT and the corresponding supports D={Dt}tโˆˆ[T]๐ทsubscriptsubscript๐ท๐‘ก๐‘กdelimited-[]๐‘‡D=\{D_{t}\}_{t\in[T]}italic_D = { italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT if

fโข(x;ฮธ0+โˆ‘t=1Tฮฑtโขฯ„t)=โˆ‘t=1Tfโข(x;ฮธ0+ฮฑtโขฯ„t)โข๐Ÿ™โข(xโˆˆDt)+fโข(x;ฮธ0)โข๐Ÿ™โข(xโˆ‰โ‹ƒtโˆˆ[T]Dt).๐‘“๐‘ฅsubscript๐œƒ0superscriptsubscript๐‘ก1๐‘‡subscript๐›ผ๐‘กsubscript๐œ๐‘กsuperscriptsubscript๐‘ก1๐‘‡๐‘“๐‘ฅsubscript๐œƒ0subscript๐›ผ๐‘กsubscript๐œ๐‘ก1๐‘ฅsubscript๐ท๐‘ก๐‘“๐‘ฅsubscript๐œƒ01๐‘ฅsubscript๐‘กdelimited-[]๐‘‡subscript๐ท๐‘ก\displaystyle f\left(x;\theta_{0}+\sum_{t=1}^{T}\alpha_{t}\tau_{t}\right)=\sum%_{t=1}^{T}f(x;\theta_{0}+\alpha_{t}\tau_{t})\mathds{1}(x\in D_{t})+f(x;\theta_%{0})\mathds{1}\left(x\notin\bigcup_{t\in[T]}D_{t}\right).italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) blackboard_1 ( italic_x โˆˆ italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) blackboard_1 ( italic_x โˆ‰ โ‹ƒ start_POSTSUBSCRIPT italic_t โˆˆ [ italic_T ] end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) .(4)

Instead of evaluating weight disentanglement on predictors, weight disentanglement is a property of the representation models and is not related to the performance on different tasks. That is, a model could be weight disentangled with respect to a set of task vectors and still perform poorly on a task, e.g., if fโข(โ‹…;ฮธ0+ฮฑโขฯ„)๐‘“โ‹…subscript๐œƒ0๐›ผ๐œf(\cdot;\theta_{0}+\alpha\tau)italic_f ( โ‹… ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ italic_ฯ„ ) does not generalize for some ฮฑ๐›ผ\alphaitalic_ฮฑ. More generally, we can visualize the level of weight disentanglement of a model by measuring its discrepancy with Eq. (4). To do so, given two tasks, one can check the disentanglement error of a model,

ฮพโข(ฮฑ1,ฮฑ2)=โˆ‘t=12๐”ผxโˆผฮผtโข[distโข(fโข(x;ฮธ0+ฮฑtโขฯ„t),fโข(x;ฮธ0+ฮฑ1โขฯ„1+ฮฑ2โขฯ„2))],๐œ‰subscript๐›ผ1subscript๐›ผ2superscriptsubscript๐‘ก12subscript๐”ผsimilar-to๐‘ฅsubscript๐œ‡๐‘กdelimited-[]dist๐‘“๐‘ฅsubscript๐œƒ0subscript๐›ผ๐‘กsubscript๐œ๐‘ก๐‘“๐‘ฅsubscript๐œƒ0subscript๐›ผ1subscript๐œ1subscript๐›ผ2subscript๐œ2\displaystyle\xi(\alpha_{1},\alpha_{2})=\sum_{t=1}^{2}\mathbb{E}_{x\sim\mu_{t}%}\left[\text{dist}\left(f(x;\theta_{0}+\alpha_{t}\tau_{t}),f(x;\theta_{0}+%\alpha_{1}\tau_{1}+\alpha_{2}\tau_{2})\right)\right],italic_ฮพ ( italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x โˆผ italic_ฮผ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ dist ( italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) ] ,

where dist denotes any distance metric between output vectors. As we are dealing with representation distributions, in what follows we use the KL divergence as the distance metric 333We use prediction error for the combined model for classification task as Ortiz-Jimenez did [17].. In general, the smaller the value of ฮพโข(ฮฑ1,ฮฑ2)๐œ‰subscript๐›ผ1subscript๐›ผ2\xi(\alpha_{1},\alpha_{2})italic_ฮพ ( italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) the more weight disentangled a model is at (ฮฑ1,ฮฑ2)subscript๐›ผ1subscript๐›ผ2(\alpha_{1},\alpha_{2})( italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ).

Figure 4 displays the disentanglement error of a CLIP ViT-B/32 model concerning several task vector pairs. We observe that the ViT model exhibits a minimal disentanglement error within a small region surrounding ฮธ0subscript๐œƒ0\theta_{0}italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, which enables task arithmetic.Different from disentanglement error at downstream tasks, it remains relatively small even for ฮฑ1,ฮฑ2>1subscript๐›ผ1subscript๐›ผ21\alpha_{1},\alpha_{2}>1italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > 1, which indicates the power of task arithmetic has been limited by the performance of task-specific models (classification head).

Comparing the disentanglement error of our models and NTK linearization reveals an interesting finding: linearized models exhibit greater disentanglement than their non-linear counterparts. This is evident from the more extensive regions with low disentanglement errors in Figure 4 (right). This explains why the linear layers fine-tuning only models achieve higher normalized accuracies via task addition (cf. Table 1). The combination of greater disentanglement and better single-task performance comes with higher multi-task performance.

The results in Figures 5 show the cosine similarity between task vectors from ViT, which are three types of fine-tuning (cf. Figure 2) on image classification tasks. We observe that vectors from linear layers only fine-tuning are closer to orthogonal than those from both standard and NTK linearization, which indicates that models fine-tuned with full fine-tuning are more independent than others. This finding is consistent with the discussion about task addition in [1, 18], the experimental results from Table 1 also support our statement. The experimental details are described in Appendix.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (5)

5 Related work

Weight Interpolation and Task Arithmetic.Recent research has explored the use of interpolations between model weights and task arithmetic to manipulate and enhance the capabilities of pre-trained models. Studies have shown that interpolating between a modelโ€™s fine-tuned weights and its pre-trained initialization can improve performance on single tasks, often surpassing the accuracies achieved through fine-tuning alone [34, 35, 36, 37, 38, 39]. In multi-task settings, averaging the parameters of multiple fine-tuned models has been proposed to create superior multi-task models [40, 15, 41, 2, 1], which help avoid catastrophic forgetting [13, 14, 42, 43] and provide better starting points for subsequent fine-tuning [44, 45]. These benefits also extend to models trained from scratch, provided they are properly aligned before merging [46, 47]. This technique has been observed to enhance downstream performance, highlighting the potential of weight interpolation and task arithmetic.

Model Fusion Techniques.Model fusion integrates knowledge from multiple models into a single unified model. One approach focuses on fusing entire models through weight interpolation. By averaging or combining the weights of multiple models, effective model fusion can be achieved, as demonstrated in prior works [48, 2, 1, 36, 40, 49]. When models are not well-aligned or lie in different loss basins, feature alignment techniques are employed before fusion to match the modelsโ€™ behaviors or transform them to a similar loss basin [46, 50]. Although fusing entire models leverages knowledge from all layers, it can be computationally expensive.

Advances in Model Merging for MTL.Model merging has emerged as a promising solution to enhance model generalization and facilitate multi-task learning (MTL). Research in this area includes merging models trained on the same task to improve overall generalization [51, 52] or to perform federated learning [53, 54]. Another approach focuses on merging models for different tasks to enable MTL [36, 46, 55, 1, 15, 56, 57]. However, simple model averaging can significantly deteriorate performance across multiple tasks. To address this, advanced techniques have been developed. For example, Fisher Merging uses the Fisher information matrix to measure the importance of individual model parameters and guide model merging [36]. Although effective, computing the Fisher information matrix is computationally and memory-intensive.

6 Conclusion

In this work, we conducted a thorough analysis of task arithmetic in deep neural networks, delving into its fundamental mechanisms and enhancing its performance. Our findings demonstrate that attention module lies in a strong linear regime within its linear layers, which improve both disentanglement and efficiency.

Understanding the nuanced impact of fine-tuning bias on model performance and disentanglement remains an open question. Future research could provide valuable insights into optimizing these settings, potentially leading to more robust and efficient methods for adapting pre-trained models to various tasks. This could significantly enhance their applicability and effectiveness in real-world scenarios.

References

  • [1]Gabriel Ilharco, MarcoTรบlio Ribeiro, Mitchell Wortsman, Suchin Gururangan, Ludwig Schmidt, Hannaneh Hajishirzi, and Ali Farhadi.Editing models with task arithmetic.In International Conference on Learning Representations (ICLR), 2023.
  • [2]Gabriel Ilharco, Mitchell Wortsman, SamirYitzhak Gadre, Shuran Song, Hannaneh Hajishirzi, Simon Kornblith, Ali Farhadi, and Ludwig Schmidt.Patching open-vocabulary models by interpolating weights.In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  • [3]Fuzhen Zhuang, Zhiyuan Qi, Keyu Duan, Dongbo Xi, Yongchun Zhu, Hengshu Zhu, Hui Xiong, and Qing He.A comprehensive survey on transfer learning.Proceedings of the IEEE, 2020.
  • [4]Ximing Lu, Sean Welleck, Liwei Jiang, Jack Hessel, Lianhui Qin, Peter West, Prithviraj Ammanabrolu, and Yejin Choi.Quark: Controllable text generation with reinforced unlearning.In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  • [5]Long Ouyang, Jeff Wu, XuJiang, Diogo Almeida, CarrollL Wainwright, Pamela Mishkin, Chong Zhang, Sandhini Agarwal, Katarina Slama, Alex Ray, etal.Training language models to follow instructions with human feedback, 2022.
  • [6]MarcoTulio Ribeiro and Scott Lundberg.Adaptive testing and debugging of nlp models.In Annual Meeting of the Association for Computational Linguistics (ACL), 2022.
  • [7]Amelia Glaese, Nat McAleese, Maja Trebacz, John Aslanides, Vlad Firoiu, Timo Ewalds, Maribeth Rauh, Laura Weidinger, Martin Chadwick, Phoebe Thacker, Lucy Campbell-Gillingham, Jonathan Uesato, Po-Sen Huang, Ramona Comanescu, Fan Yang, Abigail See, Sumanth Dathathri, Rory Greig, Charlie Chen, Doug Fritz, Jaume SanchezElias, Richard Green, Sona Mokra, Nicholas Fernando, Boxi Wu, Rachel Foley, Susannah Young, Iason Gabriel, William Isaac, John Mellor, Demis Hassabis, Koray Kavukcuoglu, LisaAnne Hendricks, and Geoffrey Irving.Improving alignment of dialogue agents via targeted human judgements.https://www.deepmind.com/blog/building-safer-dialogue-agents, 2022.
  • [8]Jiancong Xiao, Ziniu Li, Xingyu Xie, Emily Getzen, Cong Fang, QiLong, and WeijieJ Su.On the algorithmic bias of aligning large language models with rlhf: Preference collapse and matching regularization.arXiv preprint arXiv:2405.16455, 2024.
  • [9]Bo-Jian Hou, Lijun Zhang, and Zhi-Hua Zhou.Learning with feature evolvable streams.Advances in Neural Information Processing Systems, 30, 2017.
  • [10]Guillermo Ortiz-Jimรฉnez, Apostolos Modas, Seyed-Mohsen Moosavi-Dezfooli, and Pascal Frossard.Optimism in the face of adversity: Understanding and improving deep learning through adversarial robustness.Proceedings of the IEEE, 2021.
  • [11]Shibani Santurkar, Dimitris Tsipras, Mahalaxmi Elango, David Bau, Antonio Torralba, and Aleksander Madry.Editing a classifier by rewriting its prediction rules.In Advances in Neural Information Processing Systems (NeurIPS), 2021.
  • [12]Matthew Tancik, PratulP. Srinivasan, Ben Mildenhall, Sara Fridovich-Keil, Nithin Raghavan, Utkarsh Singhal, Ravi Ramamoorthi, JonathanT. Barron, and Ren Ng.Fourier features let networks learn high frequency functions in low dimensional domains.In Advances in Neural Information Processing Systems (NeurIPS), 2020.
  • [13]RobertM French.Catastrophic forgetting in connectionist networks.Trends in Cognitive Sciences, 1999.
  • [14]Michael McCloskey and NealJ Cohen.Catastrophic interference in connectionist networks: The sequential learning problem.In Psychology of Learning and Motivation. Elsevier, 1989.
  • [15]Prateek Yadav, Derek Tam, Leshem Choshen, Colin Raffel, and Mohit Bansal.Resolving interference when merging models.In Advances in Neural Information Processing Systems (NeurIPS), 2023.
  • [16]E.Yang, Z.Wang, L.Shen, S.Liu, G.Guo, X.Wang, and D.Tao.Adamerging: Adaptive model merging for multi-task learning.In The Twelfth International Conference on Learning Representations, October 2023.
  • [17]G.Ortiz-Jimenez, A.Favero, and P.Frossard.Task arithmetic in the tangent space: Improved editing of pre-trained models.In Advances in Neural Information Processing Systems, volume36, 2024.
  • [18]A.Tang, L.Shen, Y.Luo, Y.Zhan, H.Hu, B.Du, and D.Tao.Parameter-efficient multi-task model fusion with partial linearizeation.In The Twelfth International Conference on Learning Representations, October 2023.
  • [19]Arthur Jacot, Franck Gabriel, and Clรฉment Hongler.Neural tangent kernel: Convergence and generalization in neural networks.In Advances in Neural Information Processing Systems (NeurIPS), 2018.
  • [20]Alec Radford, JongWook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, and Ilya Sutskever.Learning transferable visual models from natural language supervision.In International Conference on Machine Learning (ICML), 2021.
  • [21]EdwardJ Hu, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Seyeon Wang, LuWang, and Weizhu Chen.Lora: Low-rank adaptation of large language models.In International Conference on Learning Representations, October 2021.
  • [22]Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and LiFei-Fei.Imagenet: A large-scale hierarchical image database.In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2009.
  • [23]BarakA. Pearlmutter.Fast exact multiplication by the hessian.Neural Computation, 6(1):147โ€“160, January 1994.
  • [24]Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby.An image is worth 16x16 words: Transformers for image recognition at scale.In International Conference on Learning Representations (ICLR), 2021.
  • [25]Jonathan Krause, Michael Stark, Jia Deng, and LiFei-Fei.3d object representations for fine-grained categorization.In International Conference on Computer Vision Workshops (ICCVw), 2013.
  • [26]Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, and Andrea Vedaldi.Describing textures in the wild.In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2014.
  • [27]Jianxiong Xiao, KristaA Ehinger, James Hays, Antonio Torralba, and Aude Oliva.Sun database: Exploring a large collection of scene categories.International Journal of Computer Vision (IJCV), 2016.
  • [28]Patrick Helber, Benjamin Bischke, Andreas Dengel, and Damian Borth.Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification.Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 2019.
  • [29]Johannes Stallkamp, Marc Schlipsing, Jan Salmen, and Christian Igel.The german traffic sign recognition benchmark: A multi-class classification competition.In International Joint Conference on Neural Networks (IJCNN), 2011.
  • [30]Yann LeCun.The mnist database of handwritten digits, 1998.
  • [31]Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, BoWu, and AndrewY Ng.Reading digits in natural images with unsupervised feature learning.In Advances in Neural Information Processing Systems (NeurIPS) Workshops, 2011.
  • [32]Gong Cheng, Junwei Han, and Xiaoqiang Lu.Remote sensing image scene classification: Benchmark and state of the art.Proceedings of the IEEE, 2017.
  • [33]Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, AidanN Gomez, ลukasz Kaiser, and Illia Polosukhin.Attention is all you need.Advances in neural information processing systems, 30, 2017.
  • [34]Jonathan Frankle, GintareKarolina Dziugaite, Daniel Roy, and Michael Carbin.Linear mode connectivity and the lottery ticket hypothesis.In International Conference on Machine Learning (ICML), 2020.
  • [35]Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and AndrewGordon Wilson.Averaging weights leads to wider optima and better generalization.In Conference on Uncertainty in Artificial Intelligence (UAI), 2018.
  • [36]Michael Matena and Colin Raffel.Merging models with fisher-weighted averaging.In Advances in Neural Information Processing Systems (NeurIPS), 2021.
  • [37]Alexandre Ramรฉ, Kartik Ahuja, Jianyu Zhang, Matthieu Cord, Lรฉon Bottou, and David Lopez-Paz.Model ratatouille: Recycling diverse models for out-of-distribution generalization.In International Conference on Machine Learning (ICML), 2022.
  • [38]Alexandre Ramรฉ, Matthieu Kirchmeyer, Thibaud Rahier, Alain Rakotomamonjy, Patrick Gallinari, and Matthieu Cord.Diverse weight averaging for out-of-distribution generalization.In Advances in Neural Information Processing Systems (NeurIPS), 2022.
  • [39]Mitchell Wortsman, Gabriel Ilharco, Mike Li, JongWook Kim, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, and Ludwig Schmidt.Robust fine-tuning of zero-shot models.In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022.
  • [40]Mitchell Wortsman, Gabriel Ilharco, SamirYitzhak Gadre, Rebecca Roelofs, Raphael Gontijo-Lopes, AriS Morcos, Hongseok Namkoong, Ali Farhadi, Yair Carmon, Simon Kornblith, etal.Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time.In International Conference on Machine Learning (ICML), 2022.
  • [41]Margaret Li, Suchin Gururangan, Tim Dettmers, Mike Lewis, Tim Althoff, NoahA Smith, and Luke Zettlemoyer.Branch-train-merge: Embarrassingly parallel training of expert language models, 2022.
  • [42]Bo-Jian Hou, Yu-Hu Yan, Peng Zhao, and Zhi-Hua Zhou.Storage fit learning with feature evolvable streams.In Proceedings of the AAAI Conference on Artificial Intelligence, volume35, pages 7729โ€“7736, 2021.
  • [43]Bo-Jian Hou, Lijun Zhang, and Zhi-Hua Zhou.Prediction with unpredictable feature evolution.IEEE Transactions on Neural Networks and Learning Systems, 33(10):5706โ€“5715, 2021.
  • [44]Shachar Don-Yehiya, Elad Venezian, Colin Raffel, Noam Slonim, Yoav Katz, and Leshem Choshen.Cold fusion: Collaborative descent for distributed multitask finetuning, 2022.
  • [45]Leshem Choshen, Elad Venezian, Noam Slonim, and Yoav Katz.Fusing finetuned models for better pretraining, 2022.
  • [46]SamuelK Ainsworth, Jonathan Hayase, and Siddhartha Srinivasa.Git re-basin: Merging models modulo permutation symmetries.In International Conference on Learning Representations (ICLR), 2023.
  • [47]Sidak PalSingh and Martin Jaggi.Model fusion via optimal transport.In Advances in Neural Information Processing Systems (NeurIPS), 2020.
  • [48]Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, and AndrewGordon Wilson.Loss surfaces, mode connectivity, and fast ensembling of dnns.In Conference on Uncertainty in Artificial Intelligence (UAI), 2018.
  • [49]TianYu Liu and Stefano Soatto.Tangent model composition for ensembling and continual fine-tuning, July 2023.arXiv:2307.08114 [cs].
  • [50]Xisen Jin, Xiang Ren, Daniel Preotiuc-Pietro, and Pengxiang Cheng.Dataless knowledge fusion by merging weights of language models.In International Conference on Learning Representations (ICLR), 2023.
  • [51]Vipul Gupta, SantiagoAkle Serrano, and Dennis DeCoste.Stochastic weight averaging in parallel: Large-batch training that generalizes well.In 8th International Conference on Learning Representations (ICLR). OpenReview.net, 2020.
  • [52]Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, and Sungrae Park.Swad: Domain generalization by seeking flat minima.In Advances in Neural Information Processing Systems, volume34, pages 22405โ€“22418, 2021.
  • [53]Xiang Li, Kaixuan Huang, Wenhao Yang, Shusen Wang, and Zhihua Zhang.On the convergence of fedavg on non-iid data.In International Conference on Learning Representations (ICLR), 2019.
  • [54]Hongyi Wang, Mikhail Yurochkin, Yuekai Sun, Dimitris Papailiopoulos, and Yasaman Khazaeni.Federated learning with matched averaging.In International Conference on Learning Representations (ICLR), 2020.
  • [55]George Stoica, Daniel Bolya, Jakob Bjorner, Taylor Hearn, and Judy Hoffman.Zipit! merging models from different tasks without training.In International Conference on Learning Representations (ICLR), 2023.
  • [56]Weishi Li, Yong Peng, Miao Zhang, Liang Ding, Han Hu, and LiShen.Deep model fusion: A survey, 2023.arXiv preprint arXiv:2309.15698.
  • [57]Jinghan Zhang, Shiqi Chen, Junteng Liu, and Junxian He.Composing parameter-efficient modules with arithmetic operations.In Advances in Neural Information Processing Systems (NeurIPS), 2023.
  • [58]Mehdi Cherti, Romain Beaumont, Ross Wightman, Mitchell Wortsman, Gabriel Ilharco, Cade Gordon, Christoph Schuhmann, Ludwig Schmidt, and Jenia Jitsev.Reproducible scaling laws for contrastive language-image learning.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 2818โ€“2829, June 2023.
  • [59]Ilya Loshchilov and Frank Hutter.Decoupled weight decay regularization.In International Conference on Learning Representations (ICLR), 2019.

Appendix A Experimental details

All our experiments were performed using the same hardware consisting of four 3090 NVIDIA GPUs with 24GB of memory each which can be reproduced in less than 150 GPU hours. The details of each experiment are the following.

Datasets.We evaluate task arithmetic on a set of popular benchmark datasets from various domains. The dataset collection includes:

  • โ€ข

    SVHN [31]: The Street View House Numbers dataset is a real-world image dataset for developing machine learning and object recognition algorithms with minimal requirement on data preprocessing and formatting.

  • โ€ข

    MNIST [30]: A database of handwritten digits, with 60,000 training images and 10,000 testing images.

  • โ€ข

    EuroSAT [28]: A dataset based on Sentinel-2 satellite images covering 13 spectral bands, with 10 classes and a total of 27,000 labeled and geo-referenced images.

  • โ€ข

    RESISC45 [32]: The remote sensing image scene classification dataset, consisting of 31,500 images in 45 scene classes.

  • โ€ข

    Cars [25]: This dataset contains images of cars categorized into various fine-grained classes. It is widely used for fine-grained image classification tasks, providing a rich set of vehicle images for training and evaluation.

  • โ€ข

    DTD (Describable Textures Dataset) [26]: This dataset is designed for texture recognition and categorization. It consists of texture images organized into 47 categories, each labeled with attributes describing the texture patterns. It is commonly used to evaluate texture recognition algorithms.

  • โ€ข

    SUN397 [27]: The Scene UNderstanding (SUN) dataset is a large-scale dataset for scene recognition, containing 397 categories with a total of over 100,000 images. It is used to evaluate scene understanding models and to benchmark scene classification algorithms.

  • โ€ข

    GTSRB (German Traffic Sign Recognition Benchmark) [29]: This dataset comprises images of German traffic signs, classified into over 40 categories. It is used to develop and evaluate traffic sign recognition systems, particularly in the context of autonomous driving and intelligent transportation systems.

Fine-tuning.All the fine-tuning experiments follow the same training protocol specified in Ilharco et al. [2] with minor modifications to the training code to use linearized models when needed. In particular, we fine-tune all datasets starting from the same CLIP pre-trained checkpoint downloaded from the open_clip repository [58]. We fine-tune for 2,000 iterations with a batch size of 128, a learning rate of 10โˆ’5superscript10510^{-5}10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT and a cosine annealing learning rate schedule with 200 warm-up steps and the AdamW optimizer [59]. As introduced in Ilharco et al. [2], during fine-tuning, we freeze the weights of the classification layer obtained by encoding a standard set of zero-shot template prompts for each dataset. Freezing this layer does not harm accuracy and ensures that no additional learnable parameters are introduced during fine-tuning [2]. We use this exact same protocol to fine-tune the non-linear and linearized models and do not perform any form of hyperparameter search in our experiments.

Tuning of ฮฑ๐›ผ\alphaitalic_ฮฑ in task arithmetic benchmarks.As in Ilharco et al. [2], we use a single coefficient ฮฑ๐›ผ\alphaitalic_ฮฑ to tune the size of the task vectors used to modify the pre-trained models. This is equivalent to setting ฮฑ=ฮฑ1=โ€ฆ=ฮฑT๐›ผsubscript๐›ผ1โ€ฆsubscript๐›ผ๐‘‡\alpha=\alpha_{1}=\ldots=\alpha_{T}italic_ฮฑ = italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = โ€ฆ = italic_ฮฑ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT in Eq. 1. In the task addition benchmarks, after fine-tuning, we evaluate different scaling coefficients ฮฑโˆˆ{0.0,0.05,0.1,โ€ฆ,1.0}๐›ผ0.00.050.1โ€ฆ1.0\alpha\in\{0.0,0.05,0.1,\ldots,1.0\}italic_ฮฑ โˆˆ { 0.0 , 0.05 , 0.1 , โ€ฆ , 1.0 } and choose the value that achieves the highest target metric on a small held-out proportion of the training set as specified in Ilharco et al. [2]. Namely, maximum normalized average accuracy, and minimum target accuracy on each dataset that still retains at least 95% of the accuracy of the pre-trained model on the control task. The tuning of ฮฑ๐›ผ\alphaitalic_ฮฑ is done independently for non-linear FT, linearized FT, and post-hoc linearization.

Normalized accuracies in task addition.Table 1 shows the normalized accuracies after editing different models by adding the sum of the task vectors on 8 tasks ฯ„=โˆ‘tฯ„t๐œsubscript๐‘กsubscript๐œ๐‘ก\tau=\sum_{t}\tau_{t}italic_ฯ„ = โˆ‘ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Here, the normalization is performed with respect to the single-task accuracies achieved by the model fine-tuned on each task. Mathematically,

Normalized accuracy=1Tโขโˆ‘t=1T[accโข(fโข(x;ฮธ0+โˆ‘tฯ„t))][accโข(fโข(x;ฮธ0+ฯ„t))].Normalized accuracy1๐‘‡superscriptsubscript๐‘ก1๐‘‡delimited-[]acc๐‘“๐‘ฅsubscript๐œƒ0subscript๐‘กsubscript๐œ๐‘กdelimited-[]acc๐‘“๐‘ฅsubscript๐œƒ0subscript๐œ๐‘ก\text{Normalized accuracy}=\frac{1}{T}\sum_{t=1}^{T}\frac{\left[\text{acc}%\left(f(x;\theta_{0}+\sum_{t}\tau_{t})\right)\right]}{\left[\text{acc}\left(f(%x;\theta_{0}+\tau_{t})\right)\right]}.Normalized accuracy = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG โˆ‘ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG [ acc ( italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + โˆ‘ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) ] end_ARG start_ARG [ acc ( italic_f ( italic_x ; italic_ฮธ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_ฯ„ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) ] end_ARG .(5)

Disentanglement error.To produce the weight disentanglement visualizations of Figure 4, we compute the value of ฮพโข(ฮฑ1,ฮฑ2)๐œ‰subscript๐›ผ1subscript๐›ผ2\xi(\alpha_{1},\alpha_{2})italic_ฮพ ( italic_ฮฑ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_ฮฑ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) on a 15ร—15151515\times 1515 ร— 15 grid of equispaced values in [โˆ’2,2]ร—[โˆ’2,2]2222[-2,2]\times[-2,2][ - 2 , 2 ] ร— [ - 2 , 2 ]. To estimate the disentanglement error, we use a random subset of 2,048 test points for each dataset.

Appendix B Further experimental results

We now present additional experiments that expand the findings discussed in the main text.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (6)
Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (7)
Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (8)

B.1 Fine-tuning accuracies

In Figure 6, we report the single-task accuracies achieved by different CLIP models after fine-tuning with different dynamics (referred to as non-linear, NTK linearization, and our method).

B.2 Weight disentanglement on different task pairs

In Figure 7, we illustrate weight disentanglement on different task pairs.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (9)

Appendix C Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here.

Fine-Tuning Linear Layers Only Is a Simple yet Effective Way for Task Arithmetic (2024)
Top Articles
Latest Posts
Recommended Articles
Article information

Author: Horacio Brakus JD

Last Updated:

Views: 5963

Rating: 4 / 5 (71 voted)

Reviews: 94% of readers found this page helpful

Author information

Name: Horacio Brakus JD

Birthday: 1999-08-21

Address: Apt. 524 43384 Minnie Prairie, South Edda, MA 62804

Phone: +5931039998219

Job: Sales Strategist

Hobby: Sculling, Kitesurfing, Orienteering, Painting, Computer programming, Creative writing, Scuba diving

Introduction: My name is Horacio Brakus JD, I am a lively, splendid, jolly, vivacious, vast, cheerful, agreeable person who loves writing and wants to share my knowledge and understanding with you.