diff --git a/convert-to-safetensors.py b/convert-to-safetensors.py index 9d3e3f56..60770843 100644 --- a/convert-to-safetensors.py +++ b/convert-to-safetensors.py @@ -23,6 +23,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).') parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).") +parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') args = parser.parse_args() if __name__ == '__main__': @@ -30,7 +31,7 @@ if __name__ == '__main__': model_name = path.name print(f"Loading {model_name}...") - model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16) + model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16) tokenizer = AutoTokenizer.from_pretrained(path) out_folder = args.output or Path(f"models/{model_name}_safetensors")