Skip to content

Commit

Permalink
Initial implementation of make_pipeline utility (apple#1803)
Browse files Browse the repository at this point in the history
* initial implementation of make pipeline utility

* Fix unit test

* Set weights_dir parameter

* minor cleanups

* Fix typo

* When mapping input shapes to output shapes, don't override shapes
  • Loading branch information
TobyRoseman authored Mar 21, 2023
1 parent 32b1ee0 commit 230f1e5
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 17 deletions.
31 changes: 16 additions & 15 deletions coremltools/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from ..proto import Model_pb2 as _Model_pb2
from .utils import (_MLMODEL_EXTENSION, _MLPACKAGE_AUTHOR_NAME,
_MLPACKAGE_EXTENSION, _WEIGHTS_DIR_NAME, _create_mlpackage,
_has_custom_layer, _is_macos, _macos_version)
from .utils import load_spec as _load_spec
from .utils import save_spec as _save_spec
_has_custom_layer, _is_macos, _macos_version,
load_spec as _load_spec, save_spec as _save_spec,
)

if _HAS_TORCH:
import torch
import torch as _torch

if _HAS_TF_1 or _HAS_TF_2:
import tensorflow as tf
import tensorflow as _tf


try:
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(
i.e. a spec object.
is_temp_package: bool
Set to true if the input model package dir is temporary and can be deleted upon interpreter termination.
Set to True if the input model package dir is temporary and can be deleted upon interpreter termination.
mil_program: coremltools.converters.mil.Program
Set to the MIL program object, if available.
Expand Down Expand Up @@ -326,9 +326,8 @@ def cleanup(package_path):
self.is_temp_package = False
self.package_path = None
self._weights_dir = None
if mil_program is not None:
if not isinstance(mil_program, _Program):
raise ValueError("mil_program must be of type 'coremltools.converters.mil.Program'")
if mil_program is not None and not isinstance(mil_program, _Program):
raise ValueError('"mil_program" must be of type "coremltools.converters.mil.Program"')
self._mil_program = mil_program

if isinstance(model, str):
Expand All @@ -342,8 +341,9 @@ def cleanup(package_path):
model, compute_units, skip_model_load=skip_model_load,
)
elif isinstance(model, _Model_pb2.Model):
if model.WhichOneof('Type') == "mlProgram":
if weights_dir is None:
model_type = model.WhichOneof('Type')
if model_type in ("mlProgram", 'pipelineClassifier', 'pipelineRegressor', 'pipeline'):
if model_type == "mlProgram" and weights_dir is None:
raise Exception('MLModel of type mlProgram cannot be loaded just from the model spec object. '
'It also needs the path to the weights file. Please provide that as well, '
'using the \'weights_dir\' argument.')
Expand Down Expand Up @@ -443,6 +443,7 @@ def save(self, save_path: str):
loaded_model = MLModel('my_model_file.mlmodel')
"""
save_path = _os.path.expanduser(save_path)

# Clean up existing file or directory.
if _os.path.exists(save_path):
if _os.path.isdir(save_path):
Expand Down Expand Up @@ -489,7 +490,7 @@ def predict(self, data):
Returns
-------
out: dict[str, value]
dict[str, value]
Predictions as a dictionary where each key is the output feature
name.
Expand Down Expand Up @@ -648,10 +649,10 @@ def _convert_tensor_to_numpy(self, input_dict):
def convert(given_input):
if isinstance(given_input, _numpy.ndarray):
sanitized_input = given_input
elif _HAS_TORCH and isinstance(given_input, torch.Tensor):
elif _HAS_TORCH and isinstance(given_input, _torch.Tensor):
sanitized_input = given_input.detach().numpy()
elif (_HAS_TF_1 or _HAS_TF_2) and isinstance(given_input, tf.Tensor):
sanitized_input = given_input.eval(session=tf.compat.v1.Session())
elif (_HAS_TF_1 or _HAS_TF_2) and isinstance(given_input, _tf.Tensor):
sanitized_input = given_input.eval(session=_tf.compat.v1.Session())
else:
sanitized_input = _numpy.array(given_input)
return sanitized_input
Expand Down
101 changes: 99 additions & 2 deletions coremltools/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@
"""
Utilities for the entire package.
"""

from collections.abc import Iterable as _Iterable
from functools import lru_cache as _lru_cache
import math as _math
import os as _os
import shutil as _shutil
import subprocess as _subprocess
import sys as _sys
import tempfile as _tempfile
import warnings as _warnings
from functools import lru_cache as _lru_cache
from typing import Optional as _Optional
import warnings as _warnings

import numpy as _np

import coremltools as _ct
from coremltools import ComputeUnit as _ComputeUnit
from coremltools.converters.mil.mil.passes.name_sanitization_utils import \
NameSanitizer as _NameSanitizer
from coremltools.proto import Model_pb2 as _Model_pb2
import coremltools.proto.MIL_pb2 as _mil_proto

from .._deps import _HAS_SCIPY

Expand Down Expand Up @@ -1000,3 +1004,96 @@ def _convert_to_float(feature):
if spec.WhichOneof("Type") == "pipeline":
for model_spec in spec.pipeline.models:
convert_double_to_float_multiarray_type(model_spec)


def make_pipeline(*models):
"""
Makes a pipeline with the given models.
Parameters
----------
*models - two or more instances of ct.models.MLModel
Returns
-------
ct.models.MLModel
Examples
--------
my_model1 = ct.models.MLModel('/tmp/m1.mlpackage')
my_model2 = ct.models.MLModel('/tmp/m2.mlmodel')
my_pipeline_model = ct.utils.make_pipeline(my_model1, my_model2)
"""

def updateBlobFileName(proto_message, new_path):
if type(proto_message) == _mil_proto.Value:
# Value protobuf message. This is what might need to be updated.
if proto_message.WhichOneof('value') == 'blobFileValue':
assert proto_message.blobFileValue.fileName == "@model_path/weights/weight.bin"
proto_message.blobFileValue.fileName = new_path
elif hasattr(proto_message, 'ListFields'):
# Normal protobuf message
for f in proto_message.ListFields():
updateBlobFileName(f[1], new_path)
elif hasattr(proto_message, 'values'):
# Protobuf map
for v in proto_message.values():
updateBlobFileName(v, new_path)
elif isinstance(proto_message, _Iterable) and not isinstance(proto_message, str):
# Repeated protobuf message
for e in proto_message:
updateBlobFileName(e, new_path)


assert len(models) > 1
input_specs = list(map(lambda m: m.get_spec(), models))

pipeline_spec = _ct.proto.Model_pb2.Model()
pipeline_spec.specificationVersion = max(
map(lambda spec: spec.specificationVersion, input_specs)
)

# Set pipeline input
pipeline_spec.description.input.MergeFrom(
input_specs[0].description.input
)

# Set pipeline output
pipeline_spec.description.output.MergeFrom(
input_specs[-1].description.output
)

# Map input shapes to output shapes
var_name_to_type = {}
for i in range(len(input_specs) - 1):
for j in input_specs[i + 1].description.input:
var_name_to_type[j.name] = j.type

for j in input_specs[i].description.output:
# If shape is already present, don't override it
if j.type.WhichOneof('Type') == 'multiArrayType' and len(j.type.multiArrayType.shape) != 0:
continue

if j.name in var_name_to_type:
j.type.CopyFrom(var_name_to_type[j.name])

# Update each model's spec to have a unique weight filename
for i, cur_spec in enumerate(input_specs):
if cur_spec.WhichOneof("Type") == "mlProgram":
new_file_path = f"@model_path/weights/{i}-weight.bin"
updateBlobFileName(cur_spec.mlProgram, new_file_path)
pipeline_spec.pipeline.models.append(cur_spec)

mlpackage_path = _create_mlpackage(pipeline_spec)
dst = mlpackage_path + '/Data/' + _MLPACKAGE_AUTHOR_NAME + '/' + _WEIGHTS_DIR_NAME
_os.mkdir(dst)

# Copy and rename each model's weight file
for i, cur_model in enumerate(models):
if cur_model.weights_dir is not None:
weight_file_path = cur_model.weights_dir + "/" + _WEIGHTS_FILE_NAME
if _os.path.exists(weight_file_path):
_shutil.copyfile(weight_file_path, dst + f"/{i}-weight.bin")

return _ct.models.MLModel(pipeline_spec, weights_dir=dst)
1 change: 1 addition & 0 deletions coremltools/test/api/test_api_visibilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_utils(self):
"evaluate_classifier_with_probabilities",
"evaluate_regressor",
"evaluate_transformer",
"make_pipeline",
"load_spec",
"rename_feature",
"save_spec",
Expand Down
65 changes: 65 additions & 0 deletions coremltools/test/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import itertools
import tempfile
import unittest

import numpy as np
import pytest

import coremltools as ct
from coremltools._deps import _HAS_LIBSVM, _HAS_SKLEARN
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Function, Program
from coremltools.models.pipeline import PipelineClassifier, PipelineRegressor

if _HAS_SKLEARN:
Expand Down Expand Up @@ -210,3 +218,60 @@ def test_conversion_bad_inputs(self):
with self.assertRaises(TypeError):
model = OneHotEncoder()
spec = converter.convert(model, "data", "out", "regressor")


class TestMakePipeline:
@staticmethod
def _make_model(input_name, input_length,
output_name, output_length,
convert_to):

weight_tensor = np.arange(input_length * output_length, dtype='float32')
weight_tensor = weight_tensor.reshape(output_length, input_length)

prog = Program()
func_inputs = {input_name: mb.placeholder(shape=(input_length,))}
with Function(func_inputs) as ssa_fun:
input = ssa_fun.inputs[input_name]
y = mb.linear(x=input, weight=weight_tensor, name=output_name)
ssa_fun.set_outputs([y])
prog.add_function("main", ssa_fun)

return ct.convert(prog, convert_to=convert_to)


@staticmethod
@pytest.mark.parametrize(
"model1_backend, model2_backend",
itertools.product(["mlprogram", "neuralnetwork"], ["mlprogram", "neuralnetwork"]),
)
def test_simple(model1_backend, model2_backend):
# Create models
m1 = TestMakePipeline._make_model("x", 20, "y1", 10, model1_backend)
m2 = TestMakePipeline._make_model("y1", 10, "y2", 2, model2_backend)

# Get non-pipeline result
x = np.random.rand(20)
y1 = m1.predict({"x": x})["y1"]
y2 = m2.predict({"y1": y1})

pipeline_model = ct.utils.make_pipeline(m1, m2)

y_pipeline = pipeline_model.predict({"x": x})
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])

# Check save/load
with tempfile.TemporaryDirectory() as save_dir:
# Save pipeline
save_path = save_dir + "/test.mlpackage"
pipeline_model.save(save_path)

# Check loading from a mlpackage path
p2 = ct.models.MLModel(save_path)
y_pipeline = p2.predict({"x": x})
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])

# Check loading from spec and weight dir
p3 = ct.models.MLModel(p2.get_spec(), weights_dir=p2.weights_dir)
y_pipeline = p3.predict({"x": x})
np.testing.assert_allclose(y2["y2"], y_pipeline["y2"])

0 comments on commit 230f1e5

Please sign in to comment.