feat: DPO

This commit is contained in:
Ting-Jun Wang 2024-06-11 21:38:18 +08:00
parent 94c69a4c60
commit 91843816ff
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

30
DPO.py
View File

@ -5,7 +5,7 @@ import utils
import torch
import wandb
from tqdm.auto import tqdm
from trl import DPOTrainer
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
def DPO_train(args, output_dir):
wandb.login(key=args.wandb_token)
wandb.init(project="hw6_rlhf",
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.train_test_split(test_size=0.01)
print(f"DATASET: {dataset}")
with open("./test_prompt.json", 'r') as f:
test_data = json.load(f)
# ================================DO NOT CHANGE!================================
# Model
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
utils.YOUR_CODE_HERE
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_length,
dtype=None,
load_in_4bit=True
)
# Perform model patching and add fast LoRA weights
# model = FastLanguageModel.get_peft_model(model,...)
utils.YOUR_CODE_HERE
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias='none',
use_gradient_checkpointing="unsloth",
random_state=3407
)
# Training arguments
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
dpo_trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=utils.YOUR_CODE_HERE,
eval_dataset=utils.YOUR_CODE_HERE,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
args=training_args,
beta=args.beta,
max_length=args.max_length,