claf.model.token_classification package

Submodules

class claf.model.token_classification.mixin.TokenClassification[source]

Bases: object

Token Classification Mixin Class

make_metrics(predictions)[source]

Make metrics with prediction dictionary

  • Args:
    predictions: prediction dictionary consisting of
    • key: ‘id’ (sequence id)

    • value: dictionary consisting of
      • tag_idxs

  • Returns:
    metrics: metric dictionary consisting of
    • ‘accuracy’: sequence level accuracy

    • ‘tag_accuracy’: tag level accuracy

    • ‘macro_f1’: tag prediction macro(unweighted mean) f1

    • ‘macro_precision’: tag prediction macro(unweighted mean) precision

    • ‘macro_recall’: tag prediction macro(unweighted mean) recall

make_predictions(output_dict)[source]

Make predictions with model’s output_dict

  • Args:
    output_dict: model’s output dictionary consisting of
    • sequence_embed: embedding vector of the sequence

    • tag_logits: representing unnormalized log probabilities of the tag

    • tag_idxs: target tag idxs

    • data_idx: data idx

    • loss: a scalar loss to be optimized

  • Returns:
    predictions: prediction dictionary consisting of
    • key: ‘id’ (sequence id)

    • value: dictionary consisting of
      • tag_idxs

predict(**kwargs)
print_examples(index, inputs, predictions)[source]

Print evaluation examples

  • Args:

    index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of

    • key: ‘id’ (sequence id)

    • value: dictionary consisting of
      • class_idx

  • Returns:

    print(Sequence, Target Tags, Target Slots, Predicted Tags, Predicted Slots)

write_predictions(predictions, file_path=None, is_dict=True, pycm_obj=None)[source]

Override write_predictions() in ModelBase to log confusion matrix

Module contents

class claf.model.token_classification.BertForTokCls(token_makers, num_tags, ignore_tag_idx, pretrained_model_name=None, dropout=0.2)[source]

Bases: claf.model.token_classification.mixin.TokenClassification, claf.model.base.ModelWithoutTokenEmbedder

Implementation of Single Sentence Tagging model presented in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (https://arxiv.org/abs/1810.04805)

  • Args:

    token_embedder: used to embed the sequence num_tags: number of classified tags ignore_tag_idx: index of the tag to ignore when calculating loss (tag pad value)

  • Kwargs:

    pretrained_model_name: the name of a pre-trained model dropout: classification layer dropout

forward(features, labels=None)[source]
  • Args:

    features: feature dictionary like below. {

    “bert_input”: {
    “feature”: [

    [100, 576, 21, 45, 7, 91, 101, 0, 0, …], …,

    ]

    } “token_type”: {

    “feature”: [

    [0, 0, 0, 0, 0, 0, 0, 0, 0, …], …,

    ]

    }, “tagged_sub_token_idxs”: {

    [

    [1, 3, 4, 0, 0, 0, 0, 0, 0, …], …,

    ]

    }

    }

  • Kwargs:

    label: label dictionary like below. {

    “class_idx”: [2, 1, 0, 4, 5, …] “data_idx”: [2, 4, 5, 7, 2, 1, …]

    } Do not calculate loss when there is no label. (inference/predict mode)

  • Returns: output_dict (dict) consisting of
    • sequence_embed: embedding vector of the sequence

    • tag_logits: representing unnormalized log probabilities of the tags.

    • tag_idxs: target class idx

    • data_idx: data idx

    • loss: a scalar loss to be optimized

print_examples(index, inputs, predictions)[source]

Print evaluation examples

  • Args:

    index: data index inputs: mini-batch inputs predictions: prediction dictionary consisting of

    • key: ‘id’ (sequence id)

    • value: dictionary consisting of
      • class_idx

  • Returns:

    print(Sequence, Sequence Tokens, Target Tags, Target Slots, Predicted Tags, Predicted Slots)