训练后的 Transformer 可以在上下文中学习线性模型

Posted on Mon, Oct 21, 2024 📖 Note LLM

预备知识

符号 [n]={1,2,,n}[n]=\{1,2,\dots,n\}\otimes 表示 Kronecker 积,Vec\text{Vec} 表示按列序向量化操作符。例如 Vec(1234)=(1,3,2,4)\text{Vec}\begin{pmatrix}1&2\\3&4\end{pmatrix}=(1,3,2,4)^\top。将两矩阵 A,BRm×nA,B\in\mathbb{R}^{m\times n} 的内积写为 A,B=tr(AB)\langle A,B\rangle=\text{tr}(AB^\top)。使用 0n0_n0m×n0_{m\times n} 分别表示大小为 nnn×mn\times m 的零向量和零矩阵。对于一个一般的矩阵 AAAk:A_{k:}A:kA_{:k} 分别表示第 kk 行和第 kk 列。将矩阵的算子范数和 Frobenius 范数记为 op\|\cdot\|_{op}F\|\cdot\|_F。使用 IdI_d 表示 dd 维单位矩阵,有时上下文维度清晰时使用 II。对于一个半正定矩阵 AA,记 xA2xAx\|x\|_A^2\coloneqq x^\top Ax。除非指定,使用小写字母用于标量和向量,使用大写字母用于矩阵。

上下文学习

首先描述一个函数类的上下文学习框架。上下文学习指的是模型的一种行为,其操作输入-输出对的序列,称为提示(x1,y1,,xN,yN,xquery)(x_1,y_1,\dots,x_N,y_N,x_\text{query}),其中 yi=h(xi)y_i=h(x_i) 对于某个(未知的)函数 hh,样例 xix_i,查询 xqueryx_\text{query}。上下文学习器的目标是使用提示来形成一个对于查询的预测 y^(xquery)\hat{y}(x_\text{query}),使 y^(xquery)h(xquery)\hat{y}(x_\text{query})\approx h(x_\text{query})

对于线性模型,可以将普通最小二乘法(OLS)视为一个“上下文学习器”。但上下文学习器一个相当独特的特征是这些学习算法可以是定义在一个提示分布上的随机优化问题的解。我们在下面的定义中形式化这个概念。

定义 1(在上下文样本上训练)Dx\mathcal{D}_x 是输入空间 X\mathcal{X} 上的一个分布,HYX\mathcal{H}\sub\mathcal{Y}^\mathcal{X} 是函数 XY\mathcal{X}\rightarrow\mathcal{Y} 的一个集合,DH\mathcal{D}_\mathcal{H}H\mathcal{H} 中一个函数上的分布。令 :Y×YR\ell:\mathcal{Y}\times\mathcal{Y}\rightarrow\mathbb{R} 为一个损失函数。令 S=nN{(x1,y1,,xn,yn):xiX,yiY}\mathcal{S}=\bigcup_{n\in\mathbb{N}}\{(x_1,y_1,\dots,x_n,y_n):x_i\in\mathcal{X},y_i\in\mathcal{Y}\} 为有限长度 (x,y)(x,y) 对序列集合,且令

FΘ={fθ:S×XY,θΘ}\mathcal{F}_\Theta=\{f_\theta:\mathcal{S}\times\mathcal{X}\rightarrow\mathcal{Y},\theta\in\Theta\}

为参数为在某个集合 Θ\Theta 中的 θ\theta 的一类函数。对于 N>0N>0,我们说模型 f:S×XYf:\mathcal{S}\times\mathcal{X}\rightarrow\mathcal{Y} 是在 H\mathcal{H} 中的函数的上下文样本上在损失 \ell 下关于 (DH,Dx)(\mathcal{D}_\mathcal{H},\mathcal{D}_x) 训练的,如果 f=fθf=f_{\theta^*} 其中 θΘ\theta^*\in\Theta 满足

