diff --git a/src/tasks/__init__.py b/src/tasks/__init__.py index e4bd6a1..c580290 100644 --- a/src/tasks/__init__.py +++ b/src/tasks/__init__.py @@ -17,6 +17,7 @@ "flare_fiqasa": flare.FIQASA, "flare_ner": flare.NER, "flare_finqa": flare.FinQA, + "flare_auditqa_zh": flare.AuditQA, "flare_convfinqa": flare.ConvFinQA, "flare_headlines": flare.Headlines, "flare_finer_ord": flare.FinerOrd, diff --git a/src/tasks/flare.py b/src/tasks/flare.py index 474c1d0..1f2849b 100644 --- a/src/tasks/flare.py +++ b/src/tasks/flare.py @@ -13,6 +13,8 @@ import re from factscore_package.factscorer import FactScorer import os +import jieba +from rouge_chinese import Rouge #from comet import download_model, load_from_checkpoint _CITATION = """ @@ -698,6 +700,130 @@ def aggregation(self): } +class QA_zh(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["answer"], results[0]), + "rouge2": (doc["answer"], results[0]), + "rougeL": (doc["answer"], results[0]), + "bert_score_f1": (doc["answer"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def rouge_score_zh(self, items): + # Path to the dictionary for Chinese word segmentation + # Specifically designed for Chinese evaluation + # Set this to your local dictionary path. + USER_DICT_PATH = "/path/to/vocab.txt" + hyps, refs = map(list, zip(*[[' '.join(jieba.cut(d[0])), ' '.join(jieba.cut(d[1]))] for d in items])) + filter_hyps = [] + filter_refs = [] + for i in range(len(hyps)): + hyp = hyps[i] + ref = refs[i] + if self.is_whitespace_string(hyp) or self.is_whitespace_string(ref): + continue + if hyp != '' and ref != '': + filter_hyps.append(hyp) + filter_refs.append(ref) + rouge = Rouge() + scores = rouge.get_scores(filter_hyps, filter_refs, avg=True, ignore_empty=True) + return scores + + def rouge1(self, items): + results = self.rouge_score_zh(items) + return results["rouge-1"]['f'] + + def rouge2(self, items): + results = self.rouge_score_zh(items) + return results["rouge-2"]['f'] + + def rougeL(self, items): + results = self.rouge_score_zh(items) + return results["rouge-l"]['f'] + + def is_whitespace_string(self, s): + return s.isspace() + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, preds = zip(*items) + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-chinese", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + } + + class FPB(Classification): DATASET_PATH = "chancefocus/flare-fpb" @@ -787,6 +913,13 @@ class FinQA(QA): DATASET_PATH = "chancefocus/flare-finqa" +class AuditQA(QA_zh): + # Path to the dataset for the AuditQA class + # Specifically designed for a Chinese dataset + # Set this to the appropriate path for the dataset + DATASET_PATH = "/path/to/dataset" + + class StockMovement(Classification): DATASET_NAME = None CALCULATE_MCC = True