mirror of
https://github.com/WongKinYiu/yolov7.git
synced 2025-02-17 12:50:14 +08:00
Fixed CoreML export
This commit is contained in:
parent
0563c70705
commit
c51c13a23f
44
export.py
44
export.py
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||
|
||||
@ -31,6 +32,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
|
||||
parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
|
||||
parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export')
|
||||
parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization')
|
||||
opt = parser.parse_args()
|
||||
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
|
||||
opt.dynamic = opt.dynamic and not opt.end2end
|
||||
@ -66,6 +69,7 @@ if __name__ == '__main__':
|
||||
if opt.include_nms:
|
||||
model.model[-1].include_nms = True
|
||||
y = None
|
||||
|
||||
# TorchScript export
|
||||
try:
|
||||
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
|
||||
@ -76,13 +80,35 @@ if __name__ == '__main__':
|
||||
except Exception as e:
|
||||
print('TorchScript export failure: %s' % e)
|
||||
|
||||
# CoreML export
|
||||
try:
|
||||
import coremltools as ct
|
||||
|
||||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
|
||||
# convert model from torchscript and apply pixel scaling as per detect.py
|
||||
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
|
||||
bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None)
|
||||
if bits < 32:
|
||||
if platform.system() == 'Darwin': # quantization only supported on macOS
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
else:
|
||||
print(f'{prefix} quantization only supported on macOS, skipping...')
|
||||
|
||||
f = opt.weights.replace('.pt', '.mlmodel') # filename
|
||||
ct_model.save(f)
|
||||
print('CoreML export success, saved as %s' % f)
|
||||
except Exception as e:
|
||||
print('CoreML export failure: %s' % e)
|
||||
|
||||
# TorchScript-Lite export
|
||||
try:
|
||||
print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__)
|
||||
f = opt.weights.replace('.pt', '.torchscript.ptl') # filename
|
||||
ts = torch.jit.trace(model, img, strict=False)
|
||||
ts = optimize_for_mobile(ts)
|
||||
ts._save_for_lite_interpreter(f)
|
||||
tsl = torch.jit.trace(model, img, strict=False)
|
||||
tsl = optimize_for_mobile(tsl)
|
||||
tsl._save_for_lite_interpreter(f)
|
||||
print('TorchScript-Lite export success, saved as %s' % f)
|
||||
except Exception as e:
|
||||
print('TorchScript-Lite export failure: %s' % e)
|
||||
@ -171,18 +197,6 @@ if __name__ == '__main__':
|
||||
|
||||
except Exception as e:
|
||||
print('ONNX export failure: %s' % e)
|
||||
# CoreML export
|
||||
try:
|
||||
import coremltools as ct
|
||||
|
||||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
|
||||
# convert model from torchscript and apply pixel scaling as per detect.py
|
||||
model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
|
||||
f = opt.weights.replace('.pt', '.mlmodel') # filename
|
||||
model.save(f)
|
||||
print('CoreML export success, saved as %s' % f)
|
||||
except Exception as e:
|
||||
print('CoreML export failure: %s' % e)
|
||||
|
||||
# Finish
|
||||
print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
|
||||
|
Loading…
Reference in New Issue
Block a user