保姆级教程:在Google Colab上从零跑通TFT代码,预测未来一周交通流量

张开发
2026/4/12 9:17:05 15 分钟阅读

分享文章

保姆级教程:在Google Colab上从零跑通TFT代码,预测未来一周交通流量
从零实战在Google Colab上快速部署TFT模型预测交通流量第一次接触Temporal Fusion TransformerTFT时我被它处理复杂时间序列数据的能力震撼到了。这个模型不仅能预测未来多个时间点的数值还能告诉我们哪些因素对预测结果影响最大——这种透明性在商业决策中简直是无价之宝。本文将带你用Google Colab这个免费云平台基于公开的Traffic交通流量数据集从零开始构建一个能预测未来7天交通流量的TFT模型。整个过程就像搭积木一样简单即使你只有基础的Python和机器学习知识跟着步骤操作也能在1小时内看到预测结果。1. 环境准备与数据加载在Google Colab中运行TFT模型前我们需要配置好Python环境并安装必要的库。Colab已经预装了TensorFlow和PyTorch等主流框架但TFT需要一些额外的依赖项。打开Colab笔记本https://colab.research.google.com/新建一个Python 3笔记本然后执行以下安装命令!pip install tensorflow2.8.0 !pip install pytorch-forecasting0.9.2 !pip install pandas numpy matplotlib安装完成后导入基础库并检查GPU是否可用TFT训练很吃计算资源import torch print(GPU可用:, torch.cuda.is_available())接下来加载交通流量数据集。我们将使用PyTorch Forecasting提供的Traffic数据集它记录了旧金山湾区高速公路的每小时占用率from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet from pytorch_forecasting.data.examples import get_stallion_data data get_stallion_data() data[time_idx] data[date].dt.year * 1000 data[date].dt.dayofyear data[time_idx] - data[time_idx].min()这个数据集包含以下关键字段volume交通流量值目标变量date时间戳agency路段所属机构静态分类特征sku具体路段编号静态分类特征2. 数据预处理与特征工程原始数据需要转换成TFT能理解的格式。PyTorch Forecasting提供了TimeSeriesDataSet来标准化这个过程。我们先定义训练集和验证集的切割点max_encoder_length 168 # 使用过去7天168小时的数据 max_prediction_length 24 # 预测未来24小时的流量 training_cutoff data[time_idx].max() - max_prediction_length training TimeSeriesDataSet( data[lambda x: x.time_idx training_cutoff], time_idxtime_idx, targetvolume, group_ids[agency, sku], max_encoder_lengthmax_encoder_length, max_prediction_lengthmax_prediction_length, static_categoricals[agency, sku], time_varying_known_categoricals[], time_varying_known_reals[time_idx], time_varying_unknown_reals[volume], ) validation TimeSeriesDataSet.from_dataset(training, data, predictTrue)这里有几个关键参数需要注意max_encoder_length模型能看到的历史数据长度max_prediction_length要预测的未来时间步长static_categoricals不随时间变化的分类特征如路段编号time_varying_known_reals已知的未来特征如时间索引提示如果显存不足可以尝试减小batch_size或max_encoder_length的值接着创建数据加载器这是PyTorch训练模型的标准接口batch_size 64 train_dataloader training.to_dataloader(trainTrue, batch_sizebatch_size) val_dataloader validation.to_dataloader(trainFalse, batch_sizebatch_size)3. 构建TFT模型现在来到最激动人心的部分——定义TFT模型架构。PyTorch Forecasting已经实现了TFT的核心组件我们只需配置参数tft TemporalFusionTransformer.from_dataset( training, learning_rate0.03, hidden_size32, # 隐层大小 attention_head_size4, # 注意力头数 dropout0.1, hidden_continuous_size16, output_size7, # 预测7个分位数 lossQuantileLoss(), log_interval10, reduce_on_plateau_patience4, )关键参数解析hidden_sizeGRN和VSN网络的隐层维度attention_head_size自注意力机制的头数output_size输出分位数数量这里设置为7个loss使用分位数损失函数模型结构可视化可以帮助我们理解数据流向print(tft.summarize(full)) # 打印模型结构你会看到输出中包含以下几个核心模块VariableSelectionNetwork动态选择重要特征GatedResidualNetwork门控残差网络处理特征StaticCovariateEncoder编码静态特征TemporalSelfAttention时间自注意力机制4. 模型训练与验证配置好模型后就可以开始训练了。使用PyTorch Lightning的Trainer类可以简化训练流程from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor early_stop_callback EarlyStopping( monitorval_loss, patience10, verboseTrue, modemin ) lr_logger LearningRateMonitor() trainer Trainer( max_epochs50, gpus1, gradient_clip_val0.1, callbacks[lr_logger, early_stop_callback], ) trainer.fit( tft, train_dataloaderstrain_dataloader, val_dataloadersval_dataloader, )训练过程中会显示损失值变化。如果一切正常你应该能看到类似这样的输出Epoch 10: val_loss0.123, train_loss0.145 Epoch 20: val_loss0.098, train_loss0.102训练完成后保存模型权重以便后续使用torch.save(tft.state_dict(), tft_traffic_model.pt)5. 预测与结果可视化现在到了检验成果的时刻。我们从验证集中取一个batch进行预测import matplotlib.pyplot as plt raw_predictions, x tft.predict(val_dataloader, moderaw, return_xTrue)选择第一个样本展示预测结果tft.plot_prediction(x, raw_predictions, idx0, add_loss_to_titleTrue) plt.show()这张图会显示蓝色线历史真实值绿色线未来真实值如果验证集中包含红色线模型预测的中位数浅红色区域预测的不确定性范围10%-90%分位数如果想获取具体的预测数值可以这样提取predictions tft.predict(val_dataloader) print(predictions[0]) # 第一个样本的预测结果6. 模型解释与特征重要性TFT最强大的功能之一是它的可解释性。我们可以分析模型在预测时关注了哪些特征和时间点interpretation tft.interpret_output(raw_predictions, reductionsum) tft.plot_interpretation(interpretation)这会生成三张关键图表变量重要性显示哪些特征对预测影响最大注意力权重模型在不同时间点的关注度分位数分析各分位数预测的敏感度例如在交通流量预测中你可能会发现路段编号sku是最重要的静态特征早晚高峰时段获得更高的注意力权重周末的预测不确定性比工作日更大7. 性能优化技巧在Colab的免费资源限制下这里有几个提升TFT性能的实用技巧批量大小调整batch_size 32 # 显存不足时减小此值学习率调度tft TemporalFusionTransformer.from_dataset( training, learning_rate0.01, lr_schedulertorch.optim.lr_scheduler.ReduceLROnPlateau, lr_scheduler_params{patience: 3}, )特征工程增强添加小时、星期几作为已知时间特征对交通流量做对数变换处理长尾分布data[hour] data[date].dt.hour data[day_of_week] data[date].dt.dayofweek早停策略调整early_stop_callback EarlyStopping( monitorval_loss, patience7, # 更早停止防止过拟合 min_delta0.001, )8. 常见问题排查问题1CUDA内存不足解决方案减小batch_size或max_encoder_length检查命令nvidia-smi查看显存使用情况问题2预测结果全是NaN可能原因学习率过高导致梯度爆炸修复方法tft TemporalFusionTransformer.from_dataset( training, learning_rate0.001, # 降低学习率 gradient_clip_val0.5, # 添加梯度裁剪 )问题3验证损失不下降检查数据是否标准化from pytorch_forecasting.data import TorchNormalizer training TimeSeriesDataSet( ..., target_normalizerTorchNormalizer(methodstandard), )问题4注意力权重过于分散可能原因隐层维度太小调整方案tft TemporalFusionTransformer.from_dataset( training, hidden_size64, # 增加隐层维度 attention_head_size8, )在Colab上运行完整流程后我建议将笔记本保存到Google Drive并导出为PDF或Python文件备份。对于更复杂的预测任务可以尝试调整以下超参数组合参数推荐范围影响hidden_size32-128模型容量dropout0.1-0.3防止过拟合learning_rate0.01-0.001训练稳定性attention_head_size4-8注意力机制复杂度max_encoder_length24-168历史窗口大小

更多文章