feat: add tinyllama

This commit is contained in:
Ting-Jun Wang 2024-06-11 17:37:18 +08:00
parent 7d78b49815
commit 94c69a4c60
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -19,7 +19,8 @@ def parse_args():
choices=["DPO", "ORPO"])
parser.add_argument("--model_name", type=str,
choices=["unsloth/llama-3-8b-bnb-4bit",
"unsloth/mistral-7b-v0.3-bnb-4bit"],
"unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/tinyllama-bnb-4bit"],
required=True)
parser.add_argument("--train", action="store_true")
parser.add_argument("--inference_base_model", action="store_true")
@ -90,5 +91,8 @@ if __name__ == "__main__":
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
inference.LLM_inference(args)
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
inference.LLM_inference(args)
else:
raise ValueError("Invalid model name")