diff --git a/onnx_tool/model.py b/onnx_tool/model.py index 6fef606..7a1bec9 100644 --- a/onnx_tool/model.py +++ b/onnx_tool/model.py @@ -1,4 +1,5 @@ import os +import pathlib import onnx @@ -6,10 +7,13 @@ class Model: - def __init__(self, m: [str, onnx.ModelProto], verbose=False, constant_folding: bool = True, + def __init__(self, m: [str, onnx.ModelProto, pathlib.Path], verbose=False, constant_folding: bool = True, noderename: bool = False): self.modelname = '' - if isinstance(m, str): + if isinstance(m, pathlib.Path): + self.modelname = m.name.stem + m = onnx.load_model(m) + elif isinstance(m, str): self.modelname = os.path.basename(m) self.modelname = os.path.splitext(self.modelname)[0] m = onnx.load_model(m)