feat: DPO

This commit is contained in:
Ting-Jun Wang 2024-06-11 21:38:18 +08:00
parent 94c69a4c60
commit 91843816ff
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

30
DPO.py
View File

@ -5,7 +5,7 @@ import utils
import torch import torch
import wandb import wandb
from tqdm.auto import tqdm from tqdm.auto import tqdm
from trl import DPOTrainer from trl import DPOTrainer, DPOConfig
from datasets import load_dataset from datasets import load_dataset
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported from unsloth import is_bfloat16_supported
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
def DPO_train(args, output_dir): def DPO_train(args, output_dir):
wandb.login(key=args.wandb_token) wandb.login(key=args.wandb_token)
wandb.init(project="hw6_rlhf", wandb.init(project="hw6_rlhf",
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
dataset = load_dataset("Intel/orca_dpo_pairs", split="train") dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
dataset = dataset.rename_column('question', 'prompt') dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.train_test_split(test_size=0.01) dataset = dataset.train_test_split(test_size=0.01)
print(f"DATASET: {dataset}")
with open("./test_prompt.json", 'r') as f: with open("./test_prompt.json", 'r') as f:
test_data = json.load(f) test_data = json.load(f)
# ================================DO NOT CHANGE!================================ # ================================DO NOT CHANGE!================================
# Model # Model
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...) model, tokenizer = FastLanguageModel.from_pretrained(
utils.YOUR_CODE_HERE model_name=args.model_name,
max_seq_length=args.max_length,
dtype=None,
load_in_4bit=True
)
# Perform model patching and add fast LoRA weights # Perform model patching and add fast LoRA weights
# model = FastLanguageModel.get_peft_model(model,...) # model = FastLanguageModel.get_peft_model(model,...)
utils.YOUR_CODE_HERE model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias='none',
use_gradient_checkpointing="unsloth",
random_state=3407
)
# Training arguments # Training arguments
training_args = TrainingArguments( training_args = DPOConfig(
per_device_train_batch_size=args.train_batch_size, per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size, per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
dpo_trainer = DPOTrainer( dpo_trainer = DPOTrainer(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
train_dataset=utils.YOUR_CODE_HERE, train_dataset=dataset['train'],
eval_dataset=utils.YOUR_CODE_HERE, eval_dataset=dataset['test'],
args=training_args, args=training_args,
beta=args.beta, beta=args.beta,
max_length=args.max_length, max_length=args.max_length,