feat: DPO
This commit is contained in:
parent
94c69a4c60
commit
91843816ff
30
DPO.py
30
DPO.py
@ -5,7 +5,7 @@ import utils
|
|||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer, DPOConfig
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from unsloth import is_bfloat16_supported
|
from unsloth import is_bfloat16_supported
|
||||||
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def DPO_train(args, output_dir):
|
def DPO_train(args, output_dir):
|
||||||
wandb.login(key=args.wandb_token)
|
wandb.login(key=args.wandb_token)
|
||||||
wandb.init(project="hw6_rlhf",
|
wandb.init(project="hw6_rlhf",
|
||||||
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
|
|||||||
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
|
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
|
||||||
dataset = dataset.rename_column('question', 'prompt')
|
dataset = dataset.rename_column('question', 'prompt')
|
||||||
dataset = dataset.train_test_split(test_size=0.01)
|
dataset = dataset.train_test_split(test_size=0.01)
|
||||||
|
print(f"DATASET: {dataset}")
|
||||||
|
|
||||||
with open("./test_prompt.json", 'r') as f:
|
with open("./test_prompt.json", 'r') as f:
|
||||||
test_data = json.load(f)
|
test_data = json.load(f)
|
||||||
# ================================DO NOT CHANGE!================================
|
# ================================DO NOT CHANGE!================================
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
utils.YOUR_CODE_HERE
|
model_name=args.model_name,
|
||||||
|
max_seq_length=args.max_length,
|
||||||
|
dtype=None,
|
||||||
|
load_in_4bit=True
|
||||||
|
)
|
||||||
|
|
||||||
# Perform model patching and add fast LoRA weights
|
# Perform model patching and add fast LoRA weights
|
||||||
# model = FastLanguageModel.get_peft_model(model,...)
|
# model = FastLanguageModel.get_peft_model(model,...)
|
||||||
utils.YOUR_CODE_HERE
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=16,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias='none',
|
||||||
|
use_gradient_checkpointing="unsloth",
|
||||||
|
random_state=3407
|
||||||
|
)
|
||||||
|
|
||||||
# Training arguments
|
# Training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = DPOConfig(
|
||||||
per_device_train_batch_size=args.train_batch_size,
|
per_device_train_batch_size=args.train_batch_size,
|
||||||
per_device_eval_batch_size=args.eval_batch_size,
|
per_device_eval_batch_size=args.eval_batch_size,
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
|
|||||||
dpo_trainer = DPOTrainer(
|
dpo_trainer = DPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_dataset=utils.YOUR_CODE_HERE,
|
train_dataset=dataset['train'],
|
||||||
eval_dataset=utils.YOUR_CODE_HERE,
|
eval_dataset=dataset['test'],
|
||||||
args=training_args,
|
args=training_args,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user