θarg minθΘEP=(x1,h(x1),,xN,h(xN),xquery)[(fθ(P),h(xquery))],(1)\theta^*\in\argmin_{\theta\in\Theta}\mathbb{E}_{P=(x_1,h(x_1),\dots,x_N,h(x_N),x_\text{query})}[\ell(f_\theta(P),h(x_\text{query}))], \tag{1}

其中 xi,xqueryi.i.dDxx_i,x_\text{query}\overset{\text{i.i.d}}{\sim}\mathcal{D}_xhDHh\sim\mathcal{D}_\mathcal{H} 是独立的。我们称 NN 为训练期间看到的提示长度。

该定义自然引出了一个从数据中学习学习算法的方法:通过采样随机函数 hDHh\sim\mathcal{D}_\mathcal{H} 和特征向量 xi.xqueryi.i.dDxx_i.x_\text{query}\overset{\text{i.i.d}}{\sim}\mathcal{D}_x 来采样对立的提示,然后使用随机梯度下降或其他随机优化算法来最小化 (1) 中的目标函数。这个过程返回一个从上下文样本中学习的且能用给定的训练数据序列来形成测试(查询)样本的预测的模型。下面的定义量化了对应于一个特定的假设类这样一个模型能在上下文样本上表现多好。

定义 2(假设类的上下文学习)Dx\mathcal{D}_x 是输入空间 X\mathcal{X} 上的一个分布,HYX\mathcal{H}\sub\mathcal{Y}^\mathcal{X} 是函数 XY\mathcal{X}\rightarrow\mathcal{Y} 的一个集合,DH\mathcal{D}_\mathcal{H}H\mathcal{H} 中一个函数上的分布。令 :Y×YR\ell:\mathcal{Y}\times\mathcal{Y}\rightarrow\mathbb{R} 为一个损失函数。令 S=nN{(x1,y1,,xn,yn):xiX,yiY}\mathcal{S}=\bigcup_{n\in\mathbb{N}}\{(x_1,y_1,\dots,x_n,y_n):x_i\in\mathcal{X},y_i\in\mathcal{Y}\} 为有限长度 (x,y)(x,y) 对序列集合。我们说定义在提示形式为 P=(x1,h(x1),,xM,h(xM),xquery)P=(x_1,h(x_1),\dots,x_M,h(x_M),x_\text{query}) 的模型 f:S×XYf:\mathcal{S}\times\mathcal{X}\rightarrow\mathcal{Y} 在上下文中在损失 \ell 下学习了一个关于 (DH,Dx)(\mathcal{D}_\mathcal{H},\mathcal{D}_x) 的假设类 H\mathcal{H} 至误差 ηR\eta\in\mathbb{R} 如果存在一个函数 MDH,Dx(ϵ):(0,1)NM_{\mathcal{D}_\mathcal{H},\mathcal{D}_x}(\epsilon):(0,1)\rightarrow\mathbb{N} 使得对于每个 ϵ(0,1)\epsilon\in(0,1),且对于每个长度为 MMDH,Dx(ϵ)M\ge M_{\mathcal{D}_\mathcal{H},\mathcal{D}_x}(\epsilon) 的提示 PP

EP=(x1,h(x1),,xM,h(xM),xquery)[(f(P),h(xquery))]η+ϵ,(2)\mathbb{E}_{P=(x_1,h(x_1),\dots,x_M,h(x_M),x_\text{query})}\left[\ell\left(f(P),h(x_\text{query})\right)\right]\le\eta+\epsilon, \tag{2}

其中期望是在 xi.xqueryi.i.dDxx_i.x_\text{query}\overset{\text{i.i.d}}{\sim}\mathcal{D}_xhDHh\sim\mathcal{D}_\mathcal{H} 上的随机性。

上述定义 2 中的加性误差项 η\eta 允许模型不实现任意小的误差的可能性。这个误差可能来源于使用一个不够复杂到学习 H\mathcal{H} 中函数的模型或者来源于考虑一个不可实现的设定,其中不可能实现任意小的误差。

