From 3e1ab18da9816d0460206a832d9033187e70fee5 Mon Sep 17 00:00:00 2001 From: lemonade1258 <1779716932@qq.com> Date: Tue, 22 Oct 2024 23:01:04 +0800 Subject: [PATCH 1/2] feat: add QA_zh class in flare.py and update __init__.py --- .../.ipynb_checkpoints/__init__-checkpoint.py | 132 ++ .../.ipynb_checkpoints/flare-checkpoint.py | 1552 +++++++++++++++++ src/tasks/__init__.py | 1 + src/tasks/flare.py | 133 ++ 4 files changed, 1818 insertions(+) create mode 100644 src/tasks/.ipynb_checkpoints/__init__-checkpoint.py create mode 100644 src/tasks/.ipynb_checkpoints/flare-checkpoint.py diff --git a/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py b/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..c580290 --- /dev/null +++ b/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,132 @@ +from pprint import pprint +from typing import List, Union + +import json +import lm_eval.base + +from . import flare + +TASK_REGISTRY = { + "flare_es_financees": flare.ESFINANCEES, + "flare_es_multifin": flare.ESMultiFin, + "flare_es_efp": flare.ESEFP, + "flare_es_efpa": flare.ESEFPA, + "flare_es_fns": flare.ESFNS, + "flare_es_tsa": flare.ESTSA, + "flare_fpb": flare.FPB, + "flare_fiqasa": flare.FIQASA, + "flare_ner": flare.NER, + "flare_finqa": flare.FinQA, + "flare_auditqa_zh": flare.AuditQA, + "flare_convfinqa": flare.ConvFinQA, + "flare_headlines": flare.Headlines, + "flare_finer_ord": flare.FinerOrd, + "flare_fomc": flare.FOMC, + "flare_german": flare.German, + "flare_australian": flare.Australian, + "flare_fomc": flare.FOMC, + "flare_ectsum": flare.ECTSUM, + "flare_edtsum": flare.EDTSUM, + "flare_finarg_ecc_auc": flare.FinargECCAUC, + "flare_finarg_ecc_arc": flare.FinargECCARC, + "flare_cd": flare.CD, + "flare_multifin_en": flare.MultiFinEN, + "flare_tsa": flare.TSA, + "flare_cfa": flare.CFA, + "flare_ma": flare.MA, + "flare_causal20_sc": flare.Causal20SC, + "flare_finarg_ecc_arc": flare.FINARGECCARC, + "flare_finarg_ecc_auc": flare.FINARGECCAUC, + "flare_mlesg": flare.MLESG, + "flare_fnxl": flare.FNXL, + "flare_fsrl": flare.FSRL, + "flare_tatqa": flare.TATQA, + "flare_finred": flare.FinRED, + "flare_cra_lendingclub": flare.lendingclub, + "flare_cra_ccf": flare.ccf, + "flare_cra_ccfraud": flare.ccfraud, + "flare_cra_polish": flare.polish, + "flare_cra_taiwan": flare.taiwan, + "flare_cra_portoseguro": flare.portoseguro, + "flare_cra_travelinsurace": flare.travelinsurace, + "flare_sm_bigdata": flare.StockMovementBigData, + "flare_sm_acl": flare.StockMovementACL, + "flare_sm_cikm": flare.StockMovementCIKM, + "flare_en_finterm": flare.FINTERM, + "flare_en_acronym": flare.ACRONYM, + **flare.SM_TASKS, + "flare_finarg_ecc_auc_test": flare.FINARGECCAUC_test, + "flare_edtsum_test": flare.EDTSUM_test, +} + +ALL_TASKS = sorted(list(TASK_REGISTRY)) + +_EXAMPLE_JSON_PATH = "split:key:/absolute/path/to/data.json" + + +def add_json_task(task_name): + """Add a JSON perplexity task if the given task name matches the + JSON task specification. + + See `json.JsonPerplexity`. + """ + if not task_name.startswith("json"): + return + + def create_json_task(): + splits = task_name.split("=", 1) + if len(splits) != 2 or not splits[1]: + raise ValueError( + "json tasks need a path argument pointing to the local " + "dataset, specified like this: json=" + + _EXAMPLE_JSON_PATH + + ' (if there are no splits, use "train")' + ) + + json_path = splits[1] + if json_path == _EXAMPLE_JSON_PATH: + raise ValueError( + "please do not copy the example path directly, but substitute " + "it with a path to your local dataset" + ) + return lambda: json.JsonPerplexity(json_path) + + TASK_REGISTRY[task_name] = create_json_task() + + +def get_task(task_name): + try: + add_json_task(task_name) + return TASK_REGISTRY[task_name] + except KeyError: + print("Available tasks:") + pprint(TASK_REGISTRY) + raise KeyError(f"Missing task {task_name}") + + +def get_task_name_from_object(task_object): + for name, class_ in TASK_REGISTRY.items(): + if class_ is task_object: + return name + + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) + + +def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): + task_name_dict = { + task_name: get_task(task_name)() + for task_name in task_name_list + if isinstance(task_name, str) + } + task_name_from_object_dict = { + get_task_name_from_object(task_object): task_object + for task_object in task_name_list + if not isinstance(task_object, str) + } + assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) + return {**task_name_dict, **task_name_from_object_dict} diff --git a/src/tasks/.ipynb_checkpoints/flare-checkpoint.py b/src/tasks/.ipynb_checkpoints/flare-checkpoint.py new file mode 100644 index 0000000..1f2849b --- /dev/null +++ b/src/tasks/.ipynb_checkpoints/flare-checkpoint.py @@ -0,0 +1,1552 @@ +""" +FLARE +""" +from lm_eval.base import Task, rf +from lm_eval.metrics import mean, bleu, chrf, ter +import numpy as np +from .utils import process_text +from .zhutils import process_zhtext +from seqeval.metrics import f1_score as entity_score +from sklearn.metrics import f1_score, matthews_corrcoef, mean_squared_error +from bart_score import BARTScorer +import evaluate +import re +from factscore_package.factscorer import FactScorer +import os +import jieba +from rouge_chinese import Rouge +#from comet import download_model, load_from_checkpoint + +_CITATION = """ +@misc{xie2023pixiu, + title={PIXIU: A Large Language Model, Instruction Data and Evaluation Benchmark for Finance}, + author={Qianqian Xie and Weiguang Han and Xiao Zhang and Yanzhao Lai and Min Peng and Alejandro Lopez-Lira and Jimin Huang}, + year={2023}, + eprint={2306.05443}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +""" + + +class Classification(Task): + CALCULATE_MCC = True + LOWER_CASE = True + VERSION = 1 + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["answer"] + + def process_results(self, doc, results): + gold: str = doc["choices"][doc["gold"]] + if self.LOWER_CASE: + gold = gold.lower() + ini_result = results[0].strip() + if self.LOWER_CASE: + ini_result = ini_result.lower() + + result = None + for choice in doc["choices"]: + if self.LOWER_CASE: + choice = choice.lower() + if choice in ini_result: + result = choice + break + if result is None: + result = "missing" + + acc = 1.0 if gold == result else 0.0 + + results = { + "acc": acc, + "missing": int(result == "missing"), + "f1": (result, gold), + "macro_f1": (result, gold), + } + + if self.CALCULATE_MCC: + results["mcc"] = (result, gold) + + return results + + def higher_is_better(self): + metrics = { + "acc": True, + "f1": True, + "macro_f1": True, + "missing": False, + } + if self.CALCULATE_MCC: + metrics["mcc"] = True + return metrics + + def weighted_f1(self, items): + preds, golds = zip(*items) + labels = list(set(golds)) + preds = np.array(preds) + golds = np.array(golds) + f1 = f1_score(golds, preds, average="weighted", labels=labels) + return f1 + + def macro_f1(self, items): + preds, golds = zip(*items) + labels = list(set(golds)) + preds = np.array(preds) + golds = np.array(golds) + f1 = f1_score(golds, preds, average="macro", labels=labels) + return f1 + + def matthews_corrcoef(self, items): + preds, golds = zip(*items) + labels = {label: i for i, label in enumerate(list(set(golds)))} + preds = [labels.get(pred, -1) for pred in preds] + golds = [labels.get(gold, -1) for gold in golds] + return matthews_corrcoef(golds, preds) + + def aggregation(self): + metrics = { + "acc": mean, + "missing": mean, + "f1": self.weighted_f1, + "macro_f1": self.macro_f1, + } + if self.CALCULATE_MCC: + metrics["mcc"] = self.matthews_corrcoef + return metrics + + +class SequentialLabeling(Task): + VERSION = 1 + DATASET_NAME = None + LMAP = {"O": 0} + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return "\nAnswer: " + doc["answer"] + + def process_results(self, doc, results): + return { + "entity_f1": (doc["label"], results[0], doc["token"]), + "f1": (doc["label"], results[0], doc["token"]), + } + + def higher_is_better(self): + return { + "f1": True, + "entity_f1": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def process_result(self, pred, gold, tokens): + format_pred = ["O"] * len(gold) + for index, pre in enumerate(pred.split("\n")[: len(tokens)]): + try: + word, label = pre.split(":") + except: + continue + if word == tokens[index] and label in self.LMAP.keys(): + format_pred[index] = label + return format_pred + + def entity_f1(self, items): + golds, preds, tokens = zip(*items) + + list_preds = [ + self.process_result(pred, gold, token) + for pred, gold, token in zip(preds, golds, tokens) + ] + f1 = entity_score(golds, list_preds) + return f1 + + def process_label_result(self, pred, gold, tokens): + format_pred = [-1] * len(gold) + for index, pre in enumerate(pred.split("\n")[: len(tokens)]): + try: + word, label = pre.split(":") + except: + continue + if word == tokens[index]: + format_pred[index] = self.LMAP.get(label, -1) + return format_pred + + def label_f1(self, items): + golds, preds, tokens = zip(*items) + + list_preds = [ + self.process_label_result(pred, gold, token) + for pred, gold, token in zip(preds, golds, tokens) + ] + list_preds = [item for sublist in list_preds for item in sublist] + golds = [self.LMAP[item] for sublist in golds for item in sublist] + f1 = f1_score(golds, list_preds, average="weighted") + return f1 + + def aggregation(self): + return { + "entity_f1": self.entity_f1, + "f1": self.label_f1, + } + + +class AbstractiveSummarization(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["answer"], results[0]), + "rouge2": (doc["answer"], results[0]), + "rougeL": (doc["answer"], results[0]), + "bert_score_f1": (doc["answer"], results[0]), + "bart_score": (doc["answer"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + "bart_score": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def rouge_score(self, items): + golds, preds = zip(*items) + rouge = evaluate.load("rouge") + results = rouge.compute(predictions=preds, references=golds) + return results + + def rouge1(self, items): + results = self.rouge_score(items) + return results["rouge1"] + + def rouge2(self, items): + results = self.rouge_score(items) + return results["rouge2"] + + def rougeL(self, items): + results = self.rouge_score(items) + return results["rougeL"] + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, preds = zip(*items) + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-multilingual-cased", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def bart_score(self, items): + golds, preds = zip(*items) + bart_scorer = BARTScorer(device="cuda", checkpoint="facebook/bart-large-cnn") + bart_scorer.load(path="src/metrics/BARTScore/bart_score.pth") + res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) + return sum(res) / len(res) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + "bart_score": self.bart_score, + } + + +class ExtractiveSummarization(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["label"], doc["text"], results[0]), + "rouge2": (doc["label"], doc["text"], results[0]), + "rougeL": (doc["label"], doc["text"], results[0]), + "bert_score_f1": (doc["label"], doc["text"], results[0]), + "bart_score": (doc["label"], doc["text"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + "bart_score": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def get_sum(self, labels, texts): + summ = [] + for label, text in zip(labels, texts): + text = text.split("\n") + new_text = "\n".join( + [ + text[index] + for index in range(len(text)) + if index < len(label) and label[index] == 1 + ] + ) + summ.append(new_text) + return summ + + def rouge_score(self, items): + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + rouge = evaluate.load("rouge") + results = rouge.compute(predictions=preds, references=golds) + return results + + def rouge1(self, items): + results = self.rouge_score(items) + return results["rouge1"] + + def rouge2(self, items): + results = self.rouge_score(items) + return results["rouge2"] + + def rougeL(self, items): + results = self.rouge_score(items) + return results["rougeL"] + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-multilingual-cased", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def bart_score(self, items): + golds, texts, preds = zip(*items) + golds = self.get_sum(golds, texts) + preds = self.get_sum([val.split("\n") for val in preds], texts) + + bart_scorer = BARTScorer(device="cuda:0", checkpoint="facebook/bart-large-cnn") + bart_scorer.load(path="src/metrics/BARTScore/bart_score.pth") + res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) + return sum(res) / len(res) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + "bart_score": self.bart_score, + } + + +class RelationExtraction(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "precision": (doc["label"], results[0]), + "recall": (doc["label"], results[0]), + "f1": (doc["label"], results[0]), + } + + def higher_is_better(self): + return { + "precision": True, + "recall": True, + "f1": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def process(self, items): + golds, preds = zip(*items) + + all_golds = [] + all_preds = [] + + for gold, pred in zip(golds, preds): + all_golds.extend(gold) + pred = pred.split("\n") + all_preds.extend(pred) + + return set(all_golds), set(all_preds) + + def precision(self, items): + golds, preds = self.process(items) + tp = golds & preds + prec = len(tp) / len(preds) + return prec + + def recall(self, items): + golds, preds = self.process(items) + tp = golds & preds + rec = len(tp) / len(golds) + return rec + + def cal_f1(self, items): + prec = self.precision(items) + rec = self.recall(items) + if prec + rec == 0.0: + return 0.0 + return 2 * (prec * rec) / (prec + rec) + + def aggregation(self): + return { + "precision": self.precision, + "recall": self.recall, + "f1": self.cal_f1, + } + + +class QA(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + gold = doc["answer"] + + acc = 1.0 if results[0].strip() == gold else 0.0 + + return { + "acc": acc, + } + + def higher_is_better(self): + return { + "acc": True, + } + + def aggregation(self): + return { + "acc": mean, + } + + +class QA_zh(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["answer"], results[0]), + "rouge2": (doc["answer"], results[0]), + "rougeL": (doc["answer"], results[0]), + "bert_score_f1": (doc["answer"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def rouge_score_zh(self, items): + # Path to the dictionary for Chinese word segmentation + # Specifically designed for Chinese evaluation + # Set this to your local dictionary path. + USER_DICT_PATH = "/path/to/vocab.txt" + hyps, refs = map(list, zip(*[[' '.join(jieba.cut(d[0])), ' '.join(jieba.cut(d[1]))] for d in items])) + filter_hyps = [] + filter_refs = [] + for i in range(len(hyps)): + hyp = hyps[i] + ref = refs[i] + if self.is_whitespace_string(hyp) or self.is_whitespace_string(ref): + continue + if hyp != '' and ref != '': + filter_hyps.append(hyp) + filter_refs.append(ref) + rouge = Rouge() + scores = rouge.get_scores(filter_hyps, filter_refs, avg=True, ignore_empty=True) + return scores + + def rouge1(self, items): + results = self.rouge_score_zh(items) + return results["rouge-1"]['f'] + + def rouge2(self, items): + results = self.rouge_score_zh(items) + return results["rouge-2"]['f'] + + def rougeL(self, items): + results = self.rouge_score_zh(items) + return results["rouge-l"]['f'] + + def is_whitespace_string(self, s): + return s.isspace() + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, preds = zip(*items) + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-chinese", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + } + + +class FPB(Classification): + DATASET_PATH = "chancefocus/flare-fpb" + + +class FIQASA(Classification): + DATASET_PATH = "chancefocus/flare-fiqasa" + + +class NER(Task): + VERSION = 1 + DATASET_PATH = "chancefocus/flare-ner" + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def should_decontaminate(self): + return True + + def doc_to_decontamination_query(self, doc): + return doc["text"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + text = doc["text"] + pred = process_text(results[0], text) + + return {"entity_f1": (pred, doc["label"], results[0])} + + def higher_is_better(self): + return { + "entity_f1": True, + } + + @classmethod + def entity_f1(cls, items): + preds, golds, _ = zip(*items) + f1 = entity_score(golds, preds) + return f1 + + def aggregation(self): + return { + "entity_f1": self.entity_f1, + } + + +class FinQA(QA): + DATASET_PATH = "chancefocus/flare-finqa" + + +class AuditQA(QA_zh): + # Path to the dataset for the AuditQA class + # Specifically designed for a Chinese dataset + # Set this to the appropriate path for the dataset + DATASET_PATH = "/path/to/dataset" + + +class StockMovement(Classification): + DATASET_NAME = None + CALCULATE_MCC = True + CHOICE_DICT = { + "rise": ["yes", "positive"], + "fall": ["no", "negative", "neutral"], + } + DEFAULT = "fall" + + def process_results(self, doc, results): + gold: str = doc["choices"][doc["gold"]] + if self.LOWER_CASE: + gold = gold.lower() + ini_result = results[0].strip() + if self.LOWER_CASE: + ini_result = ini_result.lower() + + result = None + for choice in doc["choices"]: + if self.LOWER_CASE: + choice = choice.lower() + if choice in ini_result or any( + [val in ini_result for val in self.CHOICE_DICT[choice]] + ): + result = choice + break + if result is None: + result = self.DEFAULT + + acc = 1.0 if gold == result else 0.0 + + results = { + "acc": acc, + "missing": int(result == "missing"), + "f1": (result, gold), + "macro_f1": (result, gold), + } + + if self.CALCULATE_MCC: + results["mcc"] = (result, gold) + + return results + + +class StockMovementBigData(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-bigdata" + + +class StockMovementACL(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-acl" + + +class StockMovementCIKM(StockMovement): + DATASET_PATH = "chancefocus/flare-sm-cikm" + + +SM_TASKS = { + "flare_sm_bigdata": StockMovementBigData, + "flare_sm_acl": StockMovementACL, + "flare_sm_cikm": StockMovementCIKM, +} + + +class Headlines(Classification): + DATASET_PATH = "chancefocus/flare-headlines" + + def process_results(self, doc, results): + gold = doc["gold"] + + return { + "avg_f1": (doc["label_type"], int(results[0].strip() != "Yes"), gold, results), + } + + def higher_is_better(self): + return { + "avg_f1": True, + } + + @classmethod + def label_avg(cls, items): + labels, preds, golds, rels = zip(*items) + label_set = set(labels) + labels = np.array(labels) + preds = np.array(preds) + golds = np.array(golds) + all_f1s = [] + for l in label_set: + pds = preds[labels == l] + gds = golds[labels == l] + f1 = f1_score(gds, pds, average="weighted", labels=[0, 1]) + all_f1s.append(f1) + return np.mean(all_f1s) + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def aggregation(self): + return { + "avg_f1": self.label_avg, + } + + +class FinerOrd(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-finer-ord" + LMAP = { + "O": 0, + "B-PER": 1, + "I-PER": 2, + "B-LOC": 3, + "I-LOC": 4, + "B-ORG": 5, + "I-ORG": 6, + } + + +class FOMC(Classification): + DATASET_PATH = "chancefocus/flare-fomc" + + +class German(StockMovement): + DATASET_PATH = "chancefocus/flare-german" + CHOICE_DICT = { + "good": ["yes", "positive"], + "bad": ["no", "negative", "neutral"], + } + DEFAULT = "good" + + +class Australian(StockMovement): + DATASET_PATH = "chancefocus/flare-australian" + CHOICE_DICT = { + "good": ["yes", "positive"], + "bad": ["no", "negative", "neutral"], + } + DEFAULT = "good" + + +class ECTSUM(ExtractiveSummarization): + DATASET_PATH = "chancefocus/flare-ectsum" + + +class EDTSUM(AbstractiveSummarization): + DATASET_PATH = "chancefocus/flare-edtsum" + + +class EDTSUM_test(AbstractiveSummarization): + DATASET_PATH = "TheFinAI/flare-edtsum_test" + + +class ConvFinQA(QA): + DATASET_PATH = "chancefocus/flare-convfinqa" + + def reformulate_turn_req(self, req, turn_request, turn): + if turn == 0: + return req + pre_answers = {f"answer{i}": turn_request[i][0] for i in range(turn)} + if pre_answers: + req.args = tuple([req.args[0].format(**pre_answers)] + list(req.args[1:])) + return req + + +class TSA(Task): + VERSION = 1 + DATASET_PATH = "chancefocus/flare-tsa" + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return "\nAnswer: " + str(doc["answer"]) + + def process_results(self, doc, results): + pred = results[0].split("\n")[0] + pred = re.findall(r'[0-9]+(?:\.[0-9]+)?', pred) + missing = 0 + if not pred: + pred = -100.0 + missing = 1 + else: + pred = pred[0] + pred = float(pred) + return { + "rmse": (doc["answer"], pred), + "missing": missing + } + + def higher_is_better(self): + return { + "rmse": False, + } + + def construct_requests(self, doc, ctx): + """ + Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": "Answer:"}) + return cont_request + + def rmse(self, items): + golds, preds = zip(*items) + fgolds, fpreds = [], [] + for gold, pred in zip(golds, preds): + if pred == -100.0: + continue + fgolds.append(gold) + fpreds.append(max(min(pred, 1.0), -1.0)) + rmse = mean_squared_error(fgolds, fpreds, squared=True) + + return rmse + + def aggregation(self): + return { + "rmse": self.rmse, + "missing": mean, + } + + + +class CFA(Classification): + DATASET_PATH = "chancefocus/flare-cfa" + LOWER_CASE = False + + +class FINARGECCARC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" + + +class FINARGECCAUC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" + + +class FINARGECCAUC_test(Classification): + DATASET_PATH = "TheFinAI/flare-finarg-ecc-auc_test" + + +class MLESG(Classification): + DATASET_PATH = "chancefocus/flare-mlesg" + + +class FSRL(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-fsrl" + LMAP = {key: index for index, key in enumerate(['O', 'I-QUANT', 'B-QUANT', 'I-TIME', 'B-TIME', 'I-MANNER', 'B-MANNER', 'I-THEME', 'B-THEME', 'I-VALUE', 'B-VALUE', 'I-WHOLE', 'B-WHOLE', 'I-LOCATION', 'B-LOCATION', 'I-AGENT', 'B-AGENT', 'I-CAUSE', 'B-CAUSE', 'I-SOURCE', 'B-SOURCE', 'I-REF_TIME', 'B-REF_TIME', 'I-CONDITION', 'B-CONDITION'])} + +class CFA(Classification): + DATASET_PATH = "chancefocus/flare-cfa" + +class FinargECCAUC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" + +class FinargECCARC(Classification): + DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" + +class CD(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-cd" + LMAP = {key: index for index, key in enumerate(['O', 'I-CAUSE', 'B-CAUSE', 'I-EFFECT', 'B-EFFECT'])} + +class MultiFinEN(Classification): + DATASET_PATH = "chancefocus/flare-multifin-en" + +class MA(Classification): + DATASET_PATH = "chancefocus/flare-ma" + +class Causal20SC(Classification): + DATASET_PATH = "chancefocus/flare-causal20-sc" + +class FNXL(SequentialLabeling): + DATASET_PATH = "chancefocus/flare-fnxl" + LMAP = {'B-BusinessCombinationContingentConsiderationArrangementsRangeOfOutcomesValueHigh': 140, 'B-VariableInterestEntityOwnershipPercentage': 646, 'B-GainLossOnDispositionOfAssets1': 119, 'B-IndefiniteLivedIntangibleAssetsExcludingGoodwill': 46, 'B-MarketingAndAdvertisingExpense': 269, 'B-ReportingUnitPercentageOfFairValueInExcessOfCarryingAmount': 142, 'B-CapitalizedComputerSoftwareNet': 91, 'B-BusinessCombinationConsiderationTransferredEquityInterestsIssuedAndIssuable': 183, 'B-LitigationSettlementExpense': 115, 'B-DefinedBenefitPlanExpectedAmortizationOfGainLossNextFiscalYear': 639, 'B-DeferredCompensationArrangementWithIndividualCompensationExpense': 15, 'B-ReclassificationFromAociCurrentPeriodTax': 152, 'B-OtherComprehensiveIncomeLossBeforeReclassificationsTax': 694, 'B-PreferredStockDividendsPerShareDeclared': 236, 'B-CapitalExpendituresIncurredButNotYetPaid': 344, 'B-DeferredCompensationArrangementWithIndividualContributionsByEmployer': 560, 'B-SeveranceCosts1': 311, 'B-InterestExpense': 784, 'B-SaleOfStockConsiderationReceivedOnTransaction': 76, 'B-LineOfCreditFacilityInterestRateAtPeriodEnd': 822, 'B-SharesIssuedPricePerShare': 137, 'B-EquityMethodInvestmentDifferenceBetweenCarryingAmountAndUnderlyingEquity': 63, 'B-EquitySecuritiesFvNi': 30, 'B-RightOfUseAssetObtainedInExchangeForOperatingLeaseLiability': 118, 'B-DefinedBenefitPlanFundedStatusOfPlan': 547, 'B-SharebasedCompensationArrangementBySharebasedPaymentAwardPurchasePriceOfCommonStockPercent': 323, 'B-TaxCutsAndJobsActOf2017IncomeTaxExpenseBenefit': 256, 'B-LongtermDebtWeightedAverageInterestRate': 364, 'B-ImpairmentOfIntangibleAssetsFinitelived': 71, 'B-ProceedsFromLinesOfCredit': 496, 'B-LongTermPurchaseCommitmentAmount': 701, 'B-DebtInstrumentFairValue': 335, 'B-RestructuringAndRelatedCostCostIncurredToDate1': 52, 'B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsVestedInPeriod': 581, 'B-FiniteLivedIntangibleAssetsAccumulatedAmortization': 143, 'B-StockRepurchasedAndRetiredDuringPeriodValue': 330, 'B-BusinessCombinationProFormaInformationRevenueOfAcquireeSinceAcquisitionDateActual': 77, 'B-ClassOfWarrantOrRightExercisePriceOfWarrantsOrRights1': 361, 'B-BusinessAcquisitionPurchasePriceAllocationGoodwillExpectedTaxDeductibleAmount': 550, 'B-OperatingLossCarryforwardsValuationAllowance': 173, 'B-BusinessAcquisitionEquityInterestsIssuedOrIssuableNumberOfSharesIssued': 32, 'B-DefinedContributionPlanMaximumAnnualContributionsPerEmployeePercent': 45, 'B-ContractWithCustomerLiabilityCurrent': 2, 'B-IncomeLossFromContinuingOperationsBeforeIncomeTaxesForeign': 474, 'B-FiniteLivedIntangibleAssetsAmortizationExpenseYearThree': 1306, 'B-DefinedBenefitPlanUltimateHealthCareCostTrendRate1': 62, 'B-DefinedBenefitPlanRecognizedNetGainLossDueToSettlements1': 317, 'B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesExpense': 448, 'B-ForeignCurrencyTransactionGainLossRealized': 132, 'B-DeferredTaxAssetsOperatingLossCarryforwardsSubjectToExpiration': 262, 'B-RetainedEarningsAccumulatedDeficit': 174, 'B-ProceedsFromIssuanceOfCommonStock': 209, 'B-EmployeeServiceShareBasedCompensationAllocationOfRecognizedPeriodCostsCapitalizedAmount': 29, 'B-OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansTax': 284, 'B-InventoryWriteDown': 465, 'B-RestructuringReserve': 234, 'B-LitigationSettlementAmountAwardedToOtherParty': 42, 'B-DerivativeGainLossOnDerivativeNet': 87, 'B-SharebasedCompensationArrangementBySharebasedPaymentAwardEquityInstrumentsOtherThanOptionsAggregateIntrinsicValueVested': 241, 'B-DerivativeFixedInterestRate': 589, 'B-CashAndCashEquivalentsAtCarryingValue': 257, 'B-ContractWithCustomerAssetNet': 245, 'B-RestructuringAndRelatedCostExpectedCost1': 107, 'B-IncomeTaxHolidayAggregateDollarAmount': 347, 'B-OperatingLeaseCost': 248, 'B-AllowanceForDoubtfulAccountsReceivable': 146, 'B-RepaymentsOfDebt': 416, 'B-InterestPaid': 110, 'B-DeferredFinanceCostsNet': 28, 'B-IncomeTaxExaminationPenaltiesAndInterestAccrued': 271, 'B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsNonvestedNumber': 92, 'B-CapitalizedContractCostNet': 155, 'B-CumulativeEffectOfNewAccountingPrincipleInPeriodOfAdoption': 17, 'B-IncomeTaxesPaid': 495, 'B-EquityMethodInvestmentOtherThanTemporaryImpairment': 22, 'B-InterestPaidNet': 225, 'B-EquitySecuritiesWithoutReadilyDeterminableFairValueAmount': 175, 'B-ImpairmentOfLongLivedAssetsHeldForUse': 313, 'B-GoodwillAcquiredDuringPeriod': 156, 'B-DecreaseInUnrecognizedTaxBenefitsIsReasonablyPossible': 363, 'B-RestructuringAndRelatedCostIncurredCost': 75, 'B-StockRepurchasedDuringPeriodValue': 254, 'B-IncomeTaxExaminationPenaltiesAndInterestExpense': 525, 'B-ImpairmentOfIntangibleAssetsIndefinitelivedExcludingGoodwill': 55, 'B-PreferredStockLiquidationPreference': 157, 'B-ImpairmentOfIntangibleAssetsExcludingGoodwill': 158, 'B-IncomeTaxesPaidNet': 456, 'B-DefinedContributionPlanEmployerMatchingContributionPercent': 332, 'B-CostOfGoodsAndServicesSold': 274, 'B-DepreciationDepletionAndAmortization': 338, 'B-InterestExpenseDebt': 191, 'B-LineOfCreditFacilityUnusedCapacityCommitmentFeePercentage': 442, 'B-DisposalGroupIncludingDiscontinuedOperationConsideration': 6, 'B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesAccrued': 14, 'B-SaleOfStockPricePerShare': 278, 'B-DefinedContributionPlanEmployerMatchingContributionPercentOfMatch': 267, 'B-FinitelivedIntangibleAssetsAcquired1': 202, 'B-PaymentsForRepurchaseOfCommonStock': 486, 'B-BusinessCombinationContingentConsiderationLiability': 103, 'B-RelatedPartyTransactionAmountsOfTransaction': 180, 'O': 0} + +class TATQA(QA): + DATASET_PATH = "chancefocus/flare-tatqa" + + +class FinRED(RelationExtraction): + DATASET_PATH = "chancefocus/flare-finred" + + +class lendingclub(Classification): + DATASET_PATH = "chancefocus/cra-lendingclub" + CALCULATE_MCC = True + + +class ccf(Classification): + DATASET_PATH = "chancefocus/cra-ccf" + CALCULATE_MCC = True + + +class ccfraud(Classification): + DATASET_PATH = "chancefocus/cra-ccfraud" + CALCULATE_MCC = True + + +class polish(Classification): + DATASET_PATH = "chancefocus/cra-polish" + CALCULATE_MCC = True + + +class taiwan(Classification): + DATASET_PATH = "chancefocus/cra-taiwan" + CALCULATE_MCC = True + + +class portoseguro(Classification): + DATASET_PATH = "chancefocus/cra-portoseguro" + CALCULATE_MCC = True + + +class travelinsurace(Classification): + DATASET_PATH = "chancefocus/cra-travelinsurace" + CALCULATE_MCC = True + + + +class LongFormFactuality(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "factscore": (doc["answer"], doc["text"], results[0]), + } + + def higher_is_better(self): + return { + "factscore": True + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def factscore(self, items): + golds, texts, preds = zip(*items) + texts = list(texts) + preds = list(preds) + + fs = FactScorer("retrieval+ChatGPT", openai_key=os.environ["OPENAI_API_KEY"]) + + fs.register_knowledge_source("finterms", data_path="./src/factscore_package/.cache/finterms.jsonl", db_path="./src/factscore_package/.cache/fin_terms.db") + + score = 0 + num_facts = 0 + # now, when you compute a score, specify knowledge source to use + for i in range(len(texts)): + try: + out = fs.get_score([texts[i]], [preds[i]], knowledge_source="finterms") + score += out["score"] * out["num_facts_per_response"] + num_facts += out["num_facts_per_response"] + except: + out = fs.get_score([texts[i]], [preds[i]], knowledge_source="finterms") + score += out["score"] * out["num_facts_per_response"] + num_facts += out["num_facts_per_response"] + + #try: + #out = fs.get_score(texts, preds, knowledge_source="finterms") + #except: + #out = fs.get_score(texts, preds, knowledge_source="finterms") + + + #return out["score"] # FActScore + return score/num_facts + + def aggregation(self): + return { + "factscore": self.factscore, + } + + +class FINTERM(LongFormFactuality): + DATASET_PATH = "PIXIU-fin/en-finterm" + +class ACRONYM(QA): + DATASET_PATH = "PIXIU-fin/en-acronym" + + + + + +class ZHFinFE(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-fe" + + +class ZHFinNL(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-nl" + + +class ZHFinNL2(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-nl2" + + +class ZHFinNSP(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-nsp" + + +class ZHFinRE(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-re" + + +class ZHAFQMC(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-afqmc" + + +class ZHAstock(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-stocka" + + +class ZHBQcourse(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-corpus" + + +class ZHFinEval(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-fineval" + + +class ZHstock11(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-stockb" + +class ZHFinQA(QA): + DATASET_PATH = "ChanceFocus/flare-zh-qa" + + +class ZHFinNA(AbstractiveSummarization): + DATASET_PATH = "ChanceFocus/flare-zh-na" + + +class ZH21CCKS(RelationExtraction): + DATASET_PATH = "ChanceFocus/flare-zh-21ccks" + + def process_results(self, doc, results): + return { + "precision": (doc["answer"], results), + "recall": (doc["answer"], results), + "f1": (doc["answer"], results), + } + + def process_string_list(self, string_list): + processed_list = [] + + for item in string_list: + processed_item = item.strip() + processed_list.append(processed_item) + + return processed_list + + def process(self, items): + golds, preds = zip(*items) + + all_golds = [] + all_preds = [] + + for gold, pred in zip(golds, preds): + gold = str(gold).split("\n") + all_golds.extend(gold) + pred = self.process_string_list(pred) + all_preds.extend(pred) + + return set(all_golds), set(all_preds) + + +class ZH19CCKS(RelationExtraction): + VERSION = 1 + DATASET_PATH = "ChanceFocus/flare-zh-19ccks" + + def process_results(self, doc, results): + return { + "precision": (doc["answer"], results[0]), + "recall": (doc["answer"], results[0]), + "f1": (doc["answer"], results[0]), + } + + +class ZH20CCKS(ZH19CCKS): + DATASET_PATH = "ChanceFocus/flare-zh-20ccks" + + +class ZH22CCKS(ZH19CCKS): + DATASET_PATH = "ChanceFocus/flare-zh-22ccks" + + +class ZHNER(NER): + DATASET_PATH = "ChanceFocus/flare-zh-ner" + + def process_results(self, doc, results): + text = ' '.join(doc["text"]) + pred = process_zhtext(results[0], text) + + return {"entity_f1": (pred, doc["label"], results[0])} + + +class ZHFPB(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-fpb" + + +class ZHFIQASA(Classification): + DATASET_PATH = "ChanceFocus/flare-zh-fiqasa" + + +class ZHHeadlines(Headlines): + DATASET_PATH = "ChanceFocus/flare-zh-headlines" + + def process_results(self, doc, results): + gold = doc["gold"] + + return { + "avg_f1": (doc["answer"], int(results[0] != "是"), gold, results), + } + + +class ZHBigData(StockMovement): + DATASET_PATH = "ChanceFocus/flare-zh-bigdata" + CHOICE_DICT = { + "上涨": ["是", "正面", "积极", "肯定的"], + "下跌": ["否", "负面", "消极"], + } + DEFAULT = "下跌" + + +class ZHACL(ZHBigData): + DATASET_PATH = "ChanceFocus/flare-zh-acl" + + +class ZHCIKM(ZHBigData): + DATASET_PATH = "ChanceFocus/flare-zh-cikm" + + +class ZHFinQAE(QA): + DATASET_PATH = "ChanceFocus/flare-zh-finqa" + + +class ZHConvFinQA(ConvFinQA): + DATASET_PATH = "ChanceFocus/flare-zh-convfinqa" + +class ESMultiFin(Classification): + DATASET_PATH = "chancefocus/flare-es-multifin" + +class ESEFP(Classification): + DATASET_PATH = "chancefocus/flare-es-efp" + +class ESEFPA(Classification): + DATASET_PATH = "chancefocus/flare-es-efpa" + +class ESTSA(Classification): + DATASET_PATH = "chancefocus/flare-es-tsa" + +class ESFINANCEES(Classification): + DATASET_PATH = "chancefocus/flare-es-financees" + +class ESFNS(AbstractiveSummarization): + DATASET_PATH = "chancefocus/flare-es-fns" diff --git a/src/tasks/__init__.py b/src/tasks/__init__.py index e4bd6a1..c580290 100644 --- a/src/tasks/__init__.py +++ b/src/tasks/__init__.py @@ -17,6 +17,7 @@ "flare_fiqasa": flare.FIQASA, "flare_ner": flare.NER, "flare_finqa": flare.FinQA, + "flare_auditqa_zh": flare.AuditQA, "flare_convfinqa": flare.ConvFinQA, "flare_headlines": flare.Headlines, "flare_finer_ord": flare.FinerOrd, diff --git a/src/tasks/flare.py b/src/tasks/flare.py index 474c1d0..1f2849b 100644 --- a/src/tasks/flare.py +++ b/src/tasks/flare.py @@ -13,6 +13,8 @@ import re from factscore_package.factscorer import FactScorer import os +import jieba +from rouge_chinese import Rouge #from comet import download_model, load_from_checkpoint _CITATION = """ @@ -698,6 +700,130 @@ def aggregation(self): } +class QA_zh(Task): + VERSION = 1 + DATASET_NAME = None + EVAL_LAST_TURN = True + + def reformulate_turn_req(self, req, turn_request, turn): + return req + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def training_docs(self): + return self.dataset["train"] + + def validation_docs(self): + return self.dataset["validation"] + + def test_docs(self): + return self.dataset["test"] + + def doc_to_text(self, doc): + # TODO: Format the query prompt portion of the document example. + return doc["query"] + + def doc_to_target(self, doc): + return doc["answer"] + + def process_results(self, doc, results): + return { + "rouge1": (doc["answer"], results[0]), + "rouge2": (doc["answer"], results[0]), + "rougeL": (doc["answer"], results[0]), + "bert_score_f1": (doc["answer"], results[0]), + } + + def higher_is_better(self): + return { + "rouge1": True, + "rouge2": True, + "rougeL": True, + "bert_score_f1": True, + } + + def construct_requests(self, doc, ctx): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + """ + cont_request = rf.greedy_until(ctx, {"until": None}) + return cont_request + + def rouge_score_zh(self, items): + # Path to the dictionary for Chinese word segmentation + # Specifically designed for Chinese evaluation + # Set this to your local dictionary path. + USER_DICT_PATH = "/path/to/vocab.txt" + hyps, refs = map(list, zip(*[[' '.join(jieba.cut(d[0])), ' '.join(jieba.cut(d[1]))] for d in items])) + filter_hyps = [] + filter_refs = [] + for i in range(len(hyps)): + hyp = hyps[i] + ref = refs[i] + if self.is_whitespace_string(hyp) or self.is_whitespace_string(ref): + continue + if hyp != '' and ref != '': + filter_hyps.append(hyp) + filter_refs.append(ref) + rouge = Rouge() + scores = rouge.get_scores(filter_hyps, filter_refs, avg=True, ignore_empty=True) + return scores + + def rouge1(self, items): + results = self.rouge_score_zh(items) + return results["rouge-1"]['f'] + + def rouge2(self, items): + results = self.rouge_score_zh(items) + return results["rouge-2"]['f'] + + def rougeL(self, items): + results = self.rouge_score_zh(items) + return results["rouge-l"]['f'] + + def is_whitespace_string(self, s): + return s.isspace() + + def bert_score(self, items): + if getattr(self, "_cache_bertscore", None) is None: + golds, preds = zip(*items) + bertscore = evaluate.load("evaluate-metric/bertscore") + self._cache_bertscore = bertscore.compute( + predictions=preds, + references=golds, + model_type="bert-base-chinese", + ) + return self._cache_bertscore + else: + return self._cache_bertscore + + def bert_score_f1(self, items): + res = self.bert_score(items) + return sum(res["f1"]) / len(res["f1"]) + + def aggregation(self): + return { + "rouge1": self.rouge1, + "rouge2": self.rouge2, + "rougeL": self.rougeL, + "bert_score_f1": self.bert_score_f1, + } + + class FPB(Classification): DATASET_PATH = "chancefocus/flare-fpb" @@ -787,6 +913,13 @@ class FinQA(QA): DATASET_PATH = "chancefocus/flare-finqa" +class AuditQA(QA_zh): + # Path to the dataset for the AuditQA class + # Specifically designed for a Chinese dataset + # Set this to the appropriate path for the dataset + DATASET_PATH = "/path/to/dataset" + + class StockMovement(Classification): DATASET_NAME = None CALCULATE_MCC = True From 0c1dbd80b95a31568f7523b407c4c6c019f22105 Mon Sep 17 00:00:00 2001 From: lemonade1258 <1779716932@qq.com> Date: Wed, 23 Oct 2024 00:08:46 +0800 Subject: [PATCH 2/2] Remove .ipynb_checkpoints directory --- .../.ipynb_checkpoints/__init__-checkpoint.py | 132 -- .../.ipynb_checkpoints/flare-checkpoint.py | 1552 ----------------- 2 files changed, 1684 deletions(-) delete mode 100644 src/tasks/.ipynb_checkpoints/__init__-checkpoint.py delete mode 100644 src/tasks/.ipynb_checkpoints/flare-checkpoint.py diff --git a/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py b/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index c580290..0000000 --- a/src/tasks/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,132 +0,0 @@ -from pprint import pprint -from typing import List, Union - -import json -import lm_eval.base - -from . import flare - -TASK_REGISTRY = { - "flare_es_financees": flare.ESFINANCEES, - "flare_es_multifin": flare.ESMultiFin, - "flare_es_efp": flare.ESEFP, - "flare_es_efpa": flare.ESEFPA, - "flare_es_fns": flare.ESFNS, - "flare_es_tsa": flare.ESTSA, - "flare_fpb": flare.FPB, - "flare_fiqasa": flare.FIQASA, - "flare_ner": flare.NER, - "flare_finqa": flare.FinQA, - "flare_auditqa_zh": flare.AuditQA, - "flare_convfinqa": flare.ConvFinQA, - "flare_headlines": flare.Headlines, - "flare_finer_ord": flare.FinerOrd, - "flare_fomc": flare.FOMC, - "flare_german": flare.German, - "flare_australian": flare.Australian, - "flare_fomc": flare.FOMC, - "flare_ectsum": flare.ECTSUM, - "flare_edtsum": flare.EDTSUM, - "flare_finarg_ecc_auc": flare.FinargECCAUC, - "flare_finarg_ecc_arc": flare.FinargECCARC, - "flare_cd": flare.CD, - "flare_multifin_en": flare.MultiFinEN, - "flare_tsa": flare.TSA, - "flare_cfa": flare.CFA, - "flare_ma": flare.MA, - "flare_causal20_sc": flare.Causal20SC, - "flare_finarg_ecc_arc": flare.FINARGECCARC, - "flare_finarg_ecc_auc": flare.FINARGECCAUC, - "flare_mlesg": flare.MLESG, - "flare_fnxl": flare.FNXL, - "flare_fsrl": flare.FSRL, - "flare_tatqa": flare.TATQA, - "flare_finred": flare.FinRED, - "flare_cra_lendingclub": flare.lendingclub, - "flare_cra_ccf": flare.ccf, - "flare_cra_ccfraud": flare.ccfraud, - "flare_cra_polish": flare.polish, - "flare_cra_taiwan": flare.taiwan, - "flare_cra_portoseguro": flare.portoseguro, - "flare_cra_travelinsurace": flare.travelinsurace, - "flare_sm_bigdata": flare.StockMovementBigData, - "flare_sm_acl": flare.StockMovementACL, - "flare_sm_cikm": flare.StockMovementCIKM, - "flare_en_finterm": flare.FINTERM, - "flare_en_acronym": flare.ACRONYM, - **flare.SM_TASKS, - "flare_finarg_ecc_auc_test": flare.FINARGECCAUC_test, - "flare_edtsum_test": flare.EDTSUM_test, -} - -ALL_TASKS = sorted(list(TASK_REGISTRY)) - -_EXAMPLE_JSON_PATH = "split:key:/absolute/path/to/data.json" - - -def add_json_task(task_name): - """Add a JSON perplexity task if the given task name matches the - JSON task specification. - - See `json.JsonPerplexity`. - """ - if not task_name.startswith("json"): - return - - def create_json_task(): - splits = task_name.split("=", 1) - if len(splits) != 2 or not splits[1]: - raise ValueError( - "json tasks need a path argument pointing to the local " - "dataset, specified like this: json=" - + _EXAMPLE_JSON_PATH - + ' (if there are no splits, use "train")' - ) - - json_path = splits[1] - if json_path == _EXAMPLE_JSON_PATH: - raise ValueError( - "please do not copy the example path directly, but substitute " - "it with a path to your local dataset" - ) - return lambda: json.JsonPerplexity(json_path) - - TASK_REGISTRY[task_name] = create_json_task() - - -def get_task(task_name): - try: - add_json_task(task_name) - return TASK_REGISTRY[task_name] - except KeyError: - print("Available tasks:") - pprint(TASK_REGISTRY) - raise KeyError(f"Missing task {task_name}") - - -def get_task_name_from_object(task_object): - for name, class_ in TASK_REGISTRY.items(): - if class_ is task_object: - return name - - # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting - return ( - task_object.EVAL_HARNESS_NAME - if hasattr(task_object, "EVAL_HARNESS_NAME") - else type(task_object).__name__ - ) - - -def get_task_dict(task_name_list: List[Union[str, lm_eval.base.Task]]): - task_name_dict = { - task_name: get_task(task_name)() - for task_name in task_name_list - if isinstance(task_name, str) - } - task_name_from_object_dict = { - get_task_name_from_object(task_object): task_object - for task_object in task_name_list - if not isinstance(task_object, str) - } - assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys())) - return {**task_name_dict, **task_name_from_object_dict} diff --git a/src/tasks/.ipynb_checkpoints/flare-checkpoint.py b/src/tasks/.ipynb_checkpoints/flare-checkpoint.py deleted file mode 100644 index 1f2849b..0000000 --- a/src/tasks/.ipynb_checkpoints/flare-checkpoint.py +++ /dev/null @@ -1,1552 +0,0 @@ -""" -FLARE -""" -from lm_eval.base import Task, rf -from lm_eval.metrics import mean, bleu, chrf, ter -import numpy as np -from .utils import process_text -from .zhutils import process_zhtext -from seqeval.metrics import f1_score as entity_score -from sklearn.metrics import f1_score, matthews_corrcoef, mean_squared_error -from bart_score import BARTScorer -import evaluate -import re -from factscore_package.factscorer import FactScorer -import os -import jieba -from rouge_chinese import Rouge -#from comet import download_model, load_from_checkpoint - -_CITATION = """ -@misc{xie2023pixiu, - title={PIXIU: A Large Language Model, Instruction Data and Evaluation Benchmark for Finance}, - author={Qianqian Xie and Weiguang Han and Xiao Zhang and Yanzhao Lai and Min Peng and Alejandro Lopez-Lira and Jimin Huang}, - year={2023}, - eprint={2306.05443}, - archivePrefix={arXiv}, - primaryClass={cs.CL} -} -""" - - -class Classification(Task): - CALCULATE_MCC = True - LOWER_CASE = True - VERSION = 1 - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return True - - def has_validation_docs(self): - return True - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def doc_to_decontamination_query(self, doc): - return doc["text"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["answer"] - - def process_results(self, doc, results): - gold: str = doc["choices"][doc["gold"]] - if self.LOWER_CASE: - gold = gold.lower() - ini_result = results[0].strip() - if self.LOWER_CASE: - ini_result = ini_result.lower() - - result = None - for choice in doc["choices"]: - if self.LOWER_CASE: - choice = choice.lower() - if choice in ini_result: - result = choice - break - if result is None: - result = "missing" - - acc = 1.0 if gold == result else 0.0 - - results = { - "acc": acc, - "missing": int(result == "missing"), - "f1": (result, gold), - "macro_f1": (result, gold), - } - - if self.CALCULATE_MCC: - results["mcc"] = (result, gold) - - return results - - def higher_is_better(self): - metrics = { - "acc": True, - "f1": True, - "macro_f1": True, - "missing": False, - } - if self.CALCULATE_MCC: - metrics["mcc"] = True - return metrics - - def weighted_f1(self, items): - preds, golds = zip(*items) - labels = list(set(golds)) - preds = np.array(preds) - golds = np.array(golds) - f1 = f1_score(golds, preds, average="weighted", labels=labels) - return f1 - - def macro_f1(self, items): - preds, golds = zip(*items) - labels = list(set(golds)) - preds = np.array(preds) - golds = np.array(golds) - f1 = f1_score(golds, preds, average="macro", labels=labels) - return f1 - - def matthews_corrcoef(self, items): - preds, golds = zip(*items) - labels = {label: i for i, label in enumerate(list(set(golds)))} - preds = [labels.get(pred, -1) for pred in preds] - golds = [labels.get(gold, -1) for gold in golds] - return matthews_corrcoef(golds, preds) - - def aggregation(self): - metrics = { - "acc": mean, - "missing": mean, - "f1": self.weighted_f1, - "macro_f1": self.macro_f1, - } - if self.CALCULATE_MCC: - metrics["mcc"] = self.matthews_corrcoef - return metrics - - -class SequentialLabeling(Task): - VERSION = 1 - DATASET_NAME = None - LMAP = {"O": 0} - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return "\nAnswer: " + doc["answer"] - - def process_results(self, doc, results): - return { - "entity_f1": (doc["label"], results[0], doc["token"]), - "f1": (doc["label"], results[0], doc["token"]), - } - - def higher_is_better(self): - return { - "f1": True, - "entity_f1": True, - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def process_result(self, pred, gold, tokens): - format_pred = ["O"] * len(gold) - for index, pre in enumerate(pred.split("\n")[: len(tokens)]): - try: - word, label = pre.split(":") - except: - continue - if word == tokens[index] and label in self.LMAP.keys(): - format_pred[index] = label - return format_pred - - def entity_f1(self, items): - golds, preds, tokens = zip(*items) - - list_preds = [ - self.process_result(pred, gold, token) - for pred, gold, token in zip(preds, golds, tokens) - ] - f1 = entity_score(golds, list_preds) - return f1 - - def process_label_result(self, pred, gold, tokens): - format_pred = [-1] * len(gold) - for index, pre in enumerate(pred.split("\n")[: len(tokens)]): - try: - word, label = pre.split(":") - except: - continue - if word == tokens[index]: - format_pred[index] = self.LMAP.get(label, -1) - return format_pred - - def label_f1(self, items): - golds, preds, tokens = zip(*items) - - list_preds = [ - self.process_label_result(pred, gold, token) - for pred, gold, token in zip(preds, golds, tokens) - ] - list_preds = [item for sublist in list_preds for item in sublist] - golds = [self.LMAP[item] for sublist in golds for item in sublist] - f1 = f1_score(golds, list_preds, average="weighted") - return f1 - - def aggregation(self): - return { - "entity_f1": self.entity_f1, - "f1": self.label_f1, - } - - -class AbstractiveSummarization(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - return { - "rouge1": (doc["answer"], results[0]), - "rouge2": (doc["answer"], results[0]), - "rougeL": (doc["answer"], results[0]), - "bert_score_f1": (doc["answer"], results[0]), - "bart_score": (doc["answer"], results[0]), - } - - def higher_is_better(self): - return { - "rouge1": True, - "rouge2": True, - "rougeL": True, - "bert_score_f1": True, - "bart_score": True, - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def rouge_score(self, items): - golds, preds = zip(*items) - rouge = evaluate.load("rouge") - results = rouge.compute(predictions=preds, references=golds) - return results - - def rouge1(self, items): - results = self.rouge_score(items) - return results["rouge1"] - - def rouge2(self, items): - results = self.rouge_score(items) - return results["rouge2"] - - def rougeL(self, items): - results = self.rouge_score(items) - return results["rougeL"] - - def bert_score(self, items): - if getattr(self, "_cache_bertscore", None) is None: - golds, preds = zip(*items) - bertscore = evaluate.load("evaluate-metric/bertscore") - self._cache_bertscore = bertscore.compute( - predictions=preds, - references=golds, - model_type="bert-base-multilingual-cased", - ) - return self._cache_bertscore - else: - return self._cache_bertscore - - def bert_score_f1(self, items): - res = self.bert_score(items) - return sum(res["f1"]) / len(res["f1"]) - - def bart_score(self, items): - golds, preds = zip(*items) - bart_scorer = BARTScorer(device="cuda", checkpoint="facebook/bart-large-cnn") - bart_scorer.load(path="src/metrics/BARTScore/bart_score.pth") - res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) - return sum(res) / len(res) - - def aggregation(self): - return { - "rouge1": self.rouge1, - "rouge2": self.rouge2, - "rougeL": self.rougeL, - "bert_score_f1": self.bert_score_f1, - "bart_score": self.bart_score, - } - - -class ExtractiveSummarization(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - return { - "rouge1": (doc["label"], doc["text"], results[0]), - "rouge2": (doc["label"], doc["text"], results[0]), - "rougeL": (doc["label"], doc["text"], results[0]), - "bert_score_f1": (doc["label"], doc["text"], results[0]), - "bart_score": (doc["label"], doc["text"], results[0]), - } - - def higher_is_better(self): - return { - "rouge1": True, - "rouge2": True, - "rougeL": True, - "bert_score_f1": True, - "bart_score": True, - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def get_sum(self, labels, texts): - summ = [] - for label, text in zip(labels, texts): - text = text.split("\n") - new_text = "\n".join( - [ - text[index] - for index in range(len(text)) - if index < len(label) and label[index] == 1 - ] - ) - summ.append(new_text) - return summ - - def rouge_score(self, items): - golds, texts, preds = zip(*items) - golds = self.get_sum(golds, texts) - preds = self.get_sum([val.split("\n") for val in preds], texts) - rouge = evaluate.load("rouge") - results = rouge.compute(predictions=preds, references=golds) - return results - - def rouge1(self, items): - results = self.rouge_score(items) - return results["rouge1"] - - def rouge2(self, items): - results = self.rouge_score(items) - return results["rouge2"] - - def rougeL(self, items): - results = self.rouge_score(items) - return results["rougeL"] - - def bert_score(self, items): - if getattr(self, "_cache_bertscore", None) is None: - golds, texts, preds = zip(*items) - golds = self.get_sum(golds, texts) - preds = self.get_sum([val.split("\n") for val in preds], texts) - - bertscore = evaluate.load("evaluate-metric/bertscore") - self._cache_bertscore = bertscore.compute( - predictions=preds, - references=golds, - model_type="bert-base-multilingual-cased", - ) - return self._cache_bertscore - else: - return self._cache_bertscore - - def bert_score_f1(self, items): - res = self.bert_score(items) - return sum(res["f1"]) / len(res["f1"]) - - def bart_score(self, items): - golds, texts, preds = zip(*items) - golds = self.get_sum(golds, texts) - preds = self.get_sum([val.split("\n") for val in preds], texts) - - bart_scorer = BARTScorer(device="cuda:0", checkpoint="facebook/bart-large-cnn") - bart_scorer.load(path="src/metrics/BARTScore/bart_score.pth") - res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8) - return sum(res) / len(res) - - def aggregation(self): - return { - "rouge1": self.rouge1, - "rouge2": self.rouge2, - "rougeL": self.rougeL, - "bert_score_f1": self.bert_score_f1, - "bart_score": self.bart_score, - } - - -class RelationExtraction(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - return { - "precision": (doc["label"], results[0]), - "recall": (doc["label"], results[0]), - "f1": (doc["label"], results[0]), - } - - def higher_is_better(self): - return { - "precision": True, - "recall": True, - "f1": True, - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def process(self, items): - golds, preds = zip(*items) - - all_golds = [] - all_preds = [] - - for gold, pred in zip(golds, preds): - all_golds.extend(gold) - pred = pred.split("\n") - all_preds.extend(pred) - - return set(all_golds), set(all_preds) - - def precision(self, items): - golds, preds = self.process(items) - tp = golds & preds - prec = len(tp) / len(preds) - return prec - - def recall(self, items): - golds, preds = self.process(items) - tp = golds & preds - rec = len(tp) / len(golds) - return rec - - def cal_f1(self, items): - prec = self.precision(items) - rec = self.recall(items) - if prec + rec == 0.0: - return 0.0 - return 2 * (prec * rec) / (prec + rec) - - def aggregation(self): - return { - "precision": self.precision, - "recall": self.recall, - "f1": self.cal_f1, - } - - -class QA(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return True - - def has_validation_docs(self): - return True - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def should_decontaminate(self): - return True - - def doc_to_decontamination_query(self, doc): - return doc["text"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - gold = doc["answer"] - - acc = 1.0 if results[0].strip() == gold else 0.0 - - return { - "acc": acc, - } - - def higher_is_better(self): - return { - "acc": True, - } - - def aggregation(self): - return { - "acc": mean, - } - - -class QA_zh(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - return { - "rouge1": (doc["answer"], results[0]), - "rouge2": (doc["answer"], results[0]), - "rougeL": (doc["answer"], results[0]), - "bert_score_f1": (doc["answer"], results[0]), - } - - def higher_is_better(self): - return { - "rouge1": True, - "rouge2": True, - "rougeL": True, - "bert_score_f1": True, - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def rouge_score_zh(self, items): - # Path to the dictionary for Chinese word segmentation - # Specifically designed for Chinese evaluation - # Set this to your local dictionary path. - USER_DICT_PATH = "/path/to/vocab.txt" - hyps, refs = map(list, zip(*[[' '.join(jieba.cut(d[0])), ' '.join(jieba.cut(d[1]))] for d in items])) - filter_hyps = [] - filter_refs = [] - for i in range(len(hyps)): - hyp = hyps[i] - ref = refs[i] - if self.is_whitespace_string(hyp) or self.is_whitespace_string(ref): - continue - if hyp != '' and ref != '': - filter_hyps.append(hyp) - filter_refs.append(ref) - rouge = Rouge() - scores = rouge.get_scores(filter_hyps, filter_refs, avg=True, ignore_empty=True) - return scores - - def rouge1(self, items): - results = self.rouge_score_zh(items) - return results["rouge-1"]['f'] - - def rouge2(self, items): - results = self.rouge_score_zh(items) - return results["rouge-2"]['f'] - - def rougeL(self, items): - results = self.rouge_score_zh(items) - return results["rouge-l"]['f'] - - def is_whitespace_string(self, s): - return s.isspace() - - def bert_score(self, items): - if getattr(self, "_cache_bertscore", None) is None: - golds, preds = zip(*items) - bertscore = evaluate.load("evaluate-metric/bertscore") - self._cache_bertscore = bertscore.compute( - predictions=preds, - references=golds, - model_type="bert-base-chinese", - ) - return self._cache_bertscore - else: - return self._cache_bertscore - - def bert_score_f1(self, items): - res = self.bert_score(items) - return sum(res["f1"]) / len(res["f1"]) - - def aggregation(self): - return { - "rouge1": self.rouge1, - "rouge2": self.rouge2, - "rougeL": self.rougeL, - "bert_score_f1": self.bert_score_f1, - } - - -class FPB(Classification): - DATASET_PATH = "chancefocus/flare-fpb" - - -class FIQASA(Classification): - DATASET_PATH = "chancefocus/flare-fiqasa" - - -class NER(Task): - VERSION = 1 - DATASET_PATH = "chancefocus/flare-ner" - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return True - - def has_validation_docs(self): - return True - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def should_decontaminate(self): - return True - - def doc_to_decontamination_query(self, doc): - return doc["text"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - text = doc["text"] - pred = process_text(results[0], text) - - return {"entity_f1": (pred, doc["label"], results[0])} - - def higher_is_better(self): - return { - "entity_f1": True, - } - - @classmethod - def entity_f1(cls, items): - preds, golds, _ = zip(*items) - f1 = entity_score(golds, preds) - return f1 - - def aggregation(self): - return { - "entity_f1": self.entity_f1, - } - - -class FinQA(QA): - DATASET_PATH = "chancefocus/flare-finqa" - - -class AuditQA(QA_zh): - # Path to the dataset for the AuditQA class - # Specifically designed for a Chinese dataset - # Set this to the appropriate path for the dataset - DATASET_PATH = "/path/to/dataset" - - -class StockMovement(Classification): - DATASET_NAME = None - CALCULATE_MCC = True - CHOICE_DICT = { - "rise": ["yes", "positive"], - "fall": ["no", "negative", "neutral"], - } - DEFAULT = "fall" - - def process_results(self, doc, results): - gold: str = doc["choices"][doc["gold"]] - if self.LOWER_CASE: - gold = gold.lower() - ini_result = results[0].strip() - if self.LOWER_CASE: - ini_result = ini_result.lower() - - result = None - for choice in doc["choices"]: - if self.LOWER_CASE: - choice = choice.lower() - if choice in ini_result or any( - [val in ini_result for val in self.CHOICE_DICT[choice]] - ): - result = choice - break - if result is None: - result = self.DEFAULT - - acc = 1.0 if gold == result else 0.0 - - results = { - "acc": acc, - "missing": int(result == "missing"), - "f1": (result, gold), - "macro_f1": (result, gold), - } - - if self.CALCULATE_MCC: - results["mcc"] = (result, gold) - - return results - - -class StockMovementBigData(StockMovement): - DATASET_PATH = "chancefocus/flare-sm-bigdata" - - -class StockMovementACL(StockMovement): - DATASET_PATH = "chancefocus/flare-sm-acl" - - -class StockMovementCIKM(StockMovement): - DATASET_PATH = "chancefocus/flare-sm-cikm" - - -SM_TASKS = { - "flare_sm_bigdata": StockMovementBigData, - "flare_sm_acl": StockMovementACL, - "flare_sm_cikm": StockMovementCIKM, -} - - -class Headlines(Classification): - DATASET_PATH = "chancefocus/flare-headlines" - - def process_results(self, doc, results): - gold = doc["gold"] - - return { - "avg_f1": (doc["label_type"], int(results[0].strip() != "Yes"), gold, results), - } - - def higher_is_better(self): - return { - "avg_f1": True, - } - - @classmethod - def label_avg(cls, items): - labels, preds, golds, rels = zip(*items) - label_set = set(labels) - labels = np.array(labels) - preds = np.array(preds) - golds = np.array(golds) - all_f1s = [] - for l in label_set: - pds = preds[labels == l] - gds = golds[labels == l] - f1 = f1_score(gds, pds, average="weighted", labels=[0, 1]) - all_f1s.append(f1) - return np.mean(all_f1s) - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def aggregation(self): - return { - "avg_f1": self.label_avg, - } - - -class FinerOrd(SequentialLabeling): - DATASET_PATH = "chancefocus/flare-finer-ord" - LMAP = { - "O": 0, - "B-PER": 1, - "I-PER": 2, - "B-LOC": 3, - "I-LOC": 4, - "B-ORG": 5, - "I-ORG": 6, - } - - -class FOMC(Classification): - DATASET_PATH = "chancefocus/flare-fomc" - - -class German(StockMovement): - DATASET_PATH = "chancefocus/flare-german" - CHOICE_DICT = { - "good": ["yes", "positive"], - "bad": ["no", "negative", "neutral"], - } - DEFAULT = "good" - - -class Australian(StockMovement): - DATASET_PATH = "chancefocus/flare-australian" - CHOICE_DICT = { - "good": ["yes", "positive"], - "bad": ["no", "negative", "neutral"], - } - DEFAULT = "good" - - -class ECTSUM(ExtractiveSummarization): - DATASET_PATH = "chancefocus/flare-ectsum" - - -class EDTSUM(AbstractiveSummarization): - DATASET_PATH = "chancefocus/flare-edtsum" - - -class EDTSUM_test(AbstractiveSummarization): - DATASET_PATH = "TheFinAI/flare-edtsum_test" - - -class ConvFinQA(QA): - DATASET_PATH = "chancefocus/flare-convfinqa" - - def reformulate_turn_req(self, req, turn_request, turn): - if turn == 0: - return req - pre_answers = {f"answer{i}": turn_request[i][0] for i in range(turn)} - if pre_answers: - req.args = tuple([req.args[0].format(**pre_answers)] + list(req.args[1:])) - return req - - -class TSA(Task): - VERSION = 1 - DATASET_PATH = "chancefocus/flare-tsa" - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return "\nAnswer: " + str(doc["answer"]) - - def process_results(self, doc, results): - pred = results[0].split("\n")[0] - pred = re.findall(r'[0-9]+(?:\.[0-9]+)?', pred) - missing = 0 - if not pred: - pred = -100.0 - missing = 1 - else: - pred = pred[0] - pred = float(pred) - return { - "rmse": (doc["answer"], pred), - "missing": missing - } - - def higher_is_better(self): - return { - "rmse": False, - } - - def construct_requests(self, doc, ctx): - """ - Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": "Answer:"}) - return cont_request - - def rmse(self, items): - golds, preds = zip(*items) - fgolds, fpreds = [], [] - for gold, pred in zip(golds, preds): - if pred == -100.0: - continue - fgolds.append(gold) - fpreds.append(max(min(pred, 1.0), -1.0)) - rmse = mean_squared_error(fgolds, fpreds, squared=True) - - return rmse - - def aggregation(self): - return { - "rmse": self.rmse, - "missing": mean, - } - - - -class CFA(Classification): - DATASET_PATH = "chancefocus/flare-cfa" - LOWER_CASE = False - - -class FINARGECCARC(Classification): - DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" - - -class FINARGECCAUC(Classification): - DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" - - -class FINARGECCAUC_test(Classification): - DATASET_PATH = "TheFinAI/flare-finarg-ecc-auc_test" - - -class MLESG(Classification): - DATASET_PATH = "chancefocus/flare-mlesg" - - -class FSRL(SequentialLabeling): - DATASET_PATH = "chancefocus/flare-fsrl" - LMAP = {key: index for index, key in enumerate(['O', 'I-QUANT', 'B-QUANT', 'I-TIME', 'B-TIME', 'I-MANNER', 'B-MANNER', 'I-THEME', 'B-THEME', 'I-VALUE', 'B-VALUE', 'I-WHOLE', 'B-WHOLE', 'I-LOCATION', 'B-LOCATION', 'I-AGENT', 'B-AGENT', 'I-CAUSE', 'B-CAUSE', 'I-SOURCE', 'B-SOURCE', 'I-REF_TIME', 'B-REF_TIME', 'I-CONDITION', 'B-CONDITION'])} - -class CFA(Classification): - DATASET_PATH = "chancefocus/flare-cfa" - -class FinargECCAUC(Classification): - DATASET_PATH = "chancefocus/flare-finarg-ecc-auc" - -class FinargECCARC(Classification): - DATASET_PATH = "chancefocus/flare-finarg-ecc-arc" - -class CD(SequentialLabeling): - DATASET_PATH = "chancefocus/flare-cd" - LMAP = {key: index for index, key in enumerate(['O', 'I-CAUSE', 'B-CAUSE', 'I-EFFECT', 'B-EFFECT'])} - -class MultiFinEN(Classification): - DATASET_PATH = "chancefocus/flare-multifin-en" - -class MA(Classification): - DATASET_PATH = "chancefocus/flare-ma" - -class Causal20SC(Classification): - DATASET_PATH = "chancefocus/flare-causal20-sc" - -class FNXL(SequentialLabeling): - DATASET_PATH = "chancefocus/flare-fnxl" - LMAP = {'B-BusinessCombinationContingentConsiderationArrangementsRangeOfOutcomesValueHigh': 140, 'B-VariableInterestEntityOwnershipPercentage': 646, 'B-GainLossOnDispositionOfAssets1': 119, 'B-IndefiniteLivedIntangibleAssetsExcludingGoodwill': 46, 'B-MarketingAndAdvertisingExpense': 269, 'B-ReportingUnitPercentageOfFairValueInExcessOfCarryingAmount': 142, 'B-CapitalizedComputerSoftwareNet': 91, 'B-BusinessCombinationConsiderationTransferredEquityInterestsIssuedAndIssuable': 183, 'B-LitigationSettlementExpense': 115, 'B-DefinedBenefitPlanExpectedAmortizationOfGainLossNextFiscalYear': 639, 'B-DeferredCompensationArrangementWithIndividualCompensationExpense': 15, 'B-ReclassificationFromAociCurrentPeriodTax': 152, 'B-OtherComprehensiveIncomeLossBeforeReclassificationsTax': 694, 'B-PreferredStockDividendsPerShareDeclared': 236, 'B-CapitalExpendituresIncurredButNotYetPaid': 344, 'B-DeferredCompensationArrangementWithIndividualContributionsByEmployer': 560, 'B-SeveranceCosts1': 311, 'B-InterestExpense': 784, 'B-SaleOfStockConsiderationReceivedOnTransaction': 76, 'B-LineOfCreditFacilityInterestRateAtPeriodEnd': 822, 'B-SharesIssuedPricePerShare': 137, 'B-EquityMethodInvestmentDifferenceBetweenCarryingAmountAndUnderlyingEquity': 63, 'B-EquitySecuritiesFvNi': 30, 'B-RightOfUseAssetObtainedInExchangeForOperatingLeaseLiability': 118, 'B-DefinedBenefitPlanFundedStatusOfPlan': 547, 'B-SharebasedCompensationArrangementBySharebasedPaymentAwardPurchasePriceOfCommonStockPercent': 323, 'B-TaxCutsAndJobsActOf2017IncomeTaxExpenseBenefit': 256, 'B-LongtermDebtWeightedAverageInterestRate': 364, 'B-ImpairmentOfIntangibleAssetsFinitelived': 71, 'B-ProceedsFromLinesOfCredit': 496, 'B-LongTermPurchaseCommitmentAmount': 701, 'B-DebtInstrumentFairValue': 335, 'B-RestructuringAndRelatedCostCostIncurredToDate1': 52, 'B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsVestedInPeriod': 581, 'B-FiniteLivedIntangibleAssetsAccumulatedAmortization': 143, 'B-StockRepurchasedAndRetiredDuringPeriodValue': 330, 'B-BusinessCombinationProFormaInformationRevenueOfAcquireeSinceAcquisitionDateActual': 77, 'B-ClassOfWarrantOrRightExercisePriceOfWarrantsOrRights1': 361, 'B-BusinessAcquisitionPurchasePriceAllocationGoodwillExpectedTaxDeductibleAmount': 550, 'B-OperatingLossCarryforwardsValuationAllowance': 173, 'B-BusinessAcquisitionEquityInterestsIssuedOrIssuableNumberOfSharesIssued': 32, 'B-DefinedContributionPlanMaximumAnnualContributionsPerEmployeePercent': 45, 'B-ContractWithCustomerLiabilityCurrent': 2, 'B-IncomeLossFromContinuingOperationsBeforeIncomeTaxesForeign': 474, 'B-FiniteLivedIntangibleAssetsAmortizationExpenseYearThree': 1306, 'B-DefinedBenefitPlanUltimateHealthCareCostTrendRate1': 62, 'B-DefinedBenefitPlanRecognizedNetGainLossDueToSettlements1': 317, 'B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesExpense': 448, 'B-ForeignCurrencyTransactionGainLossRealized': 132, 'B-DeferredTaxAssetsOperatingLossCarryforwardsSubjectToExpiration': 262, 'B-RetainedEarningsAccumulatedDeficit': 174, 'B-ProceedsFromIssuanceOfCommonStock': 209, 'B-EmployeeServiceShareBasedCompensationAllocationOfRecognizedPeriodCostsCapitalizedAmount': 29, 'B-OtherComprehensiveIncomeLossPensionAndOtherPostretirementBenefitPlansTax': 284, 'B-InventoryWriteDown': 465, 'B-RestructuringReserve': 234, 'B-LitigationSettlementAmountAwardedToOtherParty': 42, 'B-DerivativeGainLossOnDerivativeNet': 87, 'B-SharebasedCompensationArrangementBySharebasedPaymentAwardEquityInstrumentsOtherThanOptionsAggregateIntrinsicValueVested': 241, 'B-DerivativeFixedInterestRate': 589, 'B-CashAndCashEquivalentsAtCarryingValue': 257, 'B-ContractWithCustomerAssetNet': 245, 'B-RestructuringAndRelatedCostExpectedCost1': 107, 'B-IncomeTaxHolidayAggregateDollarAmount': 347, 'B-OperatingLeaseCost': 248, 'B-AllowanceForDoubtfulAccountsReceivable': 146, 'B-RepaymentsOfDebt': 416, 'B-InterestPaid': 110, 'B-DeferredFinanceCostsNet': 28, 'B-IncomeTaxExaminationPenaltiesAndInterestAccrued': 271, 'B-ShareBasedCompensationArrangementByShareBasedPaymentAwardEquityInstrumentsOtherThanOptionsNonvestedNumber': 92, 'B-CapitalizedContractCostNet': 155, 'B-CumulativeEffectOfNewAccountingPrincipleInPeriodOfAdoption': 17, 'B-IncomeTaxesPaid': 495, 'B-EquityMethodInvestmentOtherThanTemporaryImpairment': 22, 'B-InterestPaidNet': 225, 'B-EquitySecuritiesWithoutReadilyDeterminableFairValueAmount': 175, 'B-ImpairmentOfLongLivedAssetsHeldForUse': 313, 'B-GoodwillAcquiredDuringPeriod': 156, 'B-DecreaseInUnrecognizedTaxBenefitsIsReasonablyPossible': 363, 'B-RestructuringAndRelatedCostIncurredCost': 75, 'B-StockRepurchasedDuringPeriodValue': 254, 'B-IncomeTaxExaminationPenaltiesAndInterestExpense': 525, 'B-ImpairmentOfIntangibleAssetsIndefinitelivedExcludingGoodwill': 55, 'B-PreferredStockLiquidationPreference': 157, 'B-ImpairmentOfIntangibleAssetsExcludingGoodwill': 158, 'B-IncomeTaxesPaidNet': 456, 'B-DefinedContributionPlanEmployerMatchingContributionPercent': 332, 'B-CostOfGoodsAndServicesSold': 274, 'B-DepreciationDepletionAndAmortization': 338, 'B-InterestExpenseDebt': 191, 'B-LineOfCreditFacilityUnusedCapacityCommitmentFeePercentage': 442, 'B-DisposalGroupIncludingDiscontinuedOperationConsideration': 6, 'B-UnrecognizedTaxBenefitsInterestOnIncomeTaxesAccrued': 14, 'B-SaleOfStockPricePerShare': 278, 'B-DefinedContributionPlanEmployerMatchingContributionPercentOfMatch': 267, 'B-FinitelivedIntangibleAssetsAcquired1': 202, 'B-PaymentsForRepurchaseOfCommonStock': 486, 'B-BusinessCombinationContingentConsiderationLiability': 103, 'B-RelatedPartyTransactionAmountsOfTransaction': 180, 'O': 0} - -class TATQA(QA): - DATASET_PATH = "chancefocus/flare-tatqa" - - -class FinRED(RelationExtraction): - DATASET_PATH = "chancefocus/flare-finred" - - -class lendingclub(Classification): - DATASET_PATH = "chancefocus/cra-lendingclub" - CALCULATE_MCC = True - - -class ccf(Classification): - DATASET_PATH = "chancefocus/cra-ccf" - CALCULATE_MCC = True - - -class ccfraud(Classification): - DATASET_PATH = "chancefocus/cra-ccfraud" - CALCULATE_MCC = True - - -class polish(Classification): - DATASET_PATH = "chancefocus/cra-polish" - CALCULATE_MCC = True - - -class taiwan(Classification): - DATASET_PATH = "chancefocus/cra-taiwan" - CALCULATE_MCC = True - - -class portoseguro(Classification): - DATASET_PATH = "chancefocus/cra-portoseguro" - CALCULATE_MCC = True - - -class travelinsurace(Classification): - DATASET_PATH = "chancefocus/cra-travelinsurace" - CALCULATE_MCC = True - - - -class LongFormFactuality(Task): - VERSION = 1 - DATASET_NAME = None - EVAL_LAST_TURN = True - - def reformulate_turn_req(self, req, turn_request, turn): - return req - - def has_training_docs(self): - return False - - def has_validation_docs(self): - return False - - def has_test_docs(self): - return True - - def training_docs(self): - return self.dataset["train"] - - def validation_docs(self): - return self.dataset["validation"] - - def test_docs(self): - return self.dataset["test"] - - def doc_to_text(self, doc): - # TODO: Format the query prompt portion of the document example. - return doc["query"] - - def doc_to_target(self, doc): - return doc["answer"] - - def process_results(self, doc, results): - return { - "factscore": (doc["answer"], doc["text"], results[0]), - } - - def higher_is_better(self): - return { - "factscore": True - } - - def construct_requests(self, doc, ctx): - """Uses RequestFactory to construct Requests and returns an iterable of - Requests which will be sent to the LM. - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. - :param ctx: str - The context string, generated by fewshot_context. This includes the natural - language description, as well as the few shot examples, and the question - part of the document for `doc`. - """ - cont_request = rf.greedy_until(ctx, {"until": None}) - return cont_request - - def factscore(self, items): - golds, texts, preds = zip(*items) - texts = list(texts) - preds = list(preds) - - fs = FactScorer("retrieval+ChatGPT", openai_key=os.environ["OPENAI_API_KEY"]) - - fs.register_knowledge_source("finterms", data_path="./src/factscore_package/.cache/finterms.jsonl", db_path="./src/factscore_package/.cache/fin_terms.db") - - score = 0 - num_facts = 0 - # now, when you compute a score, specify knowledge source to use - for i in range(len(texts)): - try: - out = fs.get_score([texts[i]], [preds[i]], knowledge_source="finterms") - score += out["score"] * out["num_facts_per_response"] - num_facts += out["num_facts_per_response"] - except: - out = fs.get_score([texts[i]], [preds[i]], knowledge_source="finterms") - score += out["score"] * out["num_facts_per_response"] - num_facts += out["num_facts_per_response"] - - #try: - #out = fs.get_score(texts, preds, knowledge_source="finterms") - #except: - #out = fs.get_score(texts, preds, knowledge_source="finterms") - - - #return out["score"] # FActScore - return score/num_facts - - def aggregation(self): - return { - "factscore": self.factscore, - } - - -class FINTERM(LongFormFactuality): - DATASET_PATH = "PIXIU-fin/en-finterm" - -class ACRONYM(QA): - DATASET_PATH = "PIXIU-fin/en-acronym" - - - - - -class ZHFinFE(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-fe" - - -class ZHFinNL(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-nl" - - -class ZHFinNL2(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-nl2" - - -class ZHFinNSP(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-nsp" - - -class ZHFinRE(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-re" - - -class ZHAFQMC(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-afqmc" - - -class ZHAstock(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-stocka" - - -class ZHBQcourse(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-corpus" - - -class ZHFinEval(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-fineval" - - -class ZHstock11(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-stockb" - -class ZHFinQA(QA): - DATASET_PATH = "ChanceFocus/flare-zh-qa" - - -class ZHFinNA(AbstractiveSummarization): - DATASET_PATH = "ChanceFocus/flare-zh-na" - - -class ZH21CCKS(RelationExtraction): - DATASET_PATH = "ChanceFocus/flare-zh-21ccks" - - def process_results(self, doc, results): - return { - "precision": (doc["answer"], results), - "recall": (doc["answer"], results), - "f1": (doc["answer"], results), - } - - def process_string_list(self, string_list): - processed_list = [] - - for item in string_list: - processed_item = item.strip() - processed_list.append(processed_item) - - return processed_list - - def process(self, items): - golds, preds = zip(*items) - - all_golds = [] - all_preds = [] - - for gold, pred in zip(golds, preds): - gold = str(gold).split("\n") - all_golds.extend(gold) - pred = self.process_string_list(pred) - all_preds.extend(pred) - - return set(all_golds), set(all_preds) - - -class ZH19CCKS(RelationExtraction): - VERSION = 1 - DATASET_PATH = "ChanceFocus/flare-zh-19ccks" - - def process_results(self, doc, results): - return { - "precision": (doc["answer"], results[0]), - "recall": (doc["answer"], results[0]), - "f1": (doc["answer"], results[0]), - } - - -class ZH20CCKS(ZH19CCKS): - DATASET_PATH = "ChanceFocus/flare-zh-20ccks" - - -class ZH22CCKS(ZH19CCKS): - DATASET_PATH = "ChanceFocus/flare-zh-22ccks" - - -class ZHNER(NER): - DATASET_PATH = "ChanceFocus/flare-zh-ner" - - def process_results(self, doc, results): - text = ' '.join(doc["text"]) - pred = process_zhtext(results[0], text) - - return {"entity_f1": (pred, doc["label"], results[0])} - - -class ZHFPB(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-fpb" - - -class ZHFIQASA(Classification): - DATASET_PATH = "ChanceFocus/flare-zh-fiqasa" - - -class ZHHeadlines(Headlines): - DATASET_PATH = "ChanceFocus/flare-zh-headlines" - - def process_results(self, doc, results): - gold = doc["gold"] - - return { - "avg_f1": (doc["answer"], int(results[0] != "是"), gold, results), - } - - -class ZHBigData(StockMovement): - DATASET_PATH = "ChanceFocus/flare-zh-bigdata" - CHOICE_DICT = { - "上涨": ["是", "正面", "积极", "肯定的"], - "下跌": ["否", "负面", "消极"], - } - DEFAULT = "下跌" - - -class ZHACL(ZHBigData): - DATASET_PATH = "ChanceFocus/flare-zh-acl" - - -class ZHCIKM(ZHBigData): - DATASET_PATH = "ChanceFocus/flare-zh-cikm" - - -class ZHFinQAE(QA): - DATASET_PATH = "ChanceFocus/flare-zh-finqa" - - -class ZHConvFinQA(ConvFinQA): - DATASET_PATH = "ChanceFocus/flare-zh-convfinqa" - -class ESMultiFin(Classification): - DATASET_PATH = "chancefocus/flare-es-multifin" - -class ESEFP(Classification): - DATASET_PATH = "chancefocus/flare-es-efp" - -class ESEFPA(Classification): - DATASET_PATH = "chancefocus/flare-es-efpa" - -class ESTSA(Classification): - DATASET_PATH = "chancefocus/flare-es-tsa" - -class ESFINANCEES(Classification): - DATASET_PATH = "chancefocus/flare-es-financees" - -class ESFNS(AbstractiveSummarization): - DATASET_PATH = "chancefocus/flare-es-fns"