Compare commits

...

3 Commits

Author SHA1 Message Date
91843816ff
feat: DPO 2024-06-11 21:38:18 +08:00
94c69a4c60
feat: add tinyllama 2024-06-11 17:37:18 +08:00
7d78b49815
docs: add .gitignore 2024-06-11 17:36:12 +08:00
3 changed files with 32 additions and 8 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
**/__pycache__/
outputs/
submission/
wandb/

30
DPO.py
View File

@ -5,7 +5,7 @@ import utils
import torch
import wandb
from tqdm.auto import tqdm
from trl import DPOTrainer
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
@ -13,6 +13,7 @@ from transformers import TrainingArguments, TextStreamer
def DPO_train(args, output_dir):
wandb.login(key=args.wandb_token)
wandb.init(project="hw6_rlhf",
@ -29,21 +30,36 @@ def DPO_train(args, output_dir):
dataset = load_dataset("Intel/orca_dpo_pairs", split="train")
dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.train_test_split(test_size=0.01)
print(f"DATASET: {dataset}")
with open("./test_prompt.json", 'r') as f:
test_data = json.load(f)
# ================================DO NOT CHANGE!================================
# Model
# model, tokenizer = FastLanguageModel.from_pretrained(model_name=args.model_name,...)
utils.YOUR_CODE_HERE
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_length,
dtype=None,
load_in_4bit=True
)
# Perform model patching and add fast LoRA weights
# model = FastLanguageModel.get_peft_model(model,...)
utils.YOUR_CODE_HERE
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias='none',
use_gradient_checkpointing="unsloth",
random_state=3407
)
# Training arguments
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
@ -71,8 +87,8 @@ def DPO_train(args, output_dir):
dpo_trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=utils.YOUR_CODE_HERE,
eval_dataset=utils.YOUR_CODE_HERE,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
args=training_args,
beta=args.beta,
max_length=args.max_length,

View File

@ -19,7 +19,8 @@ def parse_args():
choices=["DPO", "ORPO"])
parser.add_argument("--model_name", type=str,
choices=["unsloth/llama-3-8b-bnb-4bit",
"unsloth/mistral-7b-v0.3-bnb-4bit"],
"unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/tinyllama-bnb-4bit"],
required=True)
parser.add_argument("--train", action="store_true")
parser.add_argument("--inference_base_model", action="store_true")
@ -90,5 +91,8 @@ if __name__ == "__main__":
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
inference.LLM_inference(args)
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
inference.LLM_inference(args)
else:
raise ValueError("Invalid model name")