Commit 2474e414 by xlwang

correct activation function for r,u (from tanh to sigmoid)

parent c1120241
......@@ -66,11 +66,7 @@ class DiffusionGraphConv(BaseModel):
x = torch.transpose(x, dim0=0, dim1=3) # (batch_size, num_nodes, input_size, order)
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
# self.weights = torch.nn.parameter(torch.FloatTensor(size=(input_size * num_matrices, output_size)))
# nn.init.xavier_normal_(self.weights, gain=1.414)
x = torch.matmul(x, self.weight) # (batch_size * self._num_nodes, output_size)
# self.biases = nn.Parameter(torch.FloatTensor(size=(output_size,)))
# nn.init.constant_(self.biases, val=bias_start)
x = torch.add(x, self.biases)
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
return torch.reshape(x, [batch_size, self._num_nodes * output_size])
......@@ -115,14 +111,6 @@ class DCGRUCell(BaseModel):
output_dim=num_units)
if num_proj is not None:
self.project = nn.Linear(self._num_units, self._num_proj)
# def reset_weight_parameters(self, dim1, dim2):
# # self.weight = nn.Parameter(inputs.new(inputs.size()).normal_(0, 1))
# self.weight = nn.Parameter(torch.FloatTensor(size=(dim1, dim2)))
# nn.init.xavier_normal_(self.weight, gain=1.414)
# def reset_bias_parameters(self, dim2, bias_start):
# self.biases = nn.Parameter(torch.FloatTensor(size=(dim2,)))
# nn.init.constant_(self.biases, val=bias_start)
@property
def output_size(self):
......@@ -143,7 +131,7 @@ class DCGRUCell(BaseModel):
fn = self.dconv_gate
else:
fn = self._fc
value = torch.tanh(fn(inputs, state, output_size, bias_start=1.0))
value = torch.sigmoid(fn(inputs, state, output_size, bias_start=1.0))
value = torch.reshape(value, (-1, self._num_nodes, output_size))
r, u = torch.split(value, split_size_or_sections=int(output_size/2), dim=-1)
r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment