diff --git a/DPO.py b/DPO.py index 3666840..be2fa96 100644 --- a/DPO.py +++ b/DPO.py @@ -5,7 +5,7 @@ import utils import torch import wandb from tqdm.auto import tqdm -from trl import DPOTrainer +from trl import DPOTrainer, DPOConfig from datasets import load_dataset from unsloth import FastLanguageModel from unsloth import is_bfloat16_supported @@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer + def DPO_train(args, output_dir): wandb.login(key=args.wandb_token) wandb.init(project="hw6_rlhf", @@ -29,21 +30,36 @@ def DPO_train(args, output_dir): dataset = load_dataset("Intel/orca_dpo_pairs", split="train") dataset = dataset.rename_column('question', 'prompt') dataset = dataset.train_test_split(test_size=0.01) + print(f"DATASET: {dataset}") with open("./test_prompt.json", 'r') as f: test_data = json.load(f) # ================================DO NOT CHANGE!================================ # Model - # model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...) - utils.YOUR_CODE_HERE + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.model_name, + max_seq_length=args.max_length, + dtype=None, + load_in_4bit=True + ) # Perform model patching and add fast LoRA weights # model = FastLanguageModel.get_peft_model(model,...) - utils.YOUR_CODE_HERE + model = FastLanguageModel.get_peft_model( + model, + r=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + lora_alpha=16, + lora_dropout=0, + bias='none', + use_gradient_checkpointing="unsloth", + random_state=3407 + ) # Training arguments - training_args = TrainingArguments( + training_args = DPOConfig( per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -71,8 +87,8 @@ def DPO_train(args, output_dir): dpo_trainer = DPOTrainer( model=model, tokenizer=tokenizer, - train_dataset=utils.YOUR_CODE_HERE, - eval_dataset=utils.YOUR_CODE_HERE, + train_dataset=dataset['train'], + eval_dataset=dataset['test'], args=training_args, beta=args.beta, max_length=args.max_length,