Source code for claf.metric.squad_v1_official

""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


[docs]def normalize_answer(s): # pragma: no cover """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s))))
[docs]def f1_score(prediction, ground_truth): # pragma: no cover prediction_tokens = normalize_answer(prediction).split() ground_truth_tokens = normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1
[docs]def exact_match_score(prediction, ground_truth): # pragma: no cover return normalize_answer(prediction) == normalize_answer(ground_truth)
[docs]def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): # pragma: no cover scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths)
[docs]def evaluate(dataset, predictions): f1 = exact_match = total = 0 for article in dataset: for paragraph in article["paragraphs"]: for qa in paragraph["qas"]: total += 1 if qa["id"] not in predictions: message = "Unanswered question " + qa["id"] + " will receive score 0." print(message, file=sys.stderr) continue ground_truths = list(map(lambda x: x["text"], qa["answers"])) prediction = predictions[qa["id"]] exact_match += metric_max_over_ground_truths( exact_match_score, prediction, ground_truths ) f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) exact_match = 100.0 * exact_match / total f1 = 100.0 * f1 / total return {"em": exact_match, "f1": f1}
if __name__ == "__main__": # pragma: no cover expected_version = "1.1" parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version) parser.add_argument("dataset_file", help="Dataset file") parser.add_argument("prediction_file", help="Prediction File") args = parser.parse_args() with open(args.dataset_file) as dataset_file: dataset_json = json.load(dataset_file) if dataset_json["version"] != expected_version: print( "Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"], file=sys.stderr, ) dataset = dataset_json["data"] with open(args.prediction_file) as prediction_file: predictions = json.load(prediction_file) print(json.dumps(evaluate(dataset, predictions)))