有了这两个定义,我们可以形式化下面的问题:假设给定一个函数类 FΘ\mathcal{F}_\ThetaDH\mathcal{D}_\mathcal{H} 对应于假设类 H\mathcal{H} 中假设的随机实例。在关于 (DH,Dx)(\mathcal{D}_\mathcal{H},\mathcal{D}_x)H\mathcal{H} 中函数的上下文样本上训练的从 FΘ\mathcal{F}_\Theta 中的模型可以在上下文中学习到有小的预测误差的关于 (DH,Dx)(\mathcal{D}_\mathcal{H},\mathcal{D}_x) 的假设类 H\mathcal{H} 吗?标准的基于梯度的优化算法是否足以从上下文样本中训练模型?上下文在训练期间和测试时必须多长才能实现小的预测误差?在剩余小节中,当假设类是线性模型时,我们将用线性自注意力模块回答这些问题,感兴趣的损失是平方损失,边缘是(可能是各向异性的)高斯边缘。

线性自注意力网络

我们首先回顾基于 softmax 的单头自注意力模块。令 ERde×dNE\in\mathbb{R}^{d_e\times d_N} 为一个使用长度为 NN 的提示 (x1,y1,,xN,yN,xquery)(x_1,y_1,\dots,x_N,y_N,x_\text{query}) 形成的嵌入矩阵。用户自由决定如何用提示形成这个嵌入矩阵。一种自然的方法是在 EE 的前 NN 列堆叠 (xi,yi)Rd+1(x_i,y_i)^\top\in\mathbb{R}^{d+1} 然后最后一列是 (xquery,0)(x_\text{query},0)^\top,如果 xiRd,yiRx_i\in\mathbb{R}^d,\,y_i\in\mathbb{R},我们有 de=d+1d_e=d+1dN=N+1d_N=N+1。令 WK,WQRdk×deW^K,W^Q\in\mathbb{R}^{d_k\times d_e}WVRdv×deW^V\in\mathbb{R}^{d_v\times d_e} 为键,查询和值权重矩阵,WPRde×dvW^P\in\mathbb{R}^{d_e\times d_v} 为投影矩阵,以及 ρ>0\rho>0 为归一化因子。softmax 自注意力模块以一个宽度为 dNd_N 的嵌入矩阵 EE 作为输入,输出是一个同样大小的矩阵

fAttn(E;WK,WQ,WV,WP)=E+WPWVEsoftmax((WKE)WQEρ),f_\text{Attn}(E;W^K,W^Q,W^V,W^P)=E+W^PW^VE\cdot\text{softmax}\left(\frac{(W^KE)^\top W^QE}{\rho}\right),

其中 softmax 是按列向应用的,且给定一个向量输入 vvsoftmax(v)\text{softmax}(v) 的第 ii 项为 exp(vi)/sexp(vs)\exp(v_i)/\sum_s\exp(v_s)。softmax 中的 dN×dNd_N\times d_N 矩阵称为自注意力矩阵。注意 fAttnf_\text{Attn} 可以取任意长度的序列作为其输入。

本工作中考虑单层自注意力模块的一个简化版本,其更适合理论分析且仍能够进行上下文学习线性模型。具体来说,我们考虑单层线性自注意力(LSA)模型,其是 fAttnf_\text{Attn} 的一个修改版本,移除了 softmax 非线性,将投影和值矩阵合并为一个矩阵 WPVRde×deW^{PV}\in\mathbb{R}^{d_e\times d_e},将查询和键矩阵合并为一个矩阵 WKQRde×deW^{KQ}\in\mathbb{R}^{d_e\times d_e}。我们将这些矩阵拼接为 θ=(WKQ,WPV)\theta=(W^{KQ},W^{PV}) 且记

fLSA(E;θ)=E+WPVEEWKQEρ.(3)f_\text{LSA}(E;\theta)=E+W^{PV}E\cdot\frac{E^\top W^{KQ}E}{\rho}. \tag{3}

