Source code for claf.modules.layer.highway


import torch
import torch.nn as nn

from claf.modules.activation import get_activation_fn


[docs]class Highway(nn.Module): """ Highway Networks (https://arxiv.org/abs/1505.00387) https://github.com/allenai/allennlp/blob/master/allennlp/modules/highway.py * Args: input_size: The number of expected features in the input `x` num_layers: The number of Highway layers. activation: Activation Function (ReLU is default) """ def __init__(self, input_size, num_layers=2, activation="relu"): super(Highway, self).__init__() self.activation_fn = activation if type(activation) == str: self.activation_fn = get_activation_fn(activation)() self._layers = torch.nn.ModuleList( [nn.Linear(input_size, input_size * 2) for _ in range(num_layers)] ) for layer in self._layers: layer.bias[input_size:].data.fill_( 1 ) # should bias the highway layer to just carry its input forward.
[docs] def forward(self, x): current_input = x for layer in self._layers: projected_input = layer(current_input) linear_part = current_input nonlinear_part, gate = projected_input.chunk(2, dim=-1) nonlinear_part = self.activation_fn(nonlinear_part) gate = torch.sigmoid(gate) current_input = gate * linear_part + (1 - gate) * nonlinear_part return current_input