| import argparse
|
| import torch
|
| import os
|
| import re
|
| import onnx
|
| from spandrel import ImageModelDescriptor, ModelLoader
|
| from onnxsim import simplify
|
|
|
| def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
|
| """
|
| Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
|
|
| Args:
|
| pth_path: Path to the input .pth model file.
|
| onnx_path: Path to save the output .onnx file.
|
| channel: Number of input channels for the model.
|
| use_fp16: Boolean to determine if the model should be converted to half precision.
|
| simplify_model: Boolean to determine if the ONNX model should be simplified.
|
| """
|
|
|
| print(f"Loading model from: {pth_path}")
|
| try:
|
|
|
| model_descriptor = ModelLoader().load_from_file(pth_path)
|
|
|
|
|
| if not isinstance(model_descriptor, ImageModelDescriptor):
|
| print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
|
| print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
|
| return False
|
|
|
|
|
|
|
| torch_model = model_descriptor.model
|
|
|
|
|
| torch_model.eval()
|
|
|
| except Exception as e:
|
| print(f"Error loading model: {e}")
|
| return False
|
|
|
| if channel == 0:
|
| channel = model_descriptor.input_channels
|
| if tilesize<1:
|
| tilesize = 64
|
| example_input = torch.randn(1, channel, tilesize, tilesize)
|
| print("Model input channels:", channel, "tile size:", tilesize)
|
|
|
| if use_fp16:
|
| if torch.cuda.is_available():
|
| torch_model.cuda()
|
| example_input = example_input.cuda()
|
| else:
|
| print("Warning: no CUDA device")
|
| torch_model.half()
|
| example_input = example_input.half()
|
| print(f"Model loaded successfully: {type(torch_model).__name__}")
|
|
|
| if output_folder:
|
| os.makedirs(output_folder, exist_ok=True)
|
|
|
| if onnx_path is None:
|
| base_path, _ = os.path.splitext(pth_path)
|
| if output_folder:
|
| base_path = os.path.join(output_folder, os.path.basename(base_path))
|
|
|
| scale = model_descriptor.scale
|
|
|
| filename = os.path.basename(pth_path).upper()
|
| pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
| if re.search(pattern, filename):
|
| print(f'File name contains scale info: {filename} ')
|
| else:
|
| base_path = f"{base_path}-x{scale}"
|
|
|
| onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
|
|
|
|
|
|
| elif output_folder:
|
| onnx_path = os.path.join(output_folder, onnx_path)
|
|
|
|
|
|
|
| print(f"ONNX model exporting...")
|
| try:
|
|
|
| if dynamic_axes:
|
| axes = {
|
| "input": {2: "height", 3: "width"},
|
| "output": {2: "height", 3: "width"},
|
| }
|
| else:
|
| axes = {}
|
|
|
| torch.onnx.export(
|
| torch_model,
|
| example_input,
|
| onnx_path,
|
| export_params=True,
|
| opset_version=opset,
|
| do_constant_folding=True,
|
| input_names=['input'],
|
| output_names=['output'],
|
| dynamic_axes=axes
|
| )
|
| print(f"ONNX model export successful: {onnx_path}")
|
|
|
|
|
| if simplify_model:
|
| model = onnx.load(onnx_path)
|
| model_simplified, _ = simplify(model)
|
| onnx.save(model_simplified, onnx_path)
|
| print(f"ONNX model simplified successfully: {onnx_path}")
|
|
|
|
|
| if os.path.exists(onnx_path):
|
| file_size = os.path.getsize(onnx_path)
|
| if file_size > min_size:
|
| return onnx_path
|
|
|
| os.remove(onnx_path)
|
| print(f"ONNX model has unexpected file size ({file_size} bytes), deleted invalid file")
|
| return ""
|
|
|
| except Exception as e:
|
| print(f"ONNX model export error: {e}")
|
| return ""
|
|
|
| if __name__ == "__main__":
|
| import argparse
|
| parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
|
| parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
|
| parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
|
| parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
|
| parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
|
| parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
| parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
| parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
| parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
|
| args = parser.parse_args()
|
|
|
| success = convert_pth_to_onnx(
|
| pth_path=args.pthpath,
|
| onnx_path=args.onnxpath,
|
| channel=args.channel,
|
| tilesize=args.tilesize,
|
| use_fp16=args.fp16,
|
| simplify_model=args.simplify,
|
| opset=args.opset,
|
| dynamic_axes= not args.fixed_axes,
|
| )
|
|
|
| if success:
|
| print("Conversion process finished.")
|
| else:
|
| print("Conversion process failed.")
|
| exit(1)
|
|
|