注意到最近关于理解 Transformer 的理论工作着眼于相同的模型。值得注意的是,最近的实证工作表明,具有标准的基于 softmax 的注意力模块的最先进的训练后的视觉 Transformers 使得 (WK)WQ(W^K)^\top W^QWPWVW^PW^V 几乎为单位矩阵的倍数,可以表示在我们考虑的参数化下。

本工作中,对于长度为 NN 的提示,我们是用下面的嵌入,堆叠 (xi,yi)Rd+1(x_i,y_i)^\top\in\mathbb{R}^{d+1} 到前 NN 列,最后一列是 (xquery,0)Rd+1(x_\text{query},0)^\top\in\mathbb{R}^{d+1}

E=E(P)=(x1x2xNxqueryy1y2yN0)R(d+1)×(N+1).(4)E=E(P)=\begin{pmatrix} x_1& x_2&\cdots& x_N& x_\text{query}\\ y_1&y_2&\cdots&y_N&0 \end{pmatrix}\in\mathbb{R}^{(d+1)\times(N+1)}. \tag{4}

取归一化因子 ρ\rho 为嵌入矩阵 EE 的宽度减一,即 ρ=dN1\rho=d_N-1,因为 EEE\cdot E^\top 中的每个元素是长度为 dNd_N 的两个向量的内积。在上述的词元嵌入下,我们取 ρ=N\rho=N。注意到也有其他方式使用提示数据形成嵌入矩阵,例如,将所有输入和标签填充至等长向量然后将它们排布到矩阵中,或堆叠拼接 (xi,yi)(x_i,y_i) 的线性转置的列,尽管上下文学习的动态在不同参数化下会有所不同。

网络对于词元 xqueryx_\text{query} 的预测是 fLSAf_\text{LSA} 矩阵输出的底部右侧的项,即

y^query=y^query(E;θ)=[fLSA(E;θ)](d+1),(N+1).\hat{y}_\text{query}=\hat{y}_\text{query}(E;\theta)=[f_\text{LSA}(E;\theta)]_{(d+1),(N+1)}.

由于预测只取 LSA 层的词元矩阵输出的底部右侧的项,事实上只有 WPVW^{PV}WKQW^{KQ} 的部分会影响预测。为观察如何计算,让我们令

WPV=(W11PVw12PV(w21PV)w22PV)R(d+1)×(d+1),WKQ=(W11KQw12KQ(w21KQ)w22KQ)R(d+1)×(d+1),(5)W^{PV}=\begin{pmatrix}W_{11}^{PV}&w_{12}^{PV}\\(w_{21}^{PV})^\top&w_{22}^{PV}\end{pmatrix}\in\mathbb{R}^{(d+1)\times(d+1)},\enspace W^{KQ}=\begin{pmatrix}W_{11}^{KQ}&w_{12}^{KQ}\\(w_{21}^{KQ})^\top&w_{22}^{KQ}\end{pmatrix}\in\mathbb{R}^{(d+1)\times(d+1)}, \tag{5}

其中 W11PVRd×d;w12PV,w21PVRd;w22PVR;W_{11}^{PV}\in\mathbb{R}^{d\times d};w_{12}^{PV},w_{21}^{PV}\in\mathbb{R}^d;w_{22}^{PV}\in\mathbb{R};W11KQRd×d;w12KQ,w21KQRd;w22KQRW_{11}^{KQ}\in\mathbb{R}^{d\times d};w_{12}^{KQ},w_{21}^{KQ}\in\mathbb{R}^d;w_{22}^{KQ}\in\mathbb{R}。则预测 y^query\hat{y}_\text{query}

y^query=((w21PV)w22PV)(EEN)(W11KQ(w21KQ))xquery,(6)\hat{y}_\text{query}=\left((w_{21}^{PV})^\top\enspace w_{22}^{PV}\right)\cdot\left(\frac{EE^\top}{N}\right)\begin{pmatrix}W_{11}^{KQ}\\(w_{21}^{KQ})^\top\end{pmatrix}x_\text{query}, \tag{6}

