Source code for claf.tokens.vocabulary


from collections import defaultdict
import json

from claf.data.data_handler import CachePath, DataHandler


[docs]class VocabDict(defaultdict): """ Vocab DefaultDict Class * Kwargs: oov_value: out-of-vocaburary token value (eg. <unk>) """ def __init__(self, oov_value): self.oov_value = oov_value def __missing__(self, key): return self.oov_value
[docs]class Vocab: """ Vocaburary Class Vocab consists of token_to_index and index_to_token. * Args: token_name: Token name (Token and Vocab is one-to-one relationship) * Kwargs: pad_token: padding token value (eg. <pad>) oov_token: out-of-vocaburary token value (eg. <unk>) start_token: start token value (eg. <s>, <bos>) end_token: end token value (eg. </s>, <eos>) cls_token: CLS token value for BERT (eg. [CLS]) sep_token: SEP token value for BERT (eg. [SEP]) min_count: token's minimal frequent count. when you define min_count, tokens remain that bigger than min_count. max_vocab_size: vocaburary's maximun size. when you define max_vocab_size, tokens are selected according to frequent count. frequent_count: get frequent_count threshold_index. (eg. frequent_count = 1000, threshold_index is the tokens that frequent_count is 999 index number.) pretrained_path: pretrained vocab file path (format: A\nB\nC\nD\n...) """ DEFAULT_PAD_INDEX, DEFAULT_PAD_TOKEN = 0, "[PAD]" DEFAULT_OOV_INDEX, DEFAULT_OOV_TOKEN = 1, "[UNK]" # pretrained_vocab handle methods PRETRAINED_ALL = "all" # Case. embedding matrix - predefine_vocab fixed PRETRAINED_INTERSECT = "intersect" # add token that included in predefine_vocab, else UNK_token def __init__( self, token_name, pad_token=None, oov_token=None, start_token=None, end_token=None, cls_token=None, sep_token=None, min_count=None, max_vocab_size=None, frequent_count=None, pretrained_path=None, pretrained_token=None, ): self.token_name = token_name # basic token (pad and oov) self.pad_index = self.DEFAULT_PAD_INDEX self.pad_token = pad_token if pad_token is None: self.pad_token = self.DEFAULT_PAD_TOKEN self.oov_index = self.DEFAULT_OOV_INDEX self.oov_token = oov_token if oov_token is None: self.oov_token = self.DEFAULT_OOV_TOKEN # special_tokens self.start_token = start_token self.end_token = end_token self.cls_token = cls_token self.sep_token = sep_token self.min_count = min_count self.max_vocab_size = max_vocab_size self.token_counter = None self.frequent_count = frequent_count self.threshold_index = None self.pretrained_path = pretrained_path self.pretrained_token = pretrained_token self.pretrained_token_methods = [self.PRETRAINED_ALL, self.PRETRAINED_INTERSECT]
[docs] def init(self): self.token_to_index = VocabDict(self.oov_index) self.index_to_token = VocabDict(self.oov_token) # add default token (pad, oov) self.add(self.pad_token) self.add(self.oov_token) special_tokens = [self.start_token, self.end_token, self.cls_token, self.sep_token] for token in special_tokens: if token is not None: self.add(token)
[docs] def build(self, token_counter, predefine_vocab=None): """ build token with token_counter * Args: token_counter: (collections.Counter) token's frequent_count Counter. """ if predefine_vocab is not None: if ( self.pretrained_token is None or self.pretrained_token not in self.pretrained_token_methods ): raise ValueError( f"When use 'predefine_vocab', need to set 'pretrained_token' {self.pretrained_token_methods}" ) if predefine_vocab: if self.pretrained_token == self.PRETRAINED_ALL: self.from_texts(predefine_vocab) return else: predefine_vocab = set(predefine_vocab) self.token_counter = token_counter self.init() token_counts = list(token_counter.items()) token_counts.sort(key=lambda x: x[1], reverse=True) # order: DESC if self.max_vocab_size is not None: token_counts = token_counts[: self.max_vocab_size] for token, count in token_counts: if self.min_count is not None: if count >= self.min_count: self.add(token, predefine_vocab=predefine_vocab) else: self.add(token, predefine_vocab=predefine_vocab) if self.threshold_index is None and self.frequent_count is not None: if count < self.frequent_count: self.threshold_index = len(self.token_to_index)
[docs] def build_with_pretrained_file(self, token_counter): data_handler = DataHandler(CachePath.VOCAB) vocab_texts = data_handler.read(self.pretrained_path) if self.pretrained_path.endswith(".txt"): predefine_vocab = vocab_texts.split("\n") elif self.pretrained_path.endswith(".json"): vocab_texts = json.loads(vocab_texts) # {token: id} predefine_vocab = [item[0] for item in sorted(vocab_texts.items(), key=lambda x: x[1])] else: raise ValueError(f"support vocab extention. .txt or .json") self.build(token_counter, predefine_vocab=predefine_vocab)
def __len__(self): return len(self.token_to_index)
[docs] def add(self, token, predefine_vocab=None): if token in self.token_to_index: return # already added if predefine_vocab: if self.pretrained_token == self.PRETRAINED_INTERSECT and token not in predefine_vocab: return index = len(self.token_to_index) self.token_to_index[token] = index self.index_to_token[index] = token
[docs] def get_index(self, token): return self.token_to_index[token]
[docs] def get_token(self, index): return self.index_to_token[index]
[docs] def get_all_tokens(self): return list(self.token_to_index.keys())
[docs] def dump(self, path): with open(path, "w", encoding="utf-8") as out_file: out_file.write(self.to_text())
[docs] def load(self, path): with open(path, "r", encoding="utf-8") as in_file: texts = in_file.read() self.from_texts(texts)
[docs] def to_text(self): return "\n".join(self.get_all_tokens())
[docs] def from_texts(self, texts): if type(texts) == list: tokens = texts else: tokens = [token for token in texts.split("\n")] tokens = [token for token in tokens if token] # filtering empty string # basic token (pad and oov) if self.pad_token in tokens: self.pad_index = tokens.index(self.pad_token) else: self.pad_index = len(tokens) tokens.append(self.pad_token) if self.oov_token in tokens: self.oov_index = tokens.index(self.oov_token) else: self.oov_index = len(tokens) tokens.append(self.oov_token) self.token_to_index = VocabDict(self.oov_index) self.index_to_token = VocabDict(self.oov_token) for token in tokens: self.add(token) return self