Paper: https://arxiv.org/pdf/2407.06041
Follow installation instructions in notebooks/installation_instruction.ipynb
lcquad1:
python3 code/generate_train_csv.py \
-i datasets/lcquad1/train-data.json \
-o datasets/lcquad1/train-data \
-t lcquad1 \
-l all \
--linguistic_context \
--entity_knowledge \
--question_padding_length 128 \
--entity_padding_length 64 \
--train_split_percent 90
lcquad2:
python3 code/generate_train_csv.py \
-i datasets/lcquad2/train.json \
-o datasets/lcquad2/train-lc-ent \
-t lcquad2 \
-l all \
--linguistic_context \
--entity_knowledge \
--question_padding_length 128 \
--entity_padding_length 64 \
--train_split_percent 90
qald dbpedia:
python3 code/generate_train_csv.py \
-i datasets/qald9plus/dbpedia/qald_9_plus_train_dbpedia.json \
-o datasets/qald9plus/dbpedia/qald_9_plus_train_dbpedia-lc-ent \
-t qald \
-kg DBpedia \
-l all \
--linguistic_context \
--entity_knowledge \
--question_padding_length 128 \
--entity_padding_length 64 \
--train_split_percent 90
qald wikidata:
python3 code/generate_train_csv.py \
-i datasets/qald9plus/wikidata/qald_9_plus_train_wikidata.json \
-o datasets/qald9plus/wikidata/qald_9_plus_train_wikidata-lc-ent \
-t qald \
-kg Wikidata \
-l all \
--linguistic_context \
--entity_knowledge \
--question_padding_length 128 \
--entity_padding_length 64 \
--train_split_percent 90
Note: The dev dataset is made noisy by default and is meant only for evaluating the loss. For evaluating the QA system, please look here Evaluation
train.sh
is used to train with DeepSpeed (by default it uses deepspeed/ds_config_zero2.json
, if you face CUDA out-of-memory issue, try reducing batch-size and/or switching to deepspeed/ds_config_zero3.json
)
Note: THe gradient accumulation is set to 4, which means that for each step the model sees 4x provided batch size.
Please provide arguments in the following order to the training script:
- PORT : Port to be used by deepspeed
- MODEL_NAME : Name of the model to fine-tune
- TRAIN_FILE : Path to the training file
- EVAL_FILE : Path to the eval file (provide "false" to disable eval logic)
- OUTPUT_DIR : Output directory to save the fine-tuned model (and checkpoints)
- RUN_NAME: Name of the run to be used for wandb
- TRAIN_EPOCHS : Number of epochs to train
- BATCH_SIZE : Batch size per device
- SAVE_STEPS: Interval in training steps to save the model checkpoints
Following are sample usages of the training scripts:
bash train.sh 60020 "google/mt5-xl" datasets/lcquad1/train-data_train_90pct.csv datasets/lcquad1/train-data_dev_10pct.csv fine-tuned_models/lcquad1-finetune_mt5-base_lc-ent lcquad1-finetune_mt5-base_lc-ent 32 32 1000
bash train.sh 60030 fine-tuned_models/lcquad1-finetune_mt5-base_lc-ent datasets/qald9plus/dbpedia/qald_9_plus_train_dbpedia-lc-ent_train_90pct.csv datasets/qald9plus/dbpedia/qald_9_plus_train_dbpedia-lc-ent_dev_10pct.csv fine-tuned_models/qald9plus-finetune_lcquad1-ft-base_lc-ent qald9plus-finetune_lcquad1-ft-base_lc-ent 32 32 1000
bash train.sh 60000 "google/mt5-xl" datasets/lcquad2/train-lc-ent_train_90pct.csv datasets/lcquad2/train-lc-ent_dev_10pct.csv fine-tuned_models/lcquad2-finetune_mt5-base_lc-ent lcquad2-finetune_mt5-base_lc-ent 15 32 1000
bash train.sh 60010 fine-tuned_models/lcquad2-finetune_mt5-base_lc-ent datasets/qald9plus/wikidata/qald_9_plus_train_wikidata-lc-ent_train_90pct.csv datasets/qald9plus/wikidata/qald_9_plus_train_wikidata-lc-ent_dev_10pct.csv fine-tuned_models/qald9plus-finetune_lcquad2-ft-base_lc-ent qald9plus-finetune_lcquad2-ft-base_lc-ent 32 32 1000
Please provide arguments in the following order to the evaluation script:
- MODEL_ROOT_DIR : Directory where fine-tuned models are stored: fine-tuned_models
- MODEL_NAME : Name of the model: qald9plus-finetune_mt5-base_lc-ent
- TEST_FILE : Path to the qald test file
- OUTPUT_DIR : Root directory where a new directory with model_name will be created to store predictions and results
- LANGS : Comma separated values e.g: en,de,es
- LC : linguistic context : true/false
- EK : entity knowledge : true/false
- GE: Gerbil Evaluation : true/false
- KNOWLEDGE_GRAPH: Knowledge Graph : DBpedia/Wikidata
Sample usage:
bash eval.sh fine-tuned_models qald9plus-finetune_lcquad2-ft-base_lc-ent datasets/qald9plus/wikidata/qald_9_plus_test_wikidata.json predictions_qald9plus_test "en,de,ru,zh" true true true Wikidata
You can download the following fine-tuned models that can be used out-of-the-box with the deployment script:
- Wikidata-based model trained on LC-QuAD2.0 and Qald_9_Plus (train & test):
wget -r -nH --cut-dirs=3 --no-parent --reject="index.html*" https://files.dice-research.org/projects/MST5/fine-tuned-models/qald9plus-finetune_lcquad2-ft-base_lc-ent_testeval/
- Wikidata-based model trained on LC-QuAD2.0 and Qald_9_Plus (train):
wget -r -nH --cut-dirs=3 --no-parent --reject="index.html*" https://files.dice-research.org/projects/MST5/fine-tuned-models/qald9plus-finetune_lcquad2-ft-base_lc-ent/
- DBpedia-based model trained on LC-QuAD1.0 and Qald_9_Plus (train):
wget -r -nH --cut-dirs=3 --no-parent --reject="index.html*" https://files.dice-research.org/projects/MST5/fine-tuned-models/qald9plus-finetune_lcquad1-ft-base_lc-ent/
To deploy the model as a RESTful service, deploy_model.py can be used:
python deploy_model.py --model fine-tuned_models/qald9plus-finetune_lcquad2-ft-base_lc-ent \
--knowledge_graph Wikidata \
--linguistic_context \
--entity_knowledge \
--question_padding_length 128 \
--entity_padding_length 64 \
--port 8181 \
--log_file logs/server-mst5-wiki.log
Note: For GPU-based hardware acceleration, set the relevant device in Text_Generator.py. To enable CPU-only mode, set the device
value -1
.
If you use this code or data in your research, please cite our work:
@misc{srivastava2024mst5,
title={MST5 -- Multilingual Question Answering over Knowledge Graphs},
author={Nikit Srivastava and Mengshi Ma and Daniel Vollmers and Hamada Zahera and Diego Moussallem and Axel-Cyrille Ngonga Ngomo},
year={2024},
eprint={2407.06041},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2407.06041},
}