feat: add tinyllama
This commit is contained in:
parent
7d78b49815
commit
94c69a4c60
6
main.py
6
main.py
@ -19,7 +19,8 @@ def parse_args():
|
||||
choices=["DPO", "ORPO"])
|
||||
parser.add_argument("--model_name", type=str,
|
||||
choices=["unsloth/llama-3-8b-bnb-4bit",
|
||||
"unsloth/mistral-7b-v0.3-bnb-4bit"],
|
||||
"unsloth/mistral-7b-v0.3-bnb-4bit",
|
||||
"unsloth/tinyllama-bnb-4bit"],
|
||||
required=True)
|
||||
parser.add_argument("--train", action="store_true")
|
||||
parser.add_argument("--inference_base_model", action="store_true")
|
||||
@ -90,5 +91,8 @@ if __name__ == "__main__":
|
||||
elif args.model_name == "unsloth/mistral-7b-v0.3-bnb-4bit":
|
||||
print("Inference with base model: unsloth/mistral-7b-v0.3-bnb-4bit")
|
||||
inference.LLM_inference(args)
|
||||
elif args.model_name == "unsloth/tinyllama-bnb-4bit":
|
||||
print("Inference with base model: unsloth/tinyllama-bnb-4bit")
|
||||
inference.LLM_inference(args)
|
||||
else:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user