Source code for s3prl.metric.slot_filling

"""
Metrics for the slot filling SLU task

Authors:
  * Yung-Sung Chuang 2021
  * Heng-Jui Chang 2022
"""

import re
from typing import Dict, List, Tuple

from .common import cer, wer

__all__ = ["slot_type_f1", "slot_value_cer", "slot_value_wer", "slot_edit_f1"]


def clean(ref: str) -> str:
    ref = re.sub(r"B\-(\S+) ", "", ref)
    ref = re.sub(r" E\-(\S+)", "", ref)
    return ref


def parse(hyp: str, ref: str) -> Tuple[str, str, str, str]:
    gex = re.compile(r"B\-(\S+) (.+?) E\-\1")

    hyp = re.sub(r" +", " ", hyp)
    ref = re.sub(r" +", " ", ref)

    hyp_slots = gex.findall(hyp)
    ref_slots = gex.findall(ref)

    ref_slots = ";".join([":".join([x[1], x[0]]) for x in ref_slots])
    if len(hyp_slots) > 0:
        hyp_slots = ";".join([":".join([clean(x[1]), x[0]]) for x in hyp_slots])
    else:
        hyp_slots = ""

    ref = clean(ref)
    hyp = clean(hyp)

    return ref, hyp, ref_slots, hyp_slots


def get_slot_dict(
    hyp: str, ref: str
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    ref_text, hyp_text, ref_slots, hyp_slots = parse(hyp, ref)

    ref_slots = ref_slots.split(";")
    hyp_slots = hyp_slots.split(";")
    ref_dict, hyp_dict = {}, {}

    if ref_slots[0] != "":
        for ref_slot in ref_slots:
            v, k = ref_slot.split(":")
            ref_dict.setdefault(k, [])
            ref_dict[k].append(v)

    if hyp_slots[0] != "":
        for hyp_slot in hyp_slots:
            v, k = hyp_slot.split(":")
            hyp_dict.setdefault(k, [])
            hyp_dict[k].append(v)

    return ref_dict, hyp_dict


[docs]def slot_type_f1(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float: F1s = [] for p, t in zip(hypothesis, groundtruth): ref_dict, hyp_dict = get_slot_dict(p, t) # Slot Type F1 evaluation if len(hyp_dict.keys()) == 0 and len(ref_dict.keys()) == 0: F1 = 1.0 elif len(hyp_dict.keys()) == 0: F1 = 0.0 elif len(ref_dict.keys()) == 0: F1 = 0.0 else: P, R = 0.0, 0.0 for slot in ref_dict: if slot in hyp_dict: R += 1 R = R / len(ref_dict.keys()) for slot in hyp_dict: if slot in ref_dict: P += 1 P = P / len(hyp_dict.keys()) F1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0 F1s.append(F1) return sum(F1s) / len(F1s)
[docs]def slot_value_cer(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float: value_hyps, value_refs = [], [] for p, t in zip(hypothesis, groundtruth): ref_dict, hyp_dict = get_slot_dict(p, t) # Slot Value WER/CER evaluation unique_slots = list(ref_dict.keys()) for slot in unique_slots: for ref_i, ref_v in enumerate(ref_dict[slot]): if slot not in hyp_dict: hyp_v = "" value_refs.append(ref_v) value_hyps.append(hyp_v) else: min_cer = 100 best_hyp_v = "" for hyp_v in hyp_dict[slot]: tmp_cer = cer([hyp_v], [ref_v]) if min_cer > tmp_cer: min_cer = tmp_cer best_hyp_v = hyp_v value_refs.append(ref_v) value_hyps.append(best_hyp_v) return cer(value_hyps, value_refs)
[docs]def slot_value_wer(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float: value_hyps = [] value_refs = [] for p, t in zip(hypothesis, groundtruth): ref_dict, hyp_dict = get_slot_dict(p, t) # Slot Value WER/CER evaluation unique_slots = list(ref_dict.keys()) for slot in unique_slots: for ref_i, ref_v in enumerate(ref_dict[slot]): if slot not in hyp_dict: hyp_v = "" value_refs.append(ref_v) value_hyps.append(hyp_v) else: min_wer = 100 best_hyp_v = "" for hyp_v in hyp_dict[slot]: tmp_wer = wer([hyp_v], [ref_v]) if min_wer > tmp_wer: min_wer = tmp_wer best_hyp_v = hyp_v value_refs.append(ref_v) value_hyps.append(best_hyp_v) return wer(value_hyps, value_refs)
[docs]def slot_edit_f1( hypothesis: List[str], groundtruth: List[str], loop_over_all_slot: bool, **kwargs ) -> float: slot2F1 = {} # defaultdict(lambda: [0,0,0]) # TPs, FNs, FPs for p, t in zip(hypothesis, groundtruth): ref_dict, hyp_dict = get_slot_dict(p, t) # Collecting unique slots unique_slots = list(ref_dict.keys()) if loop_over_all_slot: unique_slots += [x for x in hyp_dict if x not in ref_dict] # Evaluating slot edit F1 for slot in unique_slots: TP = 0 FP = 0 FN = 0 if slot not in ref_dict: # this never happens in list(ref_dict.keys()) for hyp_v in hyp_dict[slot]: FP += 1 else: for ref_i, ref_v in enumerate(ref_dict[slot]): if slot not in hyp_dict: FN += 1 else: match = False for hyp_v in hyp_dict[slot]: # if ref_i < len(hyp_dict[slot]): # hyp_v = hyp_dict[slot][ref_i] if hyp_v == ref_v: match = True break if match: TP += 1 else: FN += 1 FP += 1 slot2F1.setdefault(slot, [0, 0, 0]) slot2F1[slot][0] += TP slot2F1[slot][1] += FN slot2F1[slot][2] += FP all_TPs, all_FNs, all_FPs = 0, 0, 0 for slot in slot2F1.keys(): all_TPs += slot2F1[slot][0] all_FNs += slot2F1[slot][1] all_FPs += slot2F1[slot][2] return 2 * all_TPs / (2 * all_TPs + all_FPs + all_FNs)
def slot_edit_f1_full(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float: return slot_edit_f1(hypothesis, groundtruth, loop_over_all_slot=True, **kwargs) def slot_edit_f1_part(hypothesis: List[str], groundtruth: List[str], **kwargs) -> float: return slot_edit_f1(hypothesis, groundtruth, loop_over_all_slot=False, **kwargs)