Compare commits

...

3 Commits

Author SHA1 Message Date
91843816ff
feat: DPO 2024-06-11 21:38:18 +08:00
94c69a4c60
feat: add tinyllama 2024-06-11 17:37:18 +08:00
7d78b49815
docs: add .gitignore 2024-06-11 17:36:12 +08:00
3 changed files with 32 additions and 8 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
**/__pycache__/
outputs/
submission/
wandb/

30
DPO.py
View File

@ -5,7 +5,7 @@ import utils
import torch import torch
import wandb import wandb
from tqdm.auto import tqdm from tqdm.auto import tqdm
from trl import DPOTrainer from trl import DPOTrainer, DPOConfig
from datasets import load_dataset from datasets import load_dataset
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported from unsloth import is_bfloat16_supported
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
def DPO_train(args, output_dir): def DPO_train(args, output_dir):
wandb.login(key=args.wandb_token) wandb.login(key=args.wandb_token)
wandb.init(project="hw6_rlhf", wandb.init(project="hw6_rlhf",
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
dataset = load_dataset("Intel/orca_dpo_pairs", split="train") dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
dataset = dataset.rename_column('question', 'prompt') dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.train_test_split(test_size=0.01) dataset = dataset.train_test_split(test_size=0.01)
print(f"DATASET: {dataset}")
with open("./test_prompt.json", 'r') as f: with open("./test_prompt.json", 'r') as f:
test_data = json.load(f) test_data = json.load(f)
# ================================DO NOT CHANGE!================================ # ================================DO NOT CHANGE!================================
# Model # Model
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...) model, tokenizer = FastLanguageModel.from_pretrained(
utils.YOUR_CODE_HERE model_name=args.model_name,
max_seq_length=args.max_length,
dtype=None,
load_in_4bit=True
)
# Perform model patching and add fast LoRA weights # Perform model patching and add fast LoRA weights
# model = FastLanguageModel.get_peft_model(model,...) # model = FastLanguageModel.get_peft_model(model,...)
utils.YOUR_CODE_HERE model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias='none',
use_gradient_checkpointing="unsloth",
random_state=3407
)
# Training arguments # Training arguments
training_args = TrainingArguments( training_args = DPOConfig(
per_device_train_batch_size=args.train_batch_size, per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size, per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
dpo_trainer = DPOTrainer( dpo_trainer = DPOTrainer(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
train_dataset=utils.YOUR_CODE_HERE, train_dataset=dataset['train'],
eval_dataset=utils.YOUR_CODE_HERE, eval_dataset=dataset['test'],
args=training_args, args=training_args,
beta=args.beta, beta=args.beta,
max_length=args.max_length, max_length=args.max_length,

View File

@ -19,7 +19,8 @@ def parse_args():
choices=["DPO", "ORPO"]) choices=["DPO", "ORPO"])
parser.add_argument("--model_name", type=str, parser.add_argument("--model_name", type=str,
choices=["unsloth/llama-3-8b-bnb-4bit", choices=["unsloth/llama-3-8b-bnb-4bit",
"unsloth/mistral-7b-v0.3-bnb-4bit"], "unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/tinyllama-bnb-4bit"],
required=True) required=True)
parser.add_argument("--train", action="store_true") parser.add_argument("--train", action="store_true")
parser.add_argument("--inference_base_model", action="store_true") parser.add_argument("--inference_base_model", action="store_true")
@ -90,5 +91,8 @@ if __name__ == "__main__":
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit": elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit") print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
inference.LLM_inference(args) inference.LLM_inference(args)
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
inference.LLM_inference(args)
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")