feat: DPO
This commit is contained in:
parent
94c69a4c60
commit
91843816ff
30
DPO.py
30
DPO.py
@ -5,7 +5,7 @@ import utils
|
||||
import torch
|
||||
import wandb
|
||||
from tqdm.auto import tqdm
|
||||
from trl import DPOTrainer
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
from datasets import load_dataset
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth import is_bfloat16_supported
|
||||
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
|
||||
|
||||
|
||||
|
||||
|
||||
def DPO_train(args, output_dir):
|
||||
wandb.login(key=args.wandb_token)
|
||||
wandb.init(project="hw6_rlhf",
|
||||
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
|
||||
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
|
||||
dataset = dataset.rename_column('question', 'prompt')
|
||||
dataset = dataset.train_test_split(test_size=0.01)
|
||||
print(f"DATASET: {dataset}")
|
||||
|
||||
with open("./test_prompt.json", 'r') as f:
|
||||
test_data = json.load(f)
|
||||
# ================================DO NOT CHANGE!================================
|
||||
|
||||
# Model
|
||||
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
|
||||
utils.YOUR_CODE_HERE
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=args.model_name,
|
||||
max_seq_length=args.max_length,
|
||||
dtype=None,
|
||||
load_in_4bit=True
|
||||
)
|
||||
|
||||
# Perform model patching and add fast LoRA weights
|
||||
# model = FastLanguageModel.get_peft_model(model,...)
|
||||
utils.YOUR_CODE_HERE
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias='none',
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=args.train_batch_size,
|
||||
per_device_eval_batch_size=args.eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
|
||||
dpo_trainer = DPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=utils.YOUR_CODE_HERE,
|
||||
eval_dataset=utils.YOUR_CODE_HERE,
|
||||
train_dataset=dataset['train'],
|
||||
eval_dataset=dataset['test'],
|
||||
args=training_args,
|
||||
beta=args.beta,
|
||||
max_length=args.max_length,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user