Skip to content

Commit b418557

Browse files
authored
Replace model / data with public HF path, update readme (linkedin#53)
1 parent 95ac0af commit b418557

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

Diff for: examples/medusa/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Liger-Kernel Example with Medusa
2+
13
Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. [[repo](https://arxiv.org/abs/2401.10774)], [[paper](https://arxiv.org/abs/2401.10774)]
24

35
During training, Medusa requires adding \(k\) decoding heads to the hidden states right before the regular LM head \(h_t\). The \(k\)-th head is used to predict the token in the \((t + k + 1)\)-th position of the next tokens (the original language model head is used to predict the \((t + 1)\)-th position).
@@ -16,6 +18,14 @@ pip install -r requirements.txt
1618
sh scripts/llama3_8b_medusa.sh
1719
```
1820

21+
**Notes**
22+
1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads.
23+
2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings:
24+
* Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B
25+
* Run `huggingface-cli login` and enter your HuggingFace token
26+
3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.
27+
28+
1929
# Memory Profiling Result
2030

2131
> **Note:**

Diff for: examples/medusa/scripts/llama3_8b_medusa.sh

+2-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ export NUM_NODES=$WORLD_SIZE
66
export WORLD_SIZE=$((GPUS_PER_NODE * NUM_NODES))
77
echo "Starting training... Num nodes: $NUM_NODES, Num workers: $WORLD_SIZE"
88

9-
export OUTPUT_DIR="/shared/user/Meta-Llama-3-70B-Instruct-code-act-3ep"
10-
export DATA_PATH="/shared/public/data/jaszhu/medusa/ShareGPT_V4.3_unfiltered_cleaned_split.json"
9+
export OUTPUT_DIR="./llama3-8b-medusa-liger"
1110

1211
export LOCAL_TRAIN_BATCH_SIZE=4
1312
export GRADIENT_ACCUMULATION_STEPS=1
@@ -27,8 +26,6 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
2726
--main_process_port $MASTER_PORT \
2827
--machine_rank $RANK \
2928
train.py \
30-
--model_name_or_path /shared/public/models/Meta-Llama-3-8B-Instruct \
31-
--data_path $DATA_PATH \
3229
--bf16 True \
3330
--output_dir $OUTPUT_DIR \
3431
--num_train_epochs 10 \
@@ -56,4 +53,4 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
5653
--medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \
5754
--medusa_only_heads False \
5855
--medusa_return True \
59-
--with_liger True
56+
--use_liger True

Diff for: examples/medusa/train.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
@dataclass
3939
class ModelArguments:
4040
model_name_or_path: Optional[str] = field(
41-
default="/shared/public/models/Meta-Llama-3-8B-Instruct"
41+
default="meta-llama/Meta-Llama-3-8B"
4242
)
4343

4444

4545
@dataclass
4646
class DataArguments:
4747
data_path: str = field(
48-
default="sharegpt_clean.json",
48+
default="Aeala/ShareGPT_Vicuna_unfiltered",
4949
metadata={"help": "Path to the training data."},
5050
)
5151
eval_data_path: str = field(
@@ -99,7 +99,7 @@ class TrainingArguments(transformers.TrainingArguments):
9999
"help": "If train medusa heads only, default is False, the whole model will be trained"
100100
},
101101
)
102-
with_liger: bool = field(
102+
use_liger: bool = field(
103103
default=False,
104104
metadata={"help": "If apply liger kernel to the model."},
105105
)
@@ -331,7 +331,7 @@ def train():
331331
torch_dtype=torch.bfloat16,
332332
)
333333

334-
if training_args.with_liger is True:
334+
if training_args.use_liger is True:
335335
apply_liger_kernel_to_llama()
336336

337337
# Freeze the base model
@@ -344,7 +344,7 @@ def train():
344344
training_args.medusa_num_layers,
345345
training_args.medusa_return,
346346
training_args.medusa_only_heads,
347-
training_args.with_liger,
347+
training_args.use_liger,
348348
)
349349
# Format output dir
350350
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"

0 commit comments

Comments
 (0)