由于只有 WPVW^{PV} 的最后一行和 WKQW^{KQ} 的前 dd 列会影响预测,这意味着我们可以简单地在下面的小节中将其他所有项取零。

训练过程

本工作中,我们考虑在上下文中学习线性预测器的任务。我们假设训练样本如下采样出。令 Λ\Lambda 为正定协方差矩阵。每个训练提示,索引为 τN\tau\in\mathbb{N},有形式

Pτ=(xτ,1,hτ(xτ,1),,xτ,N,hτ(xτ,N),xτ,query),P_\tau=(x_{\tau,1},h_\tau(x_{\tau,1}),\dots,x_{\tau,N},h_\tau(x_{\tau,N}),x_{\tau,\text{query}}),

其中任务权重 wτi.i.dN(0,Id)w_\tau\overset{\text{i.i.d}}{\sim}\mathbf{N}(0,I_d),输入 xτ,i,xτ,queryi.i.d.N(0,Λ)x_{\tau,i},x_{\tau,\text{query}}\overset{\text{i.i.d.}}{\sim}\mathbf{N}(0,\Lambda),以及标签 hτ(x)=wτ,xh_\tau(x)=\langle w_\tau,x\rangle

每个提示对应于一个嵌入矩阵 EτE_\tau,使用变形 (4) 得到:

Eτ(xτ,1xτ,2xτ,Nxτ,querywτ,xτ,1wτ,xτ,2wτ,xτ,N0)R(d+1)×(N+1).E_\tau\coloneqq\begin{pmatrix} x_{\tau,1}& x_{\tau,2}&\cdots& x_{\tau,N}& x_{\tau,\text{query}}\\ \langle w_\tau,x_{\tau,1}\rangle&\langle w_\tau,x_{\tau,2}\rangle&\cdots&\langle w_\tau,x_{\tau,N}\rangle&0 \end{pmatrix}\in\mathbb{R}^{(d+1)\times(N+1)}.

记在任务 τ\tau 中 LSA 模型在查询标签上的预测是 y^τ,query\hat{y}_{\tau,\text{query}},其是 fLSA(Eτ)f_\text{LSA}(E_\tau) 底部右侧的元素,其中 fLSAf_\text{LSA} 是在 (3) 中定义的线性自注意力模型。BB 个独立提示上的经验风险定义为

L^(θ)=12Bτ=1B(y^τ,querywτ,xτ,query)2.(7)\hat{L}(\theta)=\frac{1}{2B}\sum_{\tau=1}^B\left(\hat{y}_{\tau,\text{query}}-\langle w_\tau,x_{\tau,\text{query}}\rangle\right)^2. \tag{7}

我们考虑梯度流训练后的网络在总体损失上的行为,由无限训练任务/提示的极限 BB\rightarrow\infin 来推导:

L(θ)=limBL^(θ)=12Ewτ,xτ,1,,xτ,N,xτ,query[(y^τ,querywτ,xτ,query)2](8)L(\theta)=\lim_{B\rightarrow\infin}\hat{L}(\theta)=\frac{1}{2}\mathbb{E}_{w_\tau,x_{\tau,1},\cdots,x_{\tau,N},x_{\tau,\text{query}}}[(\hat{y}_{\tau,\text{query}}-\langle w_\tau,x_{\tau,\text{query}}\rangle)^2] \tag{8}

上面的期望是关于在提示中的协变量 {xτ,i}i=1N{xquery}\{x_{\tau,i}\}_{i=1}^N\cup\{x_\text{query}\} 和权重向量 wτw_\tau 取的,即 xτ,i,xτ,queryi.i.d.N(0,Λ)x_{\tau,i},x_{\tau,\text{query}}\overset{\text{i.i.d.}}{\sim}\mathbf{N}(0,\Lambda)wτi.i.dN(0,Id)w_\tau\overset{\text{i.i.d}}{\sim}\mathbf{N}(0,I_d)。梯度流捕捉了在无穷小的步长下的梯度下降的行为,有由下面微分方程给出的动态:

