《Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks》
监督学习是机器学习和数据挖掘的基本任务之一。目标是推断出一个函数,该函数可以将给定的预测变量(predictor variables
)(又叫做特征)作为输入来预测目标(target
)。监督学习具有广泛的应用,包括推荐系统、图像识别等等。
在对离散(categorical
)变量进行监督学习时,重要的是要考虑离散变量之间的交互(interactions
)。例如,考虑使用三个离散变量预测客户收入的 “玩具” 问题:
occupation = {banker,engineer,...}
level = {unior,senior}
gender = {male,female}
虽然 junior bankers
的收入低于 junior engineers
,但是对于 senior level
的客户可能正好相反:senior bankers
的收入通常高于 senior engineers
。如果机器学习模型假设预测变量之间的独立性,并忽略它们之间的交互,那么模型将无法准确预测。例如线性回归模型为每个特征关联一个权重,并将目标预测为所有特征的加权和。
为了利用特征之间的交互,一种常见的解决方案是使用特征的乘积(又叫做交叉特征cross features
)显式增加特征向量。例如多项式回归(polynomial regression: PR
)中学习每个交叉特征的权重。然而 PR
(以及其它类似的基于交叉特征的解决方案,例如 Wide & Deep
中的 wide
组件)的关键问题在于,对于稀疏数据集(其中仅仅一小部分交叉特征被观测到)无法估计未观察到(unobserved
)的交叉特征的参数。
为了解决 PR
的泛化问题,人们提出了分解机(factorization machine: FM
)。FM
将交叉特征的权重参数化为每个特征(构成交叉特征的)的 embedding
向量的内积。通过学习每个特征的 embedding
向量,FM
可以估计任何交叉特征的权重。由于这种泛化能力,FM
已成功应用于各种 application
,从推荐系统到 NLP
。
尽管前景广阔,但是AFM
的作者认为 FM
可能会因为它对所有特征交互使用相同权重来建模而受到阻碍。在实际应用中,不同的预测变量通常具有不同的预测能力,并且并非所有特征都包含用于估计目标的有用信号。例如,和无用特征的交互甚至可能引入噪声并对性能产生不利影响。因此,应该为不太有用的特征的交互分配较低的权重,因为它们对预测的贡献较小。然而,FM
缺乏区分特征交互重要性的能力,这可能导致次优预测。
在论文 《Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks》
中,论文通过区分特征交互的重要性来改进 FM
。作者设计了一个叫做注意力分解机(Attentional Factorization Machine: AFM
)的新模型,它利用了神经网络建模的最新进展(即注意力机制),从而使得不同特征交互对预测有不同的贡献。更重要的是,特征交互的重要性是从数据中自动学习的,无需任何人类领域知识(human domain knowledge
)。
论文对上下文感知预测(context-aware prediction
)和个性化tag
推荐的两个公共 benchmark
数据集上进行了实验。大量实验结果表明,在 FM
上使用注意力机制有两个好处:不仅可以带来更好的性能,而且可以深入了解哪些特征交互对预测的贡献更大。这大大增强了 FM
的可解释性(interpretability
)和透明性(transparency
),允许从业人员对模型行为进行更深入的分析。另外,AFM
始终优于 SOTA
的深度学习方法 Wide & Deep
和 Deep-Cross
,并且 AFM
结构更简单、模型参数更少。
相关工作:FM
主要用于稀疏(setting
)下的监督学习,例如在离散变量通过 one-hot encoding
转换为稀疏特征向量的情况下。与在图像和音频中发现的连续(continuous
)的原始特征不同,web
领域的输入特征大多数是不连续(discrete
)的、离散(categorical
)的。当使用这类稀疏数据进行预测时,对特征之间的交互进行建模至关重要。
与仅对两个实体之间的交互进行建模的矩阵分解(matrix factorization: MF
)相比,FM
被设计为通用的 machine learner
,用于对任意数量的实体之间的交互进行建模。通过指定(specifying
)输入特征,论文 《Factorization machines》
表明FM
可以模拟很多特定的分解模型,例如标准的 MF
、并行因子分析(parallel factor analysis
)、SVD++
。因此,FM
被认为是稀疏数据预测最有效的线性 embedding
方法。
已经提出了FM
的许多变体,如 NFM
在神经网络框架下 deepen
了 FM
从而学习高阶特征交互,FFM
将一个特征关联了多个 embedding
向量从而区分不同 field
的其它特征的交互。在这项工作中,我们通过区分特征交互的重要性来改进 FM
。我们知道有一项与我们方法类似的工作 GBFM
,它通过梯度提升选择 “好的” 特征,并且只对好的特征之间的交互进行建模。对于选定特征之间的交互,GBFM
也是像 FM
一样以相同的权重对它们求和。因此,GBFM
本质上是一种特征选择算法,它与我们的 AFM
有着本质的不同,我们的 AFM
可以学习每个特征交互的重要性。
另一方面,深度神经网络正在变得越来越流行,并且最近被用于在稀疏 setting
下进行预测。具体而言:
Wide & Deep
用于 App
推荐,其中 Deep
组件是在特征 embedding
向量拼接之后的 MLP
,从而学习特征交互。
DeepCross
用于点击率预估,它用深度残差 MLP
来学习交叉特征。
我们指出,在这些方法中,特征交互被深度神经网络隐式地捕获,而不是将每个交互显式建模为两个特征的内积的 FM
。因此,这些深度方法是不可解释的,因为每个特征交互的贡献是未知的。通过使用学习每个特征交互重要性的注意力机制直接扩展 FM
,我们的 AFM
更具有可解释性,并且实验证明优于 Wide & Deep
和 DeepCross
的性能。
作为监督学习的通用机器学习模型,FM
最初用于协同推荐(collaborative recommendation
)。给定一个实值特征向量 FM
通过分解交互参数(interaction parameters
),从而对每个对特征之间的所有交互进行建模,进而估计目标(target
):
其中:
bias
,target
的交互。
factorized interaction
),代表交叉特征 embedding
向量,embedding
向量的维度。
注意,由于存在系数
值得注意的是,FM
以相同的方式对所有特征交互进行建模:
首先,在估计第
其次,所有估计的特征交互 1
。
在实践中,一个常见的现象是:并非所有特征都与预测相关。以新闻分类问题为例,考虑新闻 “美国继续在外国支付透明度方面发挥主导作用” 的分类。显然,除了 “对外支付透明度” 之外的单词并不代表新闻主题(即金融主题)。那些涉及不相关特征的交互可以被视为对预测没有贡献的噪声。
然而,FM
以相同的权重对所有可能的特征交互进行建模,这可能会不利地降低模型的泛化性能。
下图说明了我们提出的 AFM
模型的神经网络架构。为了清晰起见,我们省略了图中的线性回归部分。输入层和 embedding
层与 FM
相同,即:输入特征采用稀疏表示,并将每个非零特征嵌入到一个稠密向量中。接下来我们详细介绍了 pair-wise
交互层和 attention-based
池化层,这是本文的主要贡献。
pair-wise
交互层( Pair-wise Interaction Layer
):受使用内积对每对特征之间的交互进行建模的 FM
的启发,我们在神经网络建模中提出了一个新的 Pairwise Interaction Layer
。该 layer
将 interacted vector
),其中每个交互向量是一对向量的逐元素乘积从而编码这对向量的交互。
形式上,假设特征向量 embedding layer
的输出为 Pairwise Interaction Layer
的输出表示为一组向量:
其中:
通过定制化 Pairwise Interaction Layer
,我们可以在神经网络架构下表达 FM
。为了证明这一点,我们首先使用 sum
池化来压缩 prediction score
上:
其中 prediction layer
的权重和bias
。显然,当我们固定 FM
模型。
注意,我们在 NFM
中提出的Bi-Interaction
池化操作可以视为在 Pairwise Interaction Layer
上使用 sum
池化而来。即:
attention-based
池化层(Attention-based Pooling Layer
):注意力机制自从引入神经网络建模以来,已被广泛应用于许多任务,例如推荐、信息检索、计算机视觉等等。注意力机制的思想是:在将不同部分压缩为单个 representation
时允许不同部分做出不同的贡献。
受 FM
缺点的启发,我们提出通过对交互向量执行加权求和来对特征交互采用注意力机制:
其中 attention score
,可以解释为
为了估计 attention score
。为了解决这个泛化问题,我们使用多层感知机(multi-layer perceptron: MLP
)进一步参数化 attention score
,我们称之为注意力网络(attention network
)。
注意力网络的输入是交互向量,这个交互向量在 embedding
空间中对两个特征的交互信息进行编码。形式上,注意力网络定义为:
其中: attention network
的隐向量维度,称作 attention factor
。
attention score
通过 softmax
函数进行归一化,这是已有工作的常见做法。我们使用 relu
激活函数,实验表明它的性能良好。
attention-based
池化层的输出是一个 embedding
空间中的所有特征交互。然后我们将这个 prediction score
。
总而言之,我们给出了 AFM
模型的整体公式为:
模型的参数为:
当移除 attention network
时,AFM
模型退化为 NFM-0
模型,即标准的 FM
模型。和 NFM
相比,AFM
模型缺少隐层来提取高阶特征交互。
由于 AFM
从数据建模的角度直接增强了 FM
,因此它也可以应用于各种预测任务,包括回归、分类、以及 ranking
。应该使用不同的目标函数来为不同的任务量身定制AFM
模型学习。对于目标
其中:label
,
对于二分类任务、或者带隐式反馈的推荐任务,我们可以最小化 log loss
。在本文中,我们聚焦于回归任务并优化平方损失。
为了优化目标函数,我们采用了随机梯度下降 SGD
。实现 SGD
算法的关键是获得预测值
防止过拟合(overfitting
):过拟合是机器学习模型的永恒问题。结果表明,FM
可能会出现过拟合,因此 L2
正则化是防止 FM
过拟合的重要手段。由于 AFM
比 FM
具有更强的表达能力,因此可能更容易过拟合训练数据。这里我们考虑两种在神经网络模型中广泛使用的防止过拟合的技术:dropout
和 L2
正则化。
dropout
的思想是在训练期间随机丢弃一些神经元。dropout
被证明能够防止神经元对训练数据的复杂的协同适应(co-adaptation
)。
由于 AFM
对特征之间的所有 pairwise
交互进行建模,但并非所有交互都是有用的,因此Pair-wise Interaction Layer
的神经元可能很容易协同适应(co-adapt
)并导致过拟合。因此,我们在 Pair-wise Interaction Layer
上应用 dropout
来避免协同适应。
此外,由于dropout
在测试期间被禁用并且整个网络用于预测,因此 dropout
可以视为使用大量较小的神经网络执行模型平均,这可能会提高性能。
对于单层 MLP
的注意力网络组件,我们在权重矩阵 L2
正则化从而防止可能的过拟合。也就是说,我们优化的实际目标函数是:
其中
我们不在注意力网络上使用 dropout
,因为我们发现在Pair-wise Interaction Layer
和注意力网络上联合使用 dropout
会导致一些稳定性问题并降低性能。
我们进行实验以回答以下问题:
RQ1
:AFM
的关键超参数(即特征交互上的 dropout
和注意力网络的正则化)如何影响其性能?
RQ2
:注意力网络能否有效地学习特征交互的重要性?
RQ3
:和 SOTA
的稀疏数据预测方法相比,AFM
的表现如何?
数据集:我们使用两个公开可用的数据集 Frappe
、MovieLens
。
Frappe
:给出了不同上下文时用户的 app
使用日志记录,一共包含 96203
个 app
。除了 userID, appID
之外,每条日志还包含 8
个上下文特征:天气、城市、daytime
(如:早晨、上午、下午) 等。
采用 one-hot
编码之后,特征有 5382
维。label = 1
表示用户使用了 app
。
MovieLens
:GroupLens
发布的最新 MovieLens
数据集的完整版,包含 17045
个用户在 23743
个 item
上的 49657
种不同的 tag
。这里我们研究个性化的 tag
推荐任务,而不是仅考虑二阶交互的协同过滤。
将 userID,movieID,tag
进行 one-hot
编码之后,特征有 90445
维; label = 1
表示用户给 movie
贴了 tag
。
由于两个原始数据集都只包含正样本(即所有样本的 label
都是 1
),我们为每个正样本随机采样了两个负样本,从而确保预测模型的泛化。
对于 Frappe
数据集,对每条记录,随机采样每个用户在上下文中未使用的其它两个 app
。
对于 MovieLens
数据集,对每个用户每个电影的每个 tag
,随机分配给该电影其它两个该用户尚未分配的 tag
。
每个负样本的 label = -1
。下表给出了最终评估数据集的统计数据。
评估指标:我们将数据集随机划分为训练集(70%
)、验证集(20%
)、测试集(10%
)。验证集用于调优超参数,并在测试集上进行最终的性能比较。为了评估性能,我们采用了均方根误差 RMSE
,其中较低的 RMSE
得分表示更好的性能。
对于模型的预测结果,如果超出了 1
或者 -1
的范围,那么我们将结果截断为 1
或者 -1
。
baseline
方法:我们将 AFM
和以下设计用于稀疏数据预测的竞争方法进行比较。
LibFM
:这是 Rendle
发布的 FM
的官方实现。我们选择 SD learner
与其它所有使用 SGD
(或者变体)优化的方法进行公平比较。
HOFM
:这是高阶 FM
的 TensorFlow
实现,如论文 《Factorization machines》
所述。我们尝试了三阶,因为 MovieLens
数据涉及用户、电影、tag
之间的三元关系。
Wide & Deep
:deep
部分首先拼接特征 embedding
,然后是 MLP
对特征交互进行建模。由于 DNN
的结构难以完全调优 fully tuned
,我们使用了与原始论文中报道的相同的结构:三层 MLP
,隐层维度分别为 [1024, 512, 256]
。而 wide
组件与 FM
的线性回归部分相同。
DeepCross
:它在特征 embedding
拼接之后应用多层残差网络来学习特征交互。我们使用了原始论文中报道的相同结构:五层残差单元(每层单元有两个子隐层),每个残差单元的维度为[512, 512, 256, 128, 64]
。
配置:
所有模型的优化目标为平方损失。
除了 LibFM
,所有方法都是通过 mini-batch Adagrad
算法学习的,其中: Frappe
数据集的 batch size = 128
;MovieLens
数据集的 batch size = 4096
。libFM
使用常规SGD
优化。
所有方法的 embedding size = 256
。如无特殊说明,则注意力因子也是 256
,与 embedding size
相同。
我们仔细调优了 LibFM
和 HOFM
的 L2
正则化,以及 Wide & Deep
和 DeepCross
的 droupout rate
。
我们根据验证集的性能使用了早停策略。
对于 Wide & Deep, DeepCross, AFM
,我们发现 FM
预训练初始化 feature embedding
会导致比随机初始化更低的 RMSE
。因此我们报告了这些模型在预训练初始化时的性能。
首先我们探讨了 dropout
对 Pair-wise Interaction Layer
的影响。我们将 L2
正则化。我们还通过移除 AFM
的注意力组件来验证我们实现的 FM
的 dropout
。下图显示了 AFM
和 FM
在不同 dropout rate
下的验证误差。我们也显示了 LibFM
的结果作为 benchmark
。可以看到:
通过将 dropout rate
设为合适的值,AFM
和 FM
都可以得到显著改善。具体而言,对于 AFM
,Frappe
和 MovieLens
的最佳 dropout rate
分别为 0.2
和 0.5
。这验证了 Pair-wise Interaction Layer
上 dropout
的有效性,从而提高了 FM
和 AFM
的泛化能力。
我们的 FM
实现提供了比 LibFM
更好的性能。原因有两个:
首先,LibFM
使用平凡的 SGD
进行优化,它对所有参数采用固定的学习率。而我们使用 Adagrad
优化 FM
,它根据每个参数的频率来调整每个参数的学习率(即,对高频特征的参数进行较小的更新、对低频特征的参数进行较大的更新)。
其次,LibFM
通过 L2
正则化防止过拟合,而我们采用 dropout
。由于 dropout
的模型平均效应,这可以更有效。
AFM
大大优于 FM
和 LibFM
。即使在不使用 dropout
并且在一定程度上存在过拟合问题的情况下,AFM
的性能也明显优于 LibFM
和 FM
的最佳性能(参考dropout rate = 0
的结果)。这证明了注意力网络在学习特征交互权重方面的好处。
然后我们研究注意力网络上的 L2
正则化是否有利于 AFM
。正如前面的实验所验证的那样,dropout rate
被设置为每个数据集的最佳值。下图为验证误差和
可以看到:
结论:当 AFM
得到了改善。注意,当 AFM
获得的最佳性能。这意味着仅在 Pair-wise Interaction Layer
上应用 dropout
不足以防止 AFM
过拟合。更重要的是,调优注意力网络可以进一步提高 AFM
的泛化能力。
注意:FM
和 LibFM
均没有注意力网络,因此它们在图上都是直线。
我们现在聚焦于分析注意力网络对 AFM
的影响。首先要回答的问题是如何选择一个合适的注意力因子?下图显示了 AFM
的验证误差和注意力因子的关系。注意,针对每个注意力因子已经独立地调优了 AFM
的性能在所有注意力因子上都相当稳定。具体而言,当注意力因子为 1
时,矩阵 AFM
仍然非常强大,并且比 FM
有显著改进。
这证明了 AFM
设计的合理性,该设计基于交互向量来估计特征交互的重要性score
,这是这项工作的关键发现。
下图比较了每个 epoch
的 AFM
和 FM
的训练误差和测试误差。我们观察到 AFM
比 FM
收敛得更快。
在 Frappe
上,AFM
的训练误差和测试误差都远低于 FM
,表明 AFM
可以更好地拟合数据并导致更准确的预测。
在 MovieLens
上,虽然 AFM
的训练误差略高于 FM
,但是较低的测试误差表明 AFM
对未见(unseen
)的数据的泛化能力更好。
除了提高性能之外,AFM
的另一个关键优势是通过每个特征交互的 attention score
来得到更好的可解释性。为了证明这一点,我们通过调研 MovieLens
上每个特征交互的得分来进行一些微观分析( micro-level analysis
)。
我们首先固定 simulate
了 FM
,因此我们记作 FM
。
然后我们固定 feature embedding
,仅训练注意力网络。该模型记作 FM + A
。模型收敛之后,性能提高了大约 3%
,这证明了注意力网络的有效性。
我们从测试集中随机选择 3
个正样本 ,下表中显示了每个特征交互的 attention score
(即 interaction score
(即 0.33*-1.81
表示 attention sore = 0.33
、interaction score = -1.81
(interaction score
可能为负值)。
可以看到:在所有三个交互中,item-tag
交互是最重要的。然而,FM
为所有交互分配相同的重要性得分,导致更大的预测误差。通过使用注意力网络增强FM
(参考 FM+A
那一行),item-tag
交互被分配了更高的重要性得分,从而减少了预测误差。
最后我们比较了不同方法在测试集上的性能。下表给出了每种方法在 embedding size = 256
上获得的最佳性能,以及每种方法的可训练参数数量。M
表示百万。
首先,我们看到 AFM
在所有方法中取得了最佳性能。具体而言:
AFM
通过少于 0.1M
的附加参数、以 8.6%
的相对提升优于 LibFM
。
AFM
以 4.3%
的相对提升优于第二好的方法 Wide & Deep
,同时模型参数少得多。
这证明了 AFM
的有效性。尽管 AFM
是一个浅层模型,但是它比深度学习方法实现了更好的性能。
其次,HOFM
优于 FM
,这归因于 HOFM
对高阶特征交互的建模。然而,略微的提升是基于参数数量几乎翻倍的相当昂贵的代价,因为 HOFM
使用一组独立的 embedding
来建模每个阶次的特征交互。
这指出了未来研究的一个有希望的方向:设计更有效的方法来捕获高阶特征交互。
最后,由于过拟合的严重问题,DeepCross
表现最差。我们发现,DeepCross
的 dropout
效果不佳,这可能是由于它使用了 batch normalization
造成的。考虑到 DeepCross
是所有比较方法中最深层的方法(在 embedding layer
之后叠加了 10
层),这表明更深层的学习并不总是有帮助的,因为更深的模型可能会过拟合并且在实践中更难优化。
AFM
和 NFM
的工作是正交的,其中 NFM
聚焦于建模高阶的、非线性的特征交互的 FM
的深层变体,而 AFM
将注意力机制引入 FM
。未来工作:
通过在 attention-based
池化层之后堆叠多个非线性层来探索 AFM
的深度版本,看看是否可以进一步提高性能。
由于 AFM
具有与非零特征数量呈二次关系的、相对较高的复杂度,因此我们将考虑提高学习效率,例如通过使用 learning to hash
、或者数据采样技术。
另一个有前途的方向是开发用于半监督学习和多视图学习的 FM
变体,例如通过结合广泛使用的图拉普拉斯算子,以及结合协同正则化(co-regularization
)的设计。
最后,我们探索 AFM
为不同 application
建模其它类型的数据,如用于问答的文本数据、以及语义更丰富的多媒体内容。