Source code for claf.modules.layer.residual

import torch
import torch.nn as nn

from claf.modules.layer.normalization import LayerNorm


[docs]class ResidualConnection(nn.Module): """ ResidualConnection in Deep Residual Learning for Image Recognition (https://arxiv.org/abs/1512.03385) => f(x) + x * Args: dim: the number of dimension * Kwargs: layer_dropout: layer dropout probability (stochastic depth) dropout: dropout probability """ def __init__(self, dim, layer_dropout=None, layernorm=False): super(ResidualConnection, self).__init__() self.survival = None if layer_dropout < 1: self.survival = torch.FloatTensor([layer_dropout]) if layernorm: self.norm = LayerNorm(dim) else: self.norm = lambda x: x
[docs] def forward(self, x, sub_layer_fn): # implementation of stochastic depth if self.training and self.survival is not None: survival_prob = torch.bernoulli(self.survival).item() if survival_prob == 1: return x + sub_layer_fn(self.norm(x)) else: return x else: return x + sub_layer_fn(self.norm(x))