ddtθ=L(θ).(9)\frac{\text{d}}{\text{d}t}\theta=-\nabla L(\theta). \tag{9}

我们考虑满足下面的梯度流的初始化。

假设 3(初始化)σ>0\sigma>0 为一个参数,令 ΘRd×d\Theta\in\mathbb{R^{d\times d}} 为任意满足 ΘΘF=1\|\Theta\Theta^\top\|_F=1 的矩阵且 ΘΛ0d×d\Theta\Lambda\ne 0_{d\times d}。我们假设

WPV(0)=σ(0d×d0d0d1),WKQ(0)=σ(ΘΘ0d0d0).(10)W^{PV}(0)=\sigma\begin{pmatrix}0_{d\times d}&0_d\\0_d^\top&1\end{pmatrix},\enspace W^{KQ}(0)=\sigma\begin{pmatrix}\Theta\Theta^\top&0_d\\0_d^\top&0\end{pmatrix}. \tag{10}

对于一类特定的随机初始化方案,满足这种初始化:如果 MM 有 i.i.d. 从连续分布中抽样的项,则设定 ΘΘ=MM/MMF\Theta\Theta^\top=MM^\top/\|MM^\top\|_F,假设几乎必然被满足。我们使用这种特定的初始化方案的原因将在第 5 节中描述证明时更清楚地说明,但高层次视角来看,这是由于预测 (6) 可以被视为一个两层线性网络的输出,而满足假设 3 的初始化可以让层通过梯度流序列被“平衡”。导致这种平衡条件的随机初始化已被用于许多深度线性网络的理论工作中。我们将在其他随机初始化方案下的收敛的问题留作未来工作。

主要结果

本小节介绍文章的主要结果。首先,4.1 小节我们证明了总体损失上的梯度流会收敛到一个特定的全局最优。我们描述了当给定一个新的预测任务的提示时训练后的 transformer 在这个全局最优的预测误差。我们的描述考虑到了新的提示来自非线性预测任务的可能性。然后,我们为指定的线性回归提示实例化我们的结果,并描述实现小预测误差所需的样本数量,表明当对线性模型的上下文示例进行训练时,transformer 可以在上下文中学习线性模型。

接下来,在 4.2 小节中,我们分析了训练后的 transformer 在各种分布偏移下的行为。我们证明了 transformer 对许多分布偏移是鲁棒的,包括任务偏移(当提示中的标签不是输入的确定性的线性函数时)以及查询偏移(当查询样本 xqueryx_\text{query} 和测试提示比有可能不同的分布)。另一方面,我们证明了 transformer 受协变量分布偏移影响,即当训练提示协变量分布和测试提示协变量分布不同时。

最后,受协变量分布偏移下训练后的 transformer 的失败启发,我们在 4.3 小节中考虑在提示间改变协变量分布的上下文示例上训练的设定。我们证明了由梯度流训练的单层线性自注意力层的 transformer 收敛到总体目标的一个全局最小值上,但训练后的 transformer 仍在新的提示上表现不好。我们补充了我们在线性自注意力案例中的证明,在大型非线性 transformer 架构上进行了实验,表明其在协变量偏移下更加鲁棒。

梯度流的收敛和新任务上的预测误差

首先我们证明在合适的初始化下,梯度流会收敛到全局最优。

定理 4(收敛与极限)考虑 (3) 中定义的线性自注意力网络 fLSAf_\text{LSA} 在总体损失 (8) 上的梯度流。假设初始化满足假设 3 且初始化尺度 σ>0\sigma>0 满足 σ2Γopd<2\sigma^2\|\Gamma\|_{op}\sqrt{d}<2,其中我们定义

Γ(1+1N)Λ+1Ntr(Λ)IdRd×d.\Gamma\coloneqq\left(1+\frac{1}{N}\right)\Lambda+\frac{1}{N}\text{tr}(\Lambda)I_d\in\mathbb{R}^{d\times d}.

梯度流收敛到总体损失 (8) 的一个全局最小值。此外,WPVW^{PV}WKQW^{KQ} 分别收敛到 WPVW^{PV}_*WKQW^{KQ}_*,其中

WKQ=[tr(Γ2)]14(Γ10d0d0),WPV=[tr(Γ2)]14(0d×d0d0d1).(11)W^{KQ}_*=[\text{tr}(\Gamma^{-2})]^{-\frac{1}{4}}\begin{pmatrix}\Gamma^{-1}&0_d\\0_d^\top&0\end{pmatrix},\enspace W^{PV}_*=[\text{tr}(\Gamma^{-2})]^{-\frac{1}{4}}\begin{pmatrix}0_{d\times d}&0_d\\0_d^\top&1\end{pmatrix}. \tag{11}

该定理的完整证明见附录 A。注意到如果我们限制设定为 Λ=Id\Lambda=I_d,则梯度流描述的极限解与 von Oswald 等人(2022)的构造非常相似。因为如果我们乘 WPVW^{PV} 一个常数 c0c\ne0 且同时乘 WKQW^{KQ} c1c^{-1}, transformer 的预测是一样的,对于 Λ=Id\Lambda=I_d 情况,唯一的区别(至缩放)是他们的 WKQW^{KQ} 矩阵的顶部左侧项是 IdI_d 而非我们找到的 (1+(d+1)/N)1Id(1+(d+1)/N)^{-1}I_d

接下来,我们想要描述上述的训练后的网络在给定一个新的提示时的预测误差。让我们考虑形式为 (x1,w,x1,,xM,w,xM,xquery)(x_1,\langle w,x_1\rangle,\dots,x_M,\langle w,x_M\rangle,x_\text{query}) 的提示,其中 wRdw\in\mathbb{R}^dxi,xqueryi.i.d.N(0,Λ)x_i,x_\text{query}\overset{\text{i.i.d.}}{\sim}\mathbf{N}(0,\Lambda)。简单的计算表明全局最优处有参数 WPVW^{PV}_*WKQW^{KQ}_* 的预测 y^query\hat{y}_\text{query}

y^query=(0d1)(1Mi=1Mxixi+1Mxqueryxquery1Mi=1Mxixiw1Mi=1Mwxixi1Mi=1Mwxixiw)(Γ10d0d0)(xquery0)=xqueryΓ1(1Mi=1Mxixi)w.(12)\begin{aligned} \hat{y}_\text{query}&=\begin{pmatrix}0_d^\top&1\end{pmatrix}\begin{pmatrix}\frac{1}{M}\sum_{i=1}^Mx_ix_i^\top+\frac{1}{M}x_\text{query}x_\text{query}^\top&\frac{1}{M}\sum_{i=1}^Mx_ix_i^\top w\\\frac{1}{M}\sum_{i=1}^Mw^\top x_ix_i^\top&\frac{1}{M}\sum_{i=1}^Mw^\top x_ix_i^\top w\end{pmatrix}\begin{pmatrix}\Gamma^{-1}&0_d\\0_d^\top&0\end{pmatrix}\begin{pmatrix}x_\text{query}\\0\end{pmatrix}\\ &=x_\text{query}^\top\Gamma^{-1}\left(\frac{1}{M}\sum_{i=1}^Mx_ix_i^\top\right)w. \end{aligned}\tag{12}

当在训练期间看到的提示的长度 NN 很大时, Γ1Λ1\Gamma^{-1}\approx\Lambda^{-1},当测试提示长度 MM 很大时,1Mi=1MxixiΛ\frac{1}{M}\sum_{i=1}^Mx_ix_i^\top\approx\Lambda,以致 y^queryxqueryw\hat{y}_\text{query}\approx x_\text{query}^\top w。因此,对于足够长的提示长度,训练后的 transformer 确实在上下文中学到了线性预测器类