Commit 3f1a2e18 by xlwang

Add dual_random_walk.

parent 49c8f142
......@@ -13,7 +13,8 @@
"num_rnn_layers": 2,
"rnn_units": 64,
"seq_len": 12,
"output_dim": 1
"output_dim": 1,
"filter_type": "dual_random_walk"
}
},
"dataloader": {
......
......@@ -13,12 +13,12 @@ from base import BaseModel
class DiffusionGraphConv(BaseModel):
def __init__(self, supports, input_dim, hid_dim, num_nodes, max_diffusion_step, output_dim, bias_start=0.0):
super(DiffusionGraphConv, self).__init__()
num_matrices = max_diffusion_step + 1
self.num_matrices = len(supports) * max_diffusion_step + 1 # Don't forget to add for x itself.
input_size = input_dim + hid_dim
self._num_nodes = num_nodes
self._max_diffusion_step = max_diffusion_step
self._supports = supports
self.weight = nn.Parameter(torch.FloatTensor(size=(input_size*num_matrices, output_dim)))
self.weight = nn.Parameter(torch.FloatTensor(size=(input_size*self.num_matrices, output_dim)))
self.biases = nn.Parameter(torch.FloatTensor(size=(output_dim,)))
nn.init.xavier_normal_(self.weight.data, gain=1.414)
nn.init.constant_(self.biases.data, val=bias_start)
......@@ -54,17 +54,17 @@ class DiffusionGraphConv(BaseModel):
if self._max_diffusion_step == 0:
pass
else:
x1 = torch.sparse.mm(self._supports, x0)
x = self._concat(x, x1)
for k in range(2, self._max_diffusion_step + 1):
x2 = 2 * torch.sparse.mm(self._supports, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1
num_matrices = self._max_diffusion_step + 1 # Adds for x itself.
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
for support in self._supports:
x1 = torch.sparse.mm(support, x0)
x = self._concat(x, x1)
for k in range(2, self._max_diffusion_step + 1):
x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1
x = torch.reshape(x, shape=[self.num_matrices, self._num_nodes, input_size, batch_size])
x = torch.transpose(x, dim0=0, dim1=3) # (batch_size, num_nodes, input_size, order)
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * self.num_matrices])
x = torch.matmul(x, self.weight) # (batch_size * self._num_nodes, output_size)
x = torch.add(x, self.biases)
......@@ -77,7 +77,7 @@ class DCGRUCell(BaseModel):
Graph Convolution Gated Recurrent Unit Cell.
"""
def __init__(self, input_dim, num_units, adj_mat, max_diffusion_step, num_nodes,
num_proj=None, activation=torch.tanh, use_gc_for_ru=True):
num_proj=None, activation=torch.tanh, use_gc_for_ru=True, filter_type='laplacian'):
"""
:param num_units: the hidden dim of rnn
:param adj_mat: the (weighted) adjacency matrix of the graph, in numpy ndarray form
......@@ -94,13 +94,22 @@ class DCGRUCell(BaseModel):
self._max_diffusion_step = max_diffusion_step
self._num_proj = num_proj
self._use_gc_for_ru = use_gc_for_ru
supports = utils.calculate_scaled_laplacian(adj_mat, lambda_max=None) # scipy coo matrix
self._supports = self._build_sparse_matrix(supports).cuda() # to pytorch sparse tensor
# self.register_parameter('weight', None)
# self.register_parameter('biases', None)
# temp_inputs = torch.FloatTensor(torch.rand((batch_size, num_nodes * input_dim)))
# temp_state = torch.FloatTensor(torch.rand((batch_size, num_nodes * num_units)))
# self.forward(temp_inputs, temp_state)
self._supports = []
supports = []
if filter_type == "laplacian":
supports.append(utils.calculate_scaled_laplacian(adj_mat, lambda_max=None))
elif filter_type == "random_walk":
supports.append(utils.calculate_random_walk_matrix(adj_mat).T)
elif filter_type == "dual_random_walk":
supports.append(utils.calculate_random_walk_matrix(adj_mat))
supports.append(utils.calculate_random_walk_matrix(adj_mat.T))
else:
supports.append(utils.calculate_scaled_laplacian(adj_mat))
for support in supports:
self._supports.append(self._build_sparse_matrix(support).cuda()) # to PyTorch sparse tensor
# supports = utils.calculate_scaled_laplacian(adj_mat, lambda_max=None) # scipy coo matrix
# self._supports = self._build_sparse_matrix(supports).cuda() # to pytorch sparse tensor
self.dconv_gate = DiffusionGraphConv(supports=self._supports, input_dim=input_dim,
hid_dim=num_units, num_nodes=num_nodes,
max_diffusion_step=max_diffusion_step,
......
......@@ -10,7 +10,8 @@ from base import BaseModel
class DCRNNEncoder(BaseModel):
def __init__(self, input_dim, adj_mat, max_diffusion_step, hid_dim, num_nodes, num_rnn_layers):
def __init__(self, input_dim, adj_mat, max_diffusion_step, hid_dim, num_nodes,
num_rnn_layers, filter_type):
super(DCRNNEncoder, self).__init__()
self.hid_dim = hid_dim
self._num_rnn_layers = num_rnn_layers
......@@ -19,21 +20,23 @@ class DCRNNEncoder(BaseModel):
encoding_cells = list()
# the first layer has different input_dim
encoding_cells.append(DCGRUCell(input_dim=input_dim, num_units=hid_dim, adj_mat=adj_mat,
max_diffusion_step=max_diffusion_step, num_nodes=num_nodes))
max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes, filter_type=filter_type))
# construct multi-layer rnn
for _ in range(1, num_rnn_layers):
encoding_cells.append(DCGRUCell(input_dim=hid_dim, num_units=hid_dim, adj_mat=adj_mat,
max_diffusion_step=max_diffusion_step, num_nodes=num_nodes))
max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes, filter_type=filter_type))
self.encoding_cells = nn.ModuleList(encoding_cells)
def forward(self, inputs, initial_hidden_state):
# inputs shape is (seq_length, batch, num_nodes, input_dim) (12, 50, 207, 2)
# inputs shape is (seq_length, batch, num_nodes, input_dim) (12, 64, 207, 2)
# inputs to cell is (batch, num_nodes * input_dim)
# init_hidden_state should be (num_layers, batch_size, num_nodes*num_units) (2, 50, 207*64)
# init_hidden_state should be (num_layers, batch_size, num_nodes*num_units) (2, 64, 207*64)
seq_length = inputs.shape[0]
batch_size = inputs.shape[1]
inputs = torch.reshape(inputs, (seq_length, batch_size, -1)) # (12, 50, 207*2)
inputs = torch.reshape(inputs, (seq_length, batch_size, -1)) # (12, 64, 207*2)
current_inputs = inputs
output_hidden = [] # the output hidden states, shape (num_layers, batch, outdim)
......@@ -59,7 +62,7 @@ class DCRNNEncoder(BaseModel):
class DCGRUDecoder(BaseModel):
def __init__(self, input_dim, adj_mat, max_diffusion_step, num_nodes,
hid_dim, output_dim, num_rnn_layers):
hid_dim, output_dim, num_rnn_layers, filter_type):
super(DCGRUDecoder, self).__init__()
self.hid_dim = hid_dim
self._num_nodes = num_nodes # 207
......@@ -68,16 +71,16 @@ class DCGRUDecoder(BaseModel):
cell = DCGRUCell(input_dim=hid_dim, num_units=hid_dim,
adj_mat=adj_mat, max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes)
num_nodes=num_nodes, filter_type=filter_type)
cell_with_projection = DCGRUCell(input_dim=hid_dim, num_units=hid_dim,
adj_mat=adj_mat, max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes, num_proj=output_dim)
num_nodes=num_nodes, num_proj=output_dim, filter_type=filter_type)
decoding_cells = list()
# first layer of the decoder
decoding_cells.append(DCGRUCell(input_dim=input_dim, num_units=hid_dim,
adj_mat=adj_mat, max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes))
num_nodes=num_nodes, filter_type=filter_type))
# construct multi-layer rnn
for _ in range(1, num_rnn_layers - 1):
decoding_cells.append(cell)
......@@ -102,11 +105,11 @@ class DCGRUDecoder(BaseModel):
# if rnn has only one layer
# if self._num_rnn_layers == 1:
# # first input to the decoder is the GO Symbol
# current_inputs = inputs[0] # (50, 207*1)
# current_inputs = inputs[0] # (64, 207*1)
# hidden_state = prev_hidden_state[0]
# for t in range(1, seq_length):
# output, hidden_state = self.decoding_cells[0](current_inputs, hidden_state)
# outputs[t] = output # (50, 207*1)
# outputs[t] = output # (64, 207*1)
# teacher_force = random.random() < teacher_forcing_ratio
# current_inputs = (inputs[t] if teacher_force else output)
......@@ -130,7 +133,7 @@ class DCGRUDecoder(BaseModel):
class DCRNNModel(BaseModel):
def __init__(self, adj_mat, batch_size, enc_input_dim, dec_input_dim, max_diffusion_step, num_nodes,
num_rnn_layers, rnn_units, seq_len, output_dim):
num_rnn_layers, rnn_units, seq_len, output_dim, filter_type):
super(DCRNNModel, self).__init__()
# scaler for data normalization
# self._scaler = scaler
......@@ -150,17 +153,17 @@ class DCRNNModel(BaseModel):
self.encoder = DCRNNEncoder(input_dim=enc_input_dim, adj_mat=adj_mat,
max_diffusion_step=max_diffusion_step,
hid_dim=rnn_units, num_nodes=num_nodes,
num_rnn_layers=num_rnn_layers)
num_rnn_layers=num_rnn_layers, filter_type=filter_type)
self.decoder = DCGRUDecoder(input_dim=dec_input_dim,
adj_mat=adj_mat, max_diffusion_step=max_diffusion_step,
num_nodes=num_nodes, hid_dim=rnn_units,
output_dim=self._output_dim,
num_rnn_layers=num_rnn_layers)
num_rnn_layers=num_rnn_layers, filter_type=filter_type)
assert self.encoder.hid_dim == self.decoder.hid_dim, \
"Hidden dimensions of encoder and decoder must be equal!"
def forward(self, source, target, teacher_forcing_ratio):
# the size of source/target would be (50, 12, 207, 2)
# the size of source/target would be (64, 12, 207, 2)
source = torch.transpose(source, dim0=0, dim1=1)
target = torch.transpose(target[..., :self._output_dim], dim0=0, dim1=1)
target = torch.cat([self.GO_Symbol, target], dim=0)
......@@ -173,7 +176,7 @@ class DCRNNModel(BaseModel):
outputs = self.decoder(target, context, teacher_forcing_ratio=teacher_forcing_ratio)
# the elements of the first time step of the outputs are all zeros.
return outputs[1:, :, :] # (seq_length, batch_size, num_nodes*output_dim) (12, 50, 207*1)
return outputs[1:, :, :] # (seq_length, batch_size, num_nodes*output_dim) (12, 64, 207*1)
@property
def batch_size(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment