diff --git a/main.py b/main.py index 238aa32..f4b4e42 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,8 @@ def parse_args(): choices=["DPO", "ORPO"]) parser.add_argument("--model_name", type=str, choices=["unsloth/llama-3-8b-bnb-4bit", - "unsloth/mistral-7b-v0.3-bnb-4bit"], + "unsloth/mistral-7b-v0.3-bnb-4bit", + "unsloth/tinyllama-bnb-4bit"], required=True) parser.add_argument("--train", action="store_true") parser.add_argument("--inference_base_model", action="store_true") @@ -90,5 +91,8 @@ if __name__ == "__main__": elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit": print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit") inference.LLM_inference(args) + elif args.model_name == "unsloth/tinyllama-bnb-4bit": + print("Inference with base model: unsloth/tinyllama-bnb-4bit") + inference.LLM_inference(args) else: raise ValueError("Invalid model name")