Source code for claf.tokens.embedding.char_embedding


from overrides import overrides
import torch
import torch.nn as nn
import torch.nn.functional as F

from claf.modules.activation import get_activation_fn

from .base import TokenEmbedding


[docs]class CharEmbedding(TokenEmbedding): """ Character Embedding (CharCNN) (https://arxiv.org/abs/1509.01626) * Args: vocab: Vocab (claf.tokens.vocab) * Kwargs: dropout: The number of dropout probability embed_dim: The number of embedding dimension kernel_sizes: The list of kernel size (n-gram) num_filter: The number of cnn filter activation: Activation Function (eg. ReLU) """ def __init__( self, vocab, dropout=0.2, embed_dim=16, kernel_sizes=[5], num_filter=100, activation="relu" ): super(CharEmbedding, self).__init__(vocab) self.embed_dim = embed_dim self.num_filter = num_filter self.weight = self._init_weight(trainable=True) self.convs = [ nn.Conv1d( in_channels=1, out_channels=num_filter, kernel_size=embed_dim * kernel_size, stride=embed_dim, ) for kernel_size in kernel_sizes ] # kernel_size = n-gram for i, conv in enumerate(self.convs): self.add_module(f"conv_{i}", conv) self.activation_fn = get_activation_fn(activation)() self.dropout = nn.Dropout(p=dropout) self.projection = None if len(kernel_sizes) > 1: maxpool_output_dim = len(kernel_sizes) * num_filter self.projection = nn.Linear(maxpool_output_dim, num_filter) def _init_weight(self, trainable=False): weight = torch.FloatTensor(self.get_vocab_size(), self.embed_dim) weight = torch.nn.Parameter(weight, requires_grad=trainable) torch.nn.init.xavier_uniform_(weight) return weight
[docs] @overrides def forward(self, chars): mask_chars = (chars != 0).long() B, W_L, C_L = chars.size() # (batch_size, word_maxlen, char_maxlen) chars = chars.view(B, W_L * C_L) char_embedds = F.embedding(chars, self.weight) char_embedds = char_embedds.view(B, W_L, C_L, -1) # Masking char_embedds = char_embedds * mask_chars.unsqueeze(-1).float() char_embedds = char_embedds.view(B * W_L, 1, -1) conv_outputs = [] for i in range(len(self.convs)): conv = getattr(self, f"conv_{i}") output = self.activation_fn(conv(char_embedds)) pooled = F.max_pool1d(output, output.size(2)).squeeze(2) conv_outputs.append(pooled) encoded = conv_outputs[0] if len(conv_outputs) > 1: encoded = torch.cat(conv_outputs, dim=1) encoded = encoded.view(B, W_L, -1) if self.projection: encoded = self.projection(encoded) return self.dropout(encoded)
[docs] @overrides def get_output_dim(self): return self.num_filter