From 94c69a4c606517caf10dd87755cccfef8ed3b262 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Tue, 11 Jun 2024 17:37:18 +0800 Subject: [PATCH] feat: add tinyllama --- main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 238aa32..f4b4e42 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,8 @@ def parse_args(): choices=["DPO", "ORPO"]) parser.add_argument("--model_name", type=str, choices=["unsloth/llama-3-8b-bnb-4bit", - "unsloth/mistral-7b-v0.3-bnb-4bit"], + "unsloth/mistral-7b-v0.3-bnb-4bit", + "unsloth/tinyllama-bnb-4bit"], required=True) parser.add_argument("--train", action="store_true") parser.add_argument("--inference_base_model", action="store_true") @@ -90,5 +91,8 @@ if __name__ == "__main__": elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit": print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit") inference.LLM_inference(args) + elif args.model_name == "unsloth/tinyllama-bnb-4bit": + print("Inference with base model: unsloth/tinyllama-bnb-4bit") + inference.LLM_inference(args) else: raise ValueError("Invalid model name")