新智元报道
编辑:LRT
华南理工大学和香港大学的研究人员在 ICML 2024 上提出了一个简单而通用的时空提示调整框架 FlashST,通过轻量级的时空提示网络和分布映射机制,使预训练模型能够适应不同的下游数据集特征,显著提高了模型在多种交通预测场景中的泛化能力。
交通预测的目标是准确预测和分析城市未来的交通模式,这一过程需要同时考虑时间和空间因素。
然而,分布偏移的存在在这一领域构成了一个重大挑战,因为现有模型在面对与训练分布显著不同的测试数据时,往往难以很好地泛化。
为了解决这个问题,华南理工大学、香港大学的研究人员提出了一个简单而通用的时空提示调整框架 FlashST,能够使预训练模型适应于不同下游数据集的特定特征,提高了其在多种预测场景中的泛化能力。
代码地址: https://github.com/HKUDS/FlashST
论文地址: https://arxiv.org/abs/2405.17898
具体来说,所提出的的 FlashST 框架采用了一个轻量级的时空提示网络进行上下文学习,捕捉时空不变知识,并有效地适应不同场景。
此外,文中还引入了一个分布映射机制,对齐预训练和下游数据的数据分布,促进时空预测中有效的知识转移。实验表明 FlashST 在不同类型城市交通数据集中的有效性。
概述
现有挑战
尽管现有时空预测方法已显示出其有效性,但大多数时空预测模型在面对不同下游数据集和任务中的分布变化时往往难以有效泛化。
其中,训练数据与测试数据之间分布不一致的假设成为了真实城市场景中准确预测的障碍。如图 1 所示,直接将从数据集A上学到的参数应用于数据集B的测试,可能因不同数据分布间的时空特征显著变化而导致性能不佳。
因此,有必要通过有效适应这种分布变化来增强时空预测模型的泛化能力,设计适应性方法存在以下难点:
(1)时空上下文信息有效提取: 有效地从下游任务中提取特定的复杂时空上下文信息是至关重要的。然而,赋予预训练模型快速理解并整合仅在测试期间可访问的新领域数据的空间和时间特性的能力是一个巨大的挑战。
(2)弥合训练和测试数据的分布差距: 训练和测试数据集之间经常存在显著的分布差距,尤其是当它们来自不同的时空场景和领域时。通过使模型适应框架能够有效地弥合分布差距并捕获时空不变特征,从而增强模型适应性是至关重要的。
图1:FlashST 背后的动机:左图展示了不同交通数据集中数据分布的多样性,而右图显示了端到端模型的参数对训练集A过度拟合,未能泛化到测试集B
本文贡献
(1)本文提出了一个时空上下文提取机制来解决挑战1,该机制能够捕获来自未见数据的上下文信号,有助于适应多种时空场景。
(2)本文引入了一个统一的分布映射机制来增强 FlashST 框架,该机制弥合了预训练与下游任务之间的分布差距。通过正则化提示嵌入来对齐数据分布,促进从预训练到下游时空预测任务的有效知识转移。
方法
图2: FlashST 整体框架
时空上下文学习
时空上下文学习框架通过一个时空提示网络实现,该网络包含两个主要组成部分:
(1)时空上下文提取机制:高效捕捉感知时间和位置的未见数据中的上下文信号。
通过这种方式,它使模型能够从数据的特定上下文中学习,有助于有效地适应各种时空场景。
(2)时空依赖性建模:将时间和地点之间的复杂关系纳入到上下文时空网络中。通过捕捉和建模这些依赖性,网络能够有效地理解不同时空元素之间的相互依赖和交互。
时空上下文蒸馏
(1)时空数据映射。 本文采用Z-Score 和线性层初始化时空表征,其中线性层用于对时间维度特征进行转换。时空初始化表征表示为 Er, f ,其中r和f分别表示第r个区域的f个特征。
(2)时间上下文整合。 为了从多样化的城市数据中捕获动态和周期性的时空模式,我们在我们的提示网络中引入了时间感知的上下文。这种上下文基于多分辨率时间特征,具体包括一天中的某个时刻z^(d), 和一周中的某一天z^(w)。
时间上下文信号提取过程如下:
(3)空间上下文整合。 为了增强提示网络与区域属性相关的地理上下文信息,我们将城市道路网络结构作为反映空间上下文的编码特征。这一过程首先是构建标准化的拉普拉斯矩阵:
其中I、D 和 A 分别代表单位矩阵、度矩阵和邻接矩阵。邻接矩阵是通过考虑区域之间的距离和道路结构来计算的。
由于拉普拉斯特征向量有效地在欧几里得空间中保留全局图结构信息,我们执行特征值分解以得到△=UΛU^T。提取出特征值矩阵Λ和相应的特征向量矩阵U后,通过将U投影以获得 dr 个最小的非平凡特征向量,得出结构感知的节点属性
由于C在训练集和测试集特征空间中的潜在差异,使用 MLP 来映射这些特征,以增强网络对空间上下文的泛化能力。
随后,我们使用拼接操作整合上述嵌入以获得初始的时空嵌入:
时空依赖建模
(1)时间依赖编码器。 为捕获不同时间段间的依赖性并保留时间演变的数据模式,我们引入了一个轻量级的门控机制,如下:
其中和为可训练参数。
编码后的时间嵌入表示为 。
嵌入包含了关于时间动态和区域特征的重要上下文信息,这对于时间依赖编码器至关重要。这些丰富的信息使我们的上下文学习能够有效地识别不同区域和时间间隔中时空模式的变化,从而有助于精确建模时间相关性。
(2)空间依赖编码器。 我们采用基于图卷积网络的消息传递来编码区域间关联,形式化如下:
其中A表示邻接矩阵,表示可训练参数。
残差网络用于减轻多层图神经网络(GNN)可能导致的过度平滑现象。通过堆叠多层时空编码器,提示网络生成了富有时空语义的表征 Epro
统一分布映射机制
为了弥合预训练和下游任务中多样化未见数据之间的分布差距,我们通过加入分布映射机制来增强 FlashST。此机制的目标是将预训练数据和下游数据都转换到一个共享的分布空间中。这种数据分布的对齐使得知识能够无缝转移,确保从预训练阶段获得的知识能够有效地应用于下游的时空上下文。
为实现这一目标,FlashST 采用标准化提示嵌入以确保在多样的下游数据集中保持一致的分布。
我们从对比学习的多项工作中汲取灵感,通过引入基于 infoNCE 的损失函数来规范提示网络的表示生成。
该损失函数旨在使正样本对的表示更加接近,同时将负样本对的表示推开。通过利用无需额外标注的自监督学习,优化 infoNCE 损失有助于实现更均匀的嵌入分布。
相关研究表明,仅通过这一损失,几乎可以实现完全均匀的分布。在此基础上,我们使用 infoNCE 损失来调整学习到的时空提示嵌入的分布:
其中,余弦相似度函数 cos(⋅) 用于衡量嵌入之间的相似度,温度系数τ 用于调整 softmax 的比率。
FlashST 通过增加不同区域对应嵌入之间的分离程度来增强提示嵌入的均匀性。这种改进使得下游模型能够有效地利用提供的提示,以便在新数据和任务中快速泛化。
预训练和下游任务提示范式
在预训练阶段,我们使用专用的预训练数据集来训练和优化所有参数。在提示微调阶段,我们通过在未见过的数据集上进行有限的训练周期来专门更新提示网络的参数。这使得下游模型能够快速地适应新数据。
所提出的 FlashST 框架与模型无关,允许与各种现有的时空预测基线作为下游模型无缝集成。
(1)预训练阶段: 我们的目标是基于预训练数据A的历史时空记录预测未来趋势,并同时更新提示网络和下游模型的参数,过程形式化如下:
(2)提示微调阶段: 我们冻结了下游模型的参数并在测试数据集B中主要微调提示网络,如下:
实验
总体表现
对比实验
对比实验结果如下表,结果显示对比端到端时空模型,所提出的方法在多样化城市数据预测场景中展示出了的显著优势。这些发现有力地证明了 FlashST 在准确捕捉城市数据中存在的复杂时空不变模式方面的有效性。
所提出的上下文学习范式在将这些获得的知识转移到适应新的下游任务方面表现出色。通过有效处理分布差距,FlashST 弥合了预训练模型与实际遇到的特定预测场景之间的语义差距。
表1:FlashST 对比实验
模型无关&模型微调
(1)模型无关优势。
所提出模型的一个显著优点是模型无关,即其能够与各种现有时空基础编码器无缝集成,提供灵活性并避免了特定模型选择的限制。
下表展示了所提出 FlashST 方法与四种最先进的时空模型(即 STGCN、GWN、MTGNN、PDFormer)的轻松适配。评估结果突出了 FlashST 的多功能性,展示了其与出色时空模型结合时的卓越性能提升。与最先进模型的成功整合进一步增强了 FlashST 的适应性和在多样化城市数据场景中提高预测准确性的能力。
(2)与模型微调的比较。
为了进一步展示框架的有效性,我们将提出的提示微调方法与全参数微调进行了比较。"w/o Finetune" 方法指在预训练后直接对目标数据集进行预测,而不进行任何微调。"w/ Finetune" 表示在预训练后使用全参数微调来适应目标数据。
然而,值得注意的是,与端到端预测结果相比,直接全参数微调的结果表明其可能未能从预训练过程中受益。在没有有效对齐预训练模型与下游任务的情况下,可能引入噪声,导致误导性的微调和次优的性能。
表2:模型无关&模型微调实验
模型效率评估
(1)训练时间。
本节通过测量了三种不同场景的训练时间:端到端训练、完全微调和 FlashST 评估模型效率,如下表所示。对于端到端训练和完全微调,我们遵循现有基线的设置,将训练周期配置为 100,提前停止标准设置为 25 个轮次。
FlashST 提示调整的周期数限制为 20,用于证明下游模型对新数据集的快速适应。结果表明,相同的基线模型在端到端训练和全参数微调的效率是相似的。这两种设置之间训练时间的差异主要源于不同初始化参数导致的收敛速度变化。
FlashST 框架显著提高了计算效率,它将基线模型的训练时间减少了 20% 到 80%,这显著提高了它们适应新时空数据的效率。
表3:不同模型计算时间统计(秒)
(2)更快地收敛速度。
本节对 FlashST 在不同数据集上的收敛速度进行了调查。下图显示了在使用 PEMS07(M) 和 CA-D5 数据集时,采用 MTGNN 作为下游模型的验证误差下降趋势。
结果表明,通过整合 FlashST 方法,下游模型在几个调整周期内就能实现收敛。相比之下,端到端训练和微调范式需要更多的训练轮次来适应新数据。这一现象可以归因于我们提出的时空提示网络和数据映射策略的有效性。这些组件使得模型能够结合预训练知识,利用新数据的时空特征,从而快速适应多样的时空场景。
图3:FlashST 收敛速度
消融实验
(1)时空上下文蒸馏的效用: 我们分别移除时间上下文信息(-TC)和空间上下文信息(-SC)。结果显示,当去除时空上下文后大多数指标的性能显著下降。这突出了在上下文学习过程中保留时间和空间上下文的关键重要性。有效地编码时间信息和整合空间信息对于捕捉时空不变模式和增强模型对数据的理解至关重要。
(2)时空依赖建模的作用: 我们单独删除了时间编码器(-TE)和空间编码器(-SE)。结果表明,时空依赖编码在上下文学习过程中有效整合不同时间段和地点之间复杂关系中发挥了至关重要的作用。包含时间和空间依赖编码器使得模型能够理解并利用时间与空间之间的复杂交互。这种能力有助于下游模型更快地适应新的时空场景。
(3)统一分布映射机制的影响: 我们从两个方面评估了统一分布映射策略的实用性。
i)-Uni,去除统一分布映射策略。性能的下降表明了该策略对模型的积极影响。通过将不同的时空数据嵌入映射到一个统一的分布中,FlashST 有效地减轻了预训练数据与未见时空数据之间分布差异的影响。
ii)r/BN。统一分布映射策略被批归一化替换。批归一化根据小批量的局部统计特性标准化数据,缓解了神经网络训练中的内部协变量偏移问题,从而提高了收敛效率。
然而,由于缺乏预训练数据与下游任务数据之间确立的联系,下游模型难以有效地从预训练过程中转移知识。所提出的策略确保模型能够有效利用预训练阶段获得的知识。通过对齐不同数据源的分布,模型能够更好地适应新的时空场景并做出更准确的预测。
图4:FlashST 消融实验
超参分析
本节研究了不同超参数设置的影响,特别是不同的温度系数和损失权重系数对模型性能的影响。我们的发现表明,当参数配置为τ=0.3 和λ=1.0 时,模型达到了最佳性能。
值得注意的是,这些参数的变化对最终结果的影响很小,突出了模型对不同参数设置的有效适应性。即使特征尺度存在差异,模型也能学习到区分不同区域中嵌入特征的高效表示。
此外,模型的性能不会随着统一性损失增加而产生较大波动。这表明我们分布映射策略并不干扰预测损失。这进一步支持了我们策略的可行性,并促进了下游模型对新的时空环境的快速泛化。
图5:关于τ和λ的模型超参实验
案例研究
为了评估我们提出的统一分布映射方法在将各种数据表示转换为统一分布的有效性,我们对使用和未使用分布映射机制的提示嵌入进行了可视化。
我们首先采用 PCA 技术将每个嵌入样本的维度降至二维,随后使用 L2 范数将降维后的嵌入投影到单位圆上,如下图所示。可视化结果为统一分布映射策略有效地将提示嵌入转换成近似的均匀分布提供了有力证据。
相比之下,缺乏这一策略的变体未能实现这种理想的分布属性。通过将新的时空环境中的数据转换为一致的分布,FlashST 获得了利用预训练知识并迅速适应新数据集的能力,从而促进了其在各种交通任务上的表现。
图6:提示嵌入的分布可视化。
总结与展望
本文介绍了 FlashST,用于将时空预测模型适应于未见过数据的各种下游任务。所提出的上下文学习框架利用了一个时空提示网络,该网络包括了时空上下文提炼机制和时空依赖性建模方案。
框架通过捕捉上下文信号和建模时间及地点之间的复杂关系,有效地适应不同的时空场景。
为了解决分布差异问题,我们通过整合一个分布映射机制来增强 FlashST,该机制对齐了预训练数据和下游数据的数据分布,促进了时空预测中有效的知识转移。
广泛的实验表明,我们的 FlashST 在多种下游时空预测场景中的有效性和泛化能力。未来的研究方向之一可能是探索在 FlashST 框架中整合大型语言模型作为知识指导的潜力。
参考资料: