Commit b617e1a2 by xlwang

Amend the decoder output (leave out the first all zero time step)

parent 9a662cce
......@@ -173,7 +173,8 @@ class DCRNNModel(BaseModel):
context, _ = self.encoder(source, init_hidden_state) # (num_layers, batch, outdim)
outputs = self.decoder(target, context, teacher_forcing_ratio=teacher_forcing_ratio)
return outputs # (seq_length+1, batch_size, num_nodes*output_dim) (13, 50, 207*1)
# the elements of the first time step of the outputs are all zeros.
return outputs[1:, :, :] # (seq_length, batch_size, num_nodes*output_dim) (12, 50, 207*1)
@property
def batch_size(self):
......
......@@ -122,8 +122,8 @@ class DCRNNTrainer(BaseTrainer):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data, target, 0)
output = torch.transpose(output[1:].view(12, self.model.batch_size, self.model.num_nodes,
self.model.output_dim), 0, 1) # back to (50, 12, 207, 1)
output = torch.transpose(output.view(12, self.model.batch_size, self.model.num_nodes,
self.model.output_dim), 0, 1) # back to (50, 12, 207, 1)
loss = self.loss(output.cpu(), label)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment