Skip to content

Latest commit

 

History

History
93 lines (56 loc) · 3.59 KB

README.md

File metadata and controls

93 lines (56 loc) · 3.59 KB

LLama2 in Pytorch

Overview

This projects implements LLama2 transformer decoder architecture for self-supervised prediction, which is at the core of LLMs. It aims to provide a simple and efficient implementation of popular Llama model which is based on the original transformer architecture which is highly flexible and powerful, but implements few upgrades such as: rotary embeddings, grouped query attention for a tradeoff between MHA and MQA, SwiGLU, RMS Norm and KV Caching.

Llama2 Architecture

LLaMa2

The Llama2 architecture consists of the Transformer Decoder architecture, coupled with few upgrades such as :

  • Rotary Embeddings
  • SwiGLU
  • Grouped Query Attention
  • KV Caching

Decoder: The decoder takes in the output of the encoder and generates the final output sequence. It also consists of a stack of decoder layers. Each decoder layer has a grouped query multi-head self-attention mechanism, feed-forward neural network. It benefits from RoPe encodings, KV caching and everything mentioned above.

Grouped Query Attention: The grouped query attention mechanism is a modification to the traditional attention mechanism in the transformer architecture. It allows the model to attend to different groups of queries within the input sequence, enabling a tradeoff between multi-head attention and multi-query attention. This helps improve the model's ability to capture complex dependencies and relationships within the data.

For more details on the transformer architecture, refer to the original paper: Llama.

Features

🔀 Self-Supervised Prediction: The training loop is designed to support self-supervised prediction, enabling the model to learn from unlabeled data.

Setup

To get started with Transformer Plain, follow these steps:

  1. Clone the repository:

    git clone https://github.com/paulilioaica/Llama2-Pytorch
    cd Llama2-Pytorch/
    
  2. Install the required dependencies:

    pip install -r requirements.txt

Usage

from llama import LLama2

decoder_layers_num = 2
num_hidden = 16
num_heads = 4
num_kv_heads = 2
seq_len = 256
vocab_size = 100

model = Llama2(decoder_layers_num, num_hidden, num_heads, num_kv_heads, seq_len, vocab_size)

# batch_size, seq_len, 1 (vocab_index)
x = torch.randint(0, vocab_size, (1, seq_len))

output = model(x)
print(output.shape)
torch.Size([1, 256, 100])

OR

  1. Dataset: Make sure you have a dataset suitable for self-supervised prediction from Huggingface (or use the AG-NEWS one). Simply pass the dataset_name for training on your dataset of choice.

  2. Configure the training parameters: Adjust the hyperparameters by passing your own arguments.

  3. Train the model: Run the training script to start the self-supervised prediction training loop.

  4. Evaluate the model: Use the trained model to make predictions on your test dataset and evaluate its performance.

Example run

python main.py  --num_layers 2 --n_heads 8 --num_kv_heads --seq_len 128 --num_hidden 128 --num_epochs 10 --batch_size 32 --lr 0.001 --device cpu --dataset_name ag_news

License

This project is licensed under the MIT License.