《Fi-GNN: Modeling Feature Interactions via Graph Neural Networks for CTR Prediction》
建模复杂的特征交互,对CTR
预测的成功起到了核心作用。FM
是一个著名的模型,它通过向量内积来建模二阶特征交互。FFM
进一步考虑了 field
信息并引入了 field-aware embedding
。然而,这些 FM-based
模型只能建模二阶交互。
最近,许多基于深度学习的模型被提出来从而学习高阶特征交互,这些模型遵循一个通用的范式:简单地拼接 field embedding
向量,并将其馈入 DNN
或其他专门设计的模型,从而学习交互。例如 FNN, NFM, Wide&Deep, DeepFM
等。然而,这些基于 DNN
的模型都是以 bit-wise
的、隐式的方式来学习高阶特征交互,这缺乏良好的模型解释。
一些模型试图通过引入专门设计的网络来显式地学习高阶交互。例如,Deep&Cross
引入了 Cross Network: CrossNet
,xDeepFM
引入了压缩交互网络(Compressed Interaction Network: CIN
)。尽管如此,它们仍然不够有效和显式,因为它们仍然遵循将 feature field
组合在一起的通用范式来建模交互。简单的 unstructured combination
将不可避免地限制了灵活地、显式地建模不同 feature field
之间复杂交互的能力。
在论文 《Fi-GNN: Modeling Feature Interactions via Graph Neural Networks for CTR Prediction》
中,作者考虑了 multi-field feature
的结构。具体来说,作者用一个名为 feature graph
的图结构来表示 multi-field feature
。直观而言,图中的每个节点对应于一个 feature field
,不同 field
可以通过边进行交互。因此,建模 feature field
之间复杂交互的任务可以转化为建模 feature graph
上的节点交互的任务。为此,作者在 Graph Neural Network: GNN
的基础上设计了一个新的 Feature interaction Graph Neural Network: Fi-GNN
,它能够以灵活的、显式的方式建模复杂的节点交互(即,特征交互)。在 Fi-GNN
中,节点将通过与邻居节点交流 node state
来进行交互,并以 recurrent
的方式更新自己。
AutoInt
用Transformer Encoder Block
建模multi-field feature
,而这里用GNN
来建模multi-field feature
。Transformer Encoder Block
可以视为一个简单的GNN
。
在每一个 time step
中,模型与邻居节点进行 one-hop
的交互。因此, interaction step
的数量等同于特征交互的阶次。此外,在 feature graph
中,边的权重反映了不同 feature interaction
对于 CTR
预测的重要性,而节点的权重反映了每个 feature field
对于 CTR
预测的重要性。这可以提供很好的解释。
总的来说,论文提出的模型能够以显式的、灵活的方式建模复杂的特征交互,并提供良好的可解释性。论文贡献如下:
论文指出了现有工作的局限性,即把 multi-field feature
视为 feature field
的 unstructured combination
。为此,作者首次提出用图结构来表示 multi-field feature
。
论文设计了一个新的模型 Feature interaction Graph Neural Network: Fi-GNN
,从而以更灵活的、显式的方式建模 graph-structured feature
上 feature field
之间的复杂交互。
论文在两个真实世界的数据集上进行的广泛实验表明:所提出的方法不仅可以超越 SOTA
的方法,而且可以提供良好的模型解释。
相关工作:
Feature Interaction in CTR Predict
:建模特征交互是 CTR
预测成功的关键,因此在文献中得到了广泛的研究。
LR
是一种线性方法,它只能对原始单个特征的线性组合建模一阶交互。
FM
通过向量内积来建模二阶特征交互。之后,FM
的不同变体也被提出:
Field-aware factorization machine: FFM
考虑了 field
信息并引入了 field-aware embedding
。
AFM
考虑了不同二阶特征交互的权重。
然而,这些方法只能建模二阶交互,这是不够的。
随着 DNN
在各个领域的成功,研究人员开始用它来学习高阶特征交互,因为它有更深的结构和非线性激活函数。一般的范式是将 field embedding
向量拼接在一起,并将其馈入 DNN
来学习高阶特征交互。
《A convolutional click prediction model》
利用卷积网络建模特征交互。
FNN
在应用 DNN
之前,在 field embedding
上使用预训练的 FM
。
PNN
通过在 field embedding layer
和 DNN layer
之间引入一个 product layer
来建模二阶交互和高阶交互。
类似地,NFM
通过在 field embedding layer
和 DNN layer
之间引入一个 Bi-Interaction Pooling layer
来建模二阶交互,但是随后的操作是 sum
操作,而不是像 PNN
中的拼接操作。
另一个方向上的一些工作试图通过混合架构来联合建模二阶交互和高阶交互:Wide&Deep
和 DeepFM
包含一个 wide
组件来建模低阶交互、一个 deep
组件来建模高阶交互。
然而,所有这些利用 DNN
的方法都是以隐式的、 bit-wise
的方式学习高阶特征交互,因此缺乏良好的模型解释能力。 最近,一些工作试图通过专门设计的网络以显式的方式学习特征交互:
Deep&Cross
引入了一个在 bit-level
上对特征进行外积的 CrossNet
。
相反,xDeepFM
引入了一个在 vector-level
对特征进行外积的 CIN
。
然而,他们仍然没有解决最根本的问题,即把 field embedding
向量拼接起来。
对 feature field
进行简单的 unstructured combination
将不可避免地限制了以灵活的、显式的方式建模不同 field
之间复杂交互的能力。为此,我们提出用图结构表示 multi-field feature
,每个节点代表一个 field
,不同的 feature field
可以通过边进行交互。因此,我们可以在图上建模不同 feature field
之间的灵活交互。
Graph Neural Network
:图是一种数据结构,它对一组对象(节点)和它们的关系(边)进行建模。早期的工作通常将图结构的数据转换成序列结构的数据来处理。
无监督的 DeepWalk
算法受 word2vec
的启发,用于学习基于 random walk
的 node embedding
。
之后,LINE
算法保留了图的一阶结构信息和二阶结构信息。
node2vec
引入了一个有偏的随机行走。
然而,这些方法的计算成本很高,而且对于大型图而言也不是最优的。图形神经网络(graph neural network: GNN
)就是为了解决这些问题而设计的,它是基于深度学习的方法,在 graph domain
上运行。现在已经有很多 GNN
的变种,这里我们只介绍一些有代表性的经典方法:
Gated Graph Neural Network: GGNN
使用 GRU
作为更新器。
Graph Convolutional Network: GCN
考虑了图的 spectral structure
并利用卷积聚合器。
GraphSAGE
考虑了空间信息,并引入了三种聚合器:mean aggregator, LSTM aggregator, Pooling aggregator
。
graph attention network: GAT
将注意力机制纳入消息传播步骤。
由于 GNN
具有令人信服的性能和较高的可解释性,GNN
已经成为一种广泛应用的图分析方法。在这项工作中,我们提出了一个基于 GGNN
的模型 Fi-GNN
来为 CTR
预测建模特征交互。
假设训练数据集由 field
的 categorical feature
、以及表示用户点击行为的 label
CTR
预测任务是对输入特征(包含 field
)来预测用户点击的概率
下图是我们所提出方法的概览(
输入的 sparse m-field feature vector
首先被映射成稀疏的 one-hot
向量,然后通过 embedding layer
和 multi-head self-attention layer
嵌入到稠密的 field embedding
向量中。
然后, field embedding
向量被表示为一个 feature graph
,其中每个节点对应于一个 feature field
,不同的 feature field
可以通过边进行交互。因此,建模交互的任务可以转换为建模 feature graph
上的节点交互。因此, feature graph
被馈入 Fi-GNN
从而建模节点交互。
最后,在 Fi-GNN
的输出上应用一个 Attentional Scoring Layer
来估计点击率
这里的
Multi-head Self-Attention Layer
就是单层的AutoInt
,因此,Fi-GNN
相当于是AutoInt
和GNN
的堆叠。实验并没有表明AutoInt
在这里的贡献,而且即使是AutoInt + Fi-GNN
,模型在所有数据集上的整体效果提升也不明显,因此论文价值不大。
Embedding Layer
:我们将每个 field
表示为一个 ont-hot encoding
向量,然后将其嵌入到一个稠密向量中,记做 field embedding
向量 。field
的 field embedding
向量被拼接为(沿着 feature field
维度拼接):
其中:field
embedding
向量,field embedding
向量的维度,feature field
维度拼接。
Multi-head Self-attention Layer
:我们利用多头自注意力机制来捕获不同语义子空间中的 pairwise
特征交互。
遵从 《AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks》
,给定 feature embedding
矩阵 feature representation
,它覆盖了 attention head
pairwise interaction
:
其中:attention head
head
然后,我们将学到的每个 head
的 feature representation
结合起来,以保留每个语义子空间中的 pairwise feature interaction
:
其中:embedding
维度),attention head
数量,embedding
维度。
Feature Graph
:与以往简单地将 field embedding
向量拼接在一起并将其馈入模型中从而学习特征交互所不同的是,我们用图结构来表示 feature field
。具体而言,我们将每个输入的 multi-field feature
表示为一个 feature graph
feature field
field
可以通过边进行交互,所以
每个样本对应一张图,因此这是一个
graph-level
分类任务(二分类)。
Feature Interaction Graph Neural Network
:Fi-GNN
旨在以一种灵活的、显式的方式建模 feature graph
上的节点交互。在Fi-GNN
中,每个节点 hidden state
其中 interaction step
。由多头自注意力层学到的 feature representation
作为图的初始状态
如下图所示,节点以循环方式进行交互并更新其状态。在每一个 interaction step
中,节点聚合邻居节点的状态信息(经过变换之后),然后根据聚合信息、以及节点历史状态通过 GRU
和残差连接来更新节点状态。
State Aggregation
:在 interaction step
其中:
显然,投影矩阵和邻接矩阵决定了节点之间的交互。由于每条边上的交互应该是不同的,我们的目标是建模边上的交互,这需要对每条边有一个 unique
的权重和投影矩阵。
基于注意力的边权重:为了推断不同节点之间交互的重要性,我们建议通过注意力机制来学习边权重。具体而言,从节点 field embedding
向量)来计算:
其中:embedding
维度)。利用 softmax
函数进行归一化,使不同节点的权重容易比较。
最终邻接矩阵为:
由于边的权重反映了不同交互的重要性,Fi-GNN
可以很好地解释输入样本的不同 feature field
之间的关系,这一点将在实验部分进一步讨论。
edge-wise
变换:如前所述,所有边上的固定的投影矩阵无法建模灵活的交互,对每个边进行 unique
的变换是必要的。然而,我们的图是完全图(complete graph
)(即,任意两个节点之间都存在边),因此包含大量的边。简单地给每条边分配一个 unique
的投影矩阵将消耗太多的参数空间和运行时间。为了减少时间和空间的复杂性,同时实现 edge-wise transformation
,我们为每个节点
因此聚合信息
这样一来,参数的数量与节点的数量成正比,而不是与边的数量成正比,这就大大降低了空间复杂性和时间复杂性,同时也实现了 edge-wise interaction
。
State Update
:聚合状态信息之后,节点将通过 GRU
和残差连接来更新状态向量。
通过 GRU
进行状态更新:根据传统的 GGNN
,节点 step
的状态通过 GRU
更新的:
通过残差连接进行状态更新:我们引入了额外的残差连接(来自初始状态),与 GRU
一起更新节点状态,这可以促进低阶特征重用和梯度反向传播:
注意,这里是
的残差连接,而不是 。
Attentional Scoring Layer
:经过 propagation step
之后,我们得到了 final node state
:
由于节点已经与它们的 Fi-GNN
建模了 graph-level output
来预测 CTR
。
我们分别对每个 field
的 final state
预测一个得分,并通过注意力机制将它们相加,这个注意力机制衡量它们对整体预测的影响。正式地,每个节点 attentional node weight
可以通过两个 MLP
分别得到:
整体预测是所有节点的预测的加权和:
训练:损失函数为 logloss
,即:
其中:label
;CTR
。
我们采用 RMSProp
优化器。此外,为了平衡正负样本比例,在训练过程中,对于每个 batch
我们随机选择相同数量的正样本和负样本。
数据集:Criteo, Avazu
。对于这两个数据集:
我们移除了低频特征,并将低频特征替换为 "<unknown>"
。频次阈值分别为:Criteo
数据集为 10
、Avazu
数据集为 5
。即出现频次低于该阈值则移除。
由于数值特征可能具有较大的方差,因此我们进行对数变换:
这是由 Criteo
竞赛的获胜者提出的。
数据集以 8:1:1
的比例随机拆分为训练集、验证集、测试集。
数据集的统计信息如下表所示。
评估指标:AUC, LogLoss, Relative Improvement (RI)
。
应该注意的是,对于真实世界的 CTR
任务来说,AUC
方面的微小改进被认为是显著的。为了估计我们的模型相对于 baseline
模型的相对改进,我们在此测量 RI-AUC
和 RI-Logloss
:
其中 X
为 AUC
或 LogLoss
。
baseline
方法:
LR
:通过原始特征的线性组合来建模一阶特征交互。
FM
:通过 field embedding
向量的内积来建模二阶特征交互。
AFM
:是 FM
的一个扩展,利用注意力机制考虑不同二阶特征交互的权重。
DeepCrossing
:利用具有残差连接的 DNN
以隐式的方式学习高阶特征交互。
NFM
:利用 Bi-Interaction Pooling layer
来建模二阶特征交互,然后将拼接的二阶组合特征馈入 DNN
来建模高阶特征交互。
CrossNet(Deep&Cross)
:是 Deep&Cross
模型的核心,它通过采用拼接的 feature vector
的外积,从而显式地在 bit-wise level
上建模特征交互。
CIN(xDeepFM)
:是 xDeepFM
模型的核心,它通过采用堆叠的 feature matrix
的外积,从而显式地在 vector-wise level
上建模特征交互。
实现细节:我们使用 Tensorflow
实现我们的方法。最优超参数由网格搜索策略确定。baseline
的实现遵循 《AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks》
。
对于所有方法, field embedding
向量的维度是 16
, batch size = 1024
。
DeepCrossing
有四个前馈层,每层有 100
个隐单元。
NFM
在 Bi-Interaction layer
之上有一个大小为 200
的隐层,如原始论文中所推荐的。
CrossNet
和 CIN
都有三个交互层。
所有实验都是在配备了 8
个 NVIDIA Titan X GPU
的服务器上进行的。
不同模型的性能比较,如下表所示。可以看到:
LR
在这些 baseline
中效果最差,这证明了单个特征在 CTR
预测中是不够的。
在所有数据集上,建模二阶特征交互的 FM
和 AFM
优于 LR
,这表明建模 feature field
之间的 pairwise
交互是有效的。此外,AFM
比 FM
具有更好的表现,证明了在不同交互上的注意力的有效性。
建模高阶特征交互的方法大多优于建模二阶特征交互的方法。这表明二阶特征交互是不够的。
DeepCrossing
优于 NFM
,证明了残差连接在 CTR
预测中的有效性。
在两个数据集上,Fi-GNN
在所有这些方法中取得了最好的性能,尤其是在 Criteo
数据集上。
Fi-GNN
在 Criteo
数据集上取得的相对改进,高于在 Avazu
数据集上取得的相对改进。这可能是因为 Criteo
数据集中有更多的 feature field
,可以更好地利用图结构的表达能力。
消融研究:我们提出的 Fi-GNN
模型是基于 GGNN
的,在此基础上我们主要做了两个改进:
通过 attentional edge weight
和 edge-wise transformation
实现 edge-wise node interaction
。
引入残差连接从而与 GRU
一起更新节点状态。
为了评估两种改进的有效性,我们对比了三个变体:
Fi-GNN(-E/R)
:同时没有上述两个改进的变体。
Fi-GNN(-E)
:没有 edge-wise interaction: E
的变体。即,用二元邻接矩阵、以及所有边上共享的投影矩阵。
Fi-GNN(-R)
:没有 residual connection: R
的变体。
对比结果如下图 (a)
所示。可以看到:
Fi-GNN(-E)
的性能相比完整的 Fi-GNN
大幅下降,这表明建模 edge-wise interaction
是至关重要的。
Fi-GNN(-E)
取得了比 Fi-GNN(-E/R)
更好的性能,证明了残差连接确实可以提供有用的信息。
完整的 Fi-GNN
优于三种变体,表明我们所做的两种改进,即残差连接和 edge-wise interaction
,可以联合提高性能。
在 Fi-GNN
中,我们采用两种方法来实现 edge-wise node interaction
:attentional edge weight: W
、edge-wise transformation: T
。为了进一步研究巨大的改进来自哪,我们比较了另外三个变体:
Fi-GNN(-W/T)
:即 Fi-GNN-(E)
。
Fi-GNN(-W)
:没有 attentional edge weight
。
Fi-GNN(-T)
:没有 edge-wise transformation
,即所有边上共享投影矩阵。
对比结果如下图 (b)
所示。可以看到:
Fi-GNN(-T)
和 Fi-GNN(-W)
都优于 Fi-GNN(-W/T)
,这证明了它们的有效性。
Fi-GNN(-W)
比 Fi-GNN(-T)
实现了更大的改进,这表明在建模 edge-wise interaction
方面, edge-wise transformation
比 attentional edge weight
更有效。这是非常合理的,因为投影矩阵应该比标量的 attentional edge weight
对 edge-wise interaction
有更强的影响。
超参数研究:
state
维度 32
(Avazu
数据集)、64
(Criteo
数据集)时性能最佳。这是合理的,因为 Criteo
数据集更复杂,需要更大的维度来保持足够的信息。
没有考虑
attention head
的影响?
interaction step
interaction step
等于特征交互的最高阶次。模型性能随着 2
(Avazu
数据集)、3
(Criteo
数据集)时性能最佳。这是合理的,因为 Avazu
数据集有 23
个 feature field
、Criteo
数据集有 39
个 feature field
。因此,Criteo
数据集需要更多的 interaction step
来使 field node
与feature graph
中的其他节点完全交互。
模型可解释性:我们在 feature graph
的边上和节点上都应用了注意力机制,分别得到了 attentional edge weight
和 attentional node weight
,可以从不同的角度给出解释。
Multi-head Self-attention Layer
捕获的pair-wise
交互是否也是可解释的?论文并没有说明这一点。
attentional edge weight
:attentional edge weight
反映了两个相连的 field node
之间交互的重要性,也反映了两个feature field
之间的关系。下图展示了 Avazu
数据集中所有样本的全局平均邻接矩阵的热力图,它可以在全局水平上反映不同 field
之间的关系。 由于有一些 field
是匿名的,我们只显示剩余的 13
个具有真实含义的 feature field
。
可以看到:
一些 feature field
倾向于与其他 field
有很强的关系,例如 site_category
和 site_id
。这是有意义的,因为两个 feature field
都对应于投放广告的网站。
hour
是另一个与其他 field
有密切关系的特征。这是合理的,因为Avazu
专注于移动场景,用户可以在一天的任何时间在线冲浪。上网时间对其他的广告特征有很大的影响。
另一方面,device_ip
和 device_id
似乎与其他 feature field
的关系较弱。这可能是因为它们几乎等同于 user id
,相对固定,不易受其他特征的影响。
attentional node weight
:attentional node weight
反映了 feature field
对整体预测分数的影响的重要性。下图显示了 global-level
和 case-level
的 attentional node weight
的热力图。左边的是 Avazu
数据集中所有样本的全局平均值,右边的是Avazu
数据集中随机选择的四个样本(预测分数分别为 [0.97, 0.12, 0.91, 0.99]
,标签分别为 [1, 0, 1, 1]
)。
在 global level
,我们可以看到 featuer field app_category
对点击行为的影响最大。这是合理的,因为 Avazu
专注于移动场景,而 app
是最重要的因素。
在 case level
,我们观察到,在大多数情况下,最终的点击行为主要取决于一个关键的 feature field
。