Compare commits

..

No commits in common. "91843816ff9bbf7bc474c4b0fadb25fa0795b967" and "becfb3e0f38355db1332c8e013e740fcc0c6cce3" have entirely different histories.

3 changed files with 8 additions and 32 deletions

4
.gitignore vendored
View File

@ -1,4 +0,0 @@
**/__pycache__/
outputs/
submission/
wandb/

30
DPO.py
View File

@ -5,7 +5,7 @@ import utils
import torch import torch
import wandb import wandb
from tqdm.auto import tqdm from tqdm.auto import tqdm
from trl import DPOTrainer, DPOConfig from trl import DPOTrainer
from datasets import load_dataset from datasets import load_dataset
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported from unsloth import is_bfloat16_supported
@ -13,7 +13,6 @@ from transformers import TrainingArguments, TextStreamer
def DPO_train(args, output_dir): def DPO_train(args, output_dir):
wandb.login(key=args.wandb_token) wandb.login(key=args.wandb_token)
wandb.init(project="hw6_rlhf", wandb.init(project="hw6_rlhf",
@ -30,36 +29,21 @@ def DPO_train(args, output_dir):
dataset = load_dataset("Intel/orca_dpo_pairs", split="train") dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
dataset = dataset.rename_column('question', 'prompt') dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.train_test_split(test_size=0.01) dataset = dataset.train_test_split(test_size=0.01)
print(f"DATASET: {dataset}")
with open("./test_prompt.json", 'r') as f: with open("./test_prompt.json", 'r') as f:
test_data = json.load(f) test_data = json.load(f)
# ================================DO NOT CHANGE!================================ # ================================DO NOT CHANGE!================================
# Model # Model
model, tokenizer = FastLanguageModel.from_pretrained( # model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
model_name=args.model_name, utils.YOUR_CODE_HERE
max_seq_length=args.max_length,
dtype=None,
load_in_4bit=True
)
# Perform model patching and add fast LoRA weights # Perform model patching and add fast LoRA weights
# model = FastLanguageModel.get_peft_model(model,...) # model = FastLanguageModel.get_peft_model(model,...)
model = FastLanguageModel.get_peft_model( utils.YOUR_CODE_HERE
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias='none',
use_gradient_checkpointing="unsloth",
random_state=3407
)
# Training arguments # Training arguments
training_args = DPOConfig( training_args = TrainingArguments(
per_device_train_batch_size=args.train_batch_size, per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size, per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -87,8 +71,8 @@ def DPO_train(args, output_dir):
dpo_trainer = DPOTrainer( dpo_trainer = DPOTrainer(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
train_dataset=dataset['train'], train_dataset=utils.YOUR_CODE_HERE,
eval_dataset=dataset['test'], eval_dataset=utils.YOUR_CODE_HERE,
args=training_args, args=training_args,
beta=args.beta, beta=args.beta,
max_length=args.max_length, max_length=args.max_length,

View File

@ -19,8 +19,7 @@ def parse_args():
choices=["DPO", "ORPO"]) choices=["DPO", "ORPO"])
parser.add_argument("--model_name", type=str, parser.add_argument("--model_name", type=str,
choices=["unsloth/llama-3-8b-bnb-4bit", choices=["unsloth/llama-3-8b-bnb-4bit",
"unsloth/mistral-7b-v0.3-bnb-4bit", "unsloth/mistral-7b-v0.3-bnb-4bit"],
"unsloth/tinyllama-bnb-4bit"],
required=True) required=True)
parser.add_argument("--train", action="store_true") parser.add_argument("--train", action="store_true")
parser.add_argument("--inference_base_model", action="store_true") parser.add_argument("--inference_base_model", action="store_true")
@ -91,8 +90,5 @@ if __name__ == "__main__":
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit": elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit") print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
inference.LLM_inference(args) inference.LLM_inference(args)
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
inference.LLM_inference(args)
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")