Compare commits
3 Commits
becfb3e0f3
...
91843816ff
| Author | SHA1 | Date | |
|---|---|---|---|
| 91843816ff | |||
| 94c69a4c60 | |||
| 7d78b49815 |
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
**/__pycache__/
|
||||||
|
outputs/
|
||||||
|
submission/
|
||||||
|
wandb/
|
||||||
30
DPO.py
30
DPO.py
@ -5,7 +5,7 @@ import utils
|
|||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer, DPOConfig
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from unsloth import is_bfloat16_supported
|
from unsloth import is_bfloat16_supported
|
||||||
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def DPO_train(args, output_dir):
|
def DPO_train(args, output_dir):
|
||||||
wandb.login(key=args.wandb_token)
|
wandb.login(key=args.wandb_token)
|
||||||
wandb.init(project="hw6_rlhf",
|
wandb.init(project="hw6_rlhf",
|
||||||
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
|
|||||||
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
|
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
|
||||||
dataset = dataset.rename_column('question', 'prompt')
|
dataset = dataset.rename_column('question', 'prompt')
|
||||||
dataset = dataset.train_test_split(test_size=0.01)
|
dataset = dataset.train_test_split(test_size=0.01)
|
||||||
|
print(f"DATASET: {dataset}")
|
||||||
|
|
||||||
with open("./test_prompt.json", 'r') as f:
|
with open("./test_prompt.json", 'r') as f:
|
||||||
test_data = json.load(f)
|
test_data = json.load(f)
|
||||||
# ================================DO NOT CHANGE!================================
|
# ================================DO NOT CHANGE!================================
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
utils.YOUR_CODE_HERE
|
model_name=args.model_name,
|
||||||
|
max_seq_length=args.max_length,
|
||||||
|
dtype=None,
|
||||||
|
load_in_4bit=True
|
||||||
|
)
|
||||||
|
|
||||||
# Perform model patching and add fast LoRA weights
|
# Perform model patching and add fast LoRA weights
|
||||||
# model = FastLanguageModel.get_peft_model(model,...)
|
# model = FastLanguageModel.get_peft_model(model,...)
|
||||||
utils.YOUR_CODE_HERE
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=16,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias='none',
|
||||||
|
use_gradient_checkpointing="unsloth",
|
||||||
|
random_state=3407
|
||||||
|
)
|
||||||
|
|
||||||
# Training arguments
|
# Training arguments
|
||||||
training_args = TrainingArguments(
|
training_args = DPOConfig(
|
||||||
per_device_train_batch_size=args.train_batch_size,
|
per_device_train_batch_size=args.train_batch_size,
|
||||||
per_device_eval_batch_size=args.eval_batch_size,
|
per_device_eval_batch_size=args.eval_batch_size,
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
|
|||||||
dpo_trainer = DPOTrainer(
|
dpo_trainer = DPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_dataset=utils.YOUR_CODE_HERE,
|
train_dataset=dataset['train'],
|
||||||
eval_dataset=utils.YOUR_CODE_HERE,
|
eval_dataset=dataset['test'],
|
||||||
args=training_args,
|
args=training_args,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
|
|||||||
6
main.py
6
main.py
@ -19,7 +19,8 @@ def parse_args():
|
|||||||
choices=["DPO", "ORPO"])
|
choices=["DPO", "ORPO"])
|
||||||
parser.add_argument("--model_name", type=str,
|
parser.add_argument("--model_name", type=str,
|
||||||
choices=["unsloth/llama-3-8b-bnb-4bit",
|
choices=["unsloth/llama-3-8b-bnb-4bit",
|
||||||
"unsloth/mistral-7b-v0.3-bnb-4bit"],
|
"unsloth/mistral-7b-v0.3-bnb-4bit",
|
||||||
|
"unsloth/tinyllama-bnb-4bit"],
|
||||||
required=True)
|
required=True)
|
||||||
parser.add_argument("--train", action="store_true")
|
parser.add_argument("--train", action="store_true")
|
||||||
parser.add_argument("--inference_base_model", action="store_true")
|
parser.add_argument("--inference_base_model", action="store_true")
|
||||||
@ -90,5 +91,8 @@ if __name__ == "__main__":
|
|||||||
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
|
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
|
||||||
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
|
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
|
||||||
inference.LLM_inference(args)
|
inference.LLM_inference(args)
|
||||||
|
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
|
||||||
|
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
|
||||||
|
inference.LLM_inference(args)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid model name")
|
raise ValueError("Invalid model name")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user