Authors:
Yuankai Wu, Dingyi Zhuang, Aurelie Labbe, Lijun Sun
Abstract
Time series forecasting and spatiotemporal kriging are the two most important tasks in spatiotemporal data analysis. Recent research on graph neural networks has made substantial progress in time series forecasting, while little attention has been paid to the kriging problem—recovering signals for unsampled locations/sensors. Most existing scalable kriging methods (e.g., matrix/tensor completion) are transductive, and thus full retraining is required when we have a new sensor to interpolate. In this paper, we develop an Inductive Graph Neural Network Kriging (IGNNK) model to recover data for unsampled sensors on a network/graph structure. To generalize the effect of distance and reachability, we generate random subgraphs as samples and corresponding adjacency matrix for each sample. By reconstructing all signals on each sample subgraph, IGNNK can effectively learn the spatial message passing mechanism. Empirical results on several real-world spatiotemporal datasets demonstrate the effectiveness of our model. In addition, we also find that the learned model can be successfully transferred to the same type of kriging tasks on an unseen dataset. Our results show that: 1) GNN is an efficient and effective tool for spatial kriging; 2) inductive GNNs can be trained using dynamic adjacency matrices; 3) a trained model can be transferred to new graph structures and 4) IGNNK can be used to generate virtual sensors.
This notebook presents the demo performance of our work. Details can also be referred in our GitHub repo .
from __future__ import division
import torch
import numpy as np
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import load_metr_la_rdata , get_normalized_adj , get_Laplace , calculate_random_walk_matrix , test_error
import random
import pandas as pd
from basic_structure import D_GCN , C_GCN , K_GCN , IGNNK
import geopandas as gp
import matplotlib as mlt
Define the hyper parameters
n_o_n_m = 150 #sampled space dimension
h = 24 #sampled time dimension
z = 100 #hidden dimension for graph convolution
K = 1 #If using diffusion convolution, the actual diffusion convolution step is K+1
n_m = 50 #number of mask node during training
N_u = 50 #target locations, N_u locations will be deleted from the training data
Max_episode = 750 #max training episode
learning_rate = 0.0001 #the learning_rate for Adam optimizer
E_maxvalue = 80 #the max value from experience
batch_size = 4
Build the IGNNK model
STmodel = IGNNK ( h , z , K ) # The graph neural networks
Load data
A , X = load_metr_la_rdata ()
split_line1 = int ( X . shape [ 2 ] * 0.7 )
training_set = X [:, 0 , : split_line1 ]. transpose ()
test_set = X [:, 0 , split_line1 :]. transpose () # split the training and test period
rand = np . random . RandomState ( 0 ) # Fixed random output, just an example when seed = 0.
unknow_set = rand . choice ( list ( range ( 0 , X . shape [ 0 ])), N_u , replace = False )
unknow_set = set ( unknow_set )
full_set = set ( range ( 0 , 207 ))
know_set = full_set - unknow_set
training_set_s = training_set [:, list ( know_set )] # get the training data in the sample time period
A_s = A [:, list ( know_set )][ list ( know_set ), :] # get the observed adjacent matrix from the full adjacent matrix,
# the adjacent matrix are based on pairwise distance,
# so we need not to construct it for each batch, we just use index to find the dynamic adjacent matrix
Train the IGNNK model
criterion = nn . MSELoss ()
optimizer = optim . Adam ( STmodel . parameters (), lr = learning_rate )
RMSE_list = []
MAE_list = []
MAPE_list = []
for epoch in range ( Max_episode ):
for i in range ( training_set . shape [ 0 ] // ( h * batch_size )): #using time_length as reference to record test_error
t_random = np . random . randint ( 0 , high = ( training_set_s . shape [ 0 ] - h ), size = batch_size , dtype = 'l' )
know_mask = set ( random . sample ( range ( 0 , training_set_s . shape [ 1 ]), n_o_n_m )) #sample n_o + n_m nodes
feed_batch = []
for j in range ( batch_size ):
feed_batch . append ( training_set_s [ t_random [ j ]: t_random [ j ] + h , :][:, list ( know_mask )]) #generate 8 time batches
inputs = np . array ( feed_batch )
inputs_omask = np . ones ( np . shape ( inputs ))
inputs_omask [ inputs == 0 ] = 0 # We found that there are irregular 0 values for METR-LA, so we treat those 0 values as missing data,
# For other datasets, it is not necessary to mask 0 values
missing_index = np . ones (( inputs . shape ))
for j in range ( batch_size ):
missing_mask = random . sample ( range ( 0 , n_o_n_m ), n_m ) #Masked locations
missing_index [ j , :, missing_mask ] = 0
Mf_inputs = inputs * inputs_omask * missing_index / E_maxvalue #normalize the value according to experience
Mf_inputs = torch . from_numpy ( Mf_inputs . astype ( 'float32' ))
mask = torch . from_numpy ( inputs_omask . astype ( 'float32' )) #The reconstruction errors on irregular 0s are not used for training
A_dynamic = A_s [ list ( know_mask ), :][:, list ( know_mask )] #Obtain the dynamic adjacent matrix
A_q = torch . from_numpy (( calculate_random_walk_matrix ( A_dynamic ). T ). astype ( 'float32' ))
A_h = torch . from_numpy (( calculate_random_walk_matrix ( A_dynamic . T ). T ). astype ( 'float32' ))
outputs = torch . from_numpy ( inputs / E_maxvalue ) #The label
optimizer . zero_grad ()
X_res = STmodel ( Mf_inputs , A_q , A_h ) #Obtain the reconstruction
loss = criterion ( X_res * mask , outputs * mask )
loss . backward ()
optimizer . step () #Errors backward
MAE_t , RMSE_t , MAPE_t , metr_ignnk_res = test_error ( STmodel , unknow_set , test_set , A , E_maxvalue , True )
RMSE_list . append ( RMSE_t )
MAE_list . append ( MAE_t )
MAPE_list . append ( MAPE_t )
if epoch % 50 == 0 :
print ( epoch , MAE_t , RMSE_t , MAPE_t )
#torch.save(STmodel.state_dict(), 'model/IGNNK.pth') # Save the model
Draw Learning curves on testing error
fig , ax = plt . subplots ()
ax . plot ( RMSE_list , label = 'RMSE_on_test_set' , linewidth = 3.5 )
ax . set_xlabel ( 'Training Batch (x249)' , fontsize = 20 )
ax . set_ylabel ( 'RMSE' , fontsize = 20 )
ax . tick_params ( axis = "x" , labelsize = 14 )
ax . tick_params ( axis = "y" , labelsize = 14 )
ax . legend ( fontsize = 16 )
plt . grid ( True )
plt . tight_layout ()
plt . savefig ( 'fig/ignnk_learning_curve_metr-la.pdf' )
Draw spatial information of METR-LA kriging
url_census = 'data/metr/Census_Road_2010_shapefile/Census_Road_2010.shp'
meta_locations = pd . read_csv ( 'data/metr/graph_sensor_locations.csv' )
map_metr = gp . read_file ( url_census , encoding = "utf-8" )
fig , axes = plt . subplots ( 2 , 2 , figsize = ( 10 , 5 ))
lng_div = 0.01
lat_div = 0.01
crowd = [ 127 , 160 ] #crowd and uncrowd, in the test time slice
ylbs = [ 'Crowded' , 'Uncrowded' ]
for row in range ( 2 ):
for col in range ( 2 ):
ax = axes [ row , col ]
map_metr . plot ( ax = ax , color = 'black' )
ax . set_xlim (( np . min ( meta_locations [ 'longitude' ]) - lng_div , np . max ( meta_locations [ 'longitude' ]) + lng_div ))
ax . set_ylim (( np . min ( meta_locations [ 'latitude' ]) - lat_div , np . max ( meta_locations [ 'latitude' ]) + lat_div ))
ax . set_xticks ([])
ax . set_yticks ([])
if col == 0 :
cax = ax . scatter ( meta_locations [ 'longitude' ][ list ( know_set )], meta_locations [ 'latitude' ][ list ( know_set )], s = 100 , cmap = plt . cm . RdYlGn , c = test_set [ crowd [ row ], list ( know_set )],
norm = mlt . colors . Normalize ( vmin = X . min (), vmax = X . max ()), alpha = 0.6 , label = 'Known nodes' )
cax2 = ax . scatter ( meta_locations [ 'longitude' ][ list ( unknow_set )], meta_locations [ 'latitude' ][ list ( unknow_set )], s = 250 , cmap = plt . cm . RdYlGn , c = test_set [ crowd [ row ], list ( unknow_set )],
norm = mlt . colors . Normalize ( vmin = X . min (), vmax = X . max ()), alpha = 1 , marker = '*' , label = 'Unknown nodes' )
ax . set_ylabel ( ylbs [ row ], fontsize = 20 )
if row == 0 :
ax . set_title ( 'True' , fontsize = 18 )
else :
ax . scatter ( meta_locations [ 'longitude' ][ list ( know_set )], meta_locations [ 'latitude' ][ list ( know_set )], s = 100 , cmap = plt . cm . RdYlGn , c = test_set [ crowd [ row ], list ( know_set )],
norm = mlt . colors . Normalize ( vmin = X . min (), vmax = X . max ()), alpha = 0.6 )
ax . scatter ( meta_locations [ 'longitude' ][ list ( unknow_set )], meta_locations [ 'latitude' ][ list ( unknow_set )], s = 250 , cmap = plt . cm . RdYlGn , c = metr_ignnk_res [ crowd [ row ], list ( unknow_set )],
norm = mlt . colors . Normalize ( vmin = X . min (), vmax = X . max ()), alpha = 1 , marker = '*' )
if row == 0 :
ax . set_title ( 'IGNNK' , fontsize = 18 )
fig . tight_layout ()
fig . subplots_adjust ( right = 0.9 , hspace = 0 , wspace = 0 , bottom = 0 , top = 1 )
l = 0.92
b = 0.03
w = 0.015
h = 0.8
rect = [ l , b , w , h ]
cbar_ax = fig . add_axes ( rect )
cbar = fig . colorbar ( cax , cax = cbar_ax )
cbar . ax . tick_params ( labelsize = 16 )
plt . figlegend ( handles = ( cax , cax2 ), labels = ( 'Known nodes' , 'Unknown nodes' ), bbox_to_anchor = ( 1.2 , 1 ), loc = 1 , borderaxespad = 0. , fontsize = 16 )
plt . savefig ( 'fig/metr_ignnk_spatial_crowd{:}_uncrowd{:}.pdf' . format ( crowd [ 0 ], crowd [ 1 ]))
plt . show ()
Draw temporal information of METR-LA kriging
fig , ax = plt . subplots ( figsize = ( 16 , 5 ))
s = int ( 6400 - 64 )
e = int ( s + 24 * 60 / 5 + 1 )
station = list ( unknow_set )[ 24 ]
ax . plot ( test_set [ s : e , station ], label = 'True' , linewidth = 3 )
ax . plot ( metr_ignnk_res [ s : e , station ], label = 'IGNNK' , linewidth = 3 )
ax . set_ylabel ( 'mile/h' , fontsize = 20 )
ax . tick_params ( axis = "x" , labelsize = 14 )
ax . tick_params ( axis = "y" , labelsize = 14 )
ax . set_xticks ( range ( 0 , 350 , 50 ))
ax . set_xticklabels ([ '0:00 \n Mar 3rd' , '4:00' , '8:00' , '12:00' , '16:00' , '20:00' , '0:00 \n Mar 4th' ])
ax . legend ( bbox_to_anchor = ( 1 , 1 ), loc = 0 , borderaxespad = 0 , fontsize = 16 )
plt . tight_layout ()
plt . savefig ( 'fig/metr_ignnk_temporal.pdf' )
plt . show ()