Source code for claf.tokens.token_embedder.basic_embedder


from overrides import overrides

import torch

from .base import TokenEmbedder


[docs]class BasicTokenEmbedder(TokenEmbedder): """ Basic Token Embedder Take a tensor(indexed token) look up Embedding modules. Output is concatenating all embedded tensors. * Args: token_makers: dictionary of TokenMaker (claf.tokens.token_maker) """ def __init__(self, token_makers): super(BasicTokenEmbedder, self).__init__(token_makers)
[docs] @overrides def get_embed_dim(self, except_keys=[]): return sum(self.embed_dims.values())
[docs] @overrides def forward(self, inputs, except_keys=[], params={}): token_names = [name for name in self.token_names if name not in except_keys] if set(token_names) != set(inputs.keys()): raise ValueError( f"Mismatch token_names inputs: {inputs.keys()}, embeddings: {self.token_names}" ) embedded_tokens = [] for token_name, tensors in inputs.items(): embedding = getattr(self, token_name) embedded_token = embedding(tensors, **params) embedded_tokens.append(embedded_token) output = torch.cat(embedded_tokens, dim=-1) return output