Source code for claf.tokens.embedding.frequent_word_embedding


from overrides import overrides
import torch
import torch.nn as nn

import claf.modules.functional as f

from .base import TokenEmbedding
from .word_embedding import WordEmbedding


[docs]class FrequentTuningWordEmbedding(TokenEmbedding): """ Frequent Word Finetuning Embedding Finetuning embedding matrix, according to 'threshold_index' * Args: vocab: Vocab (claf.tokens.vocab) * Kwargs: dropout: The number of dropout probability embed_dim: The number of embedding dimension padding_idx: If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index. max_norm: If given, will renormalize the embedding vectors to have a norm lesser than this before extracting. Note: this will modify weight in-place. norm_type: The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq: if given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse: if True, gradient w.r.t. weight will be a sparse tensor. See Notes under torch.nn.Embedding for more details regarding sparse gradients. pretrained_path: pretrained vector path (eg. GloVe) trainable: finetune or fixed """ def __init__( self, vocab, dropout=0.2, embed_dim=100, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, pretrained_path=None, ): super(FrequentTuningWordEmbedding, self).__init__(vocab) self.threshold_index = vocab.threshold_index self.embed_dim = embed_dim self.fine_tune_word_embedding = WordEmbedding( vocab, dropout=0, embed_dim=embed_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, pretrained_path=pretrained_path, ) self.fixed_word_embedding = WordEmbedding( vocab, dropout=0, embed_dim=embed_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, pretrained_path=pretrained_path, ) if dropout > 0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = lambda x: x
[docs] @overrides def forward(self, words, frequent_tuning=False): if frequent_tuning and self.training: padding_mask = words.eq(0).long() # Fine-tuning - N the most frequent fine_tune_mask = torch.lt(words, self.threshold_index) * padding_mask.eq( 0 ) # < threshold_index fine_tune_words = words * fine_tune_mask.long() fine_tune_embedded = self.fine_tune_word_embedding(fine_tune_words) fine_tune_embedded = f.masked_zero(fine_tune_embedded, fine_tune_mask) # Fixed - under N frequent fixed_mask = torch.ge(words, self.threshold_index) # >= threshold_index fixed_embedeed = self.fixed_word_embedding(words).detach() # Fixed fixed_embedeed = f.masked_zero(fixed_embedeed, fixed_mask) embedded_words = fine_tune_embedded + fixed_embedeed else: embedded_words = self.fixed_word_embedding(words) return self.dropout(embedded_words)
[docs] @overrides def get_output_dim(self): return self.embed_dim