Skip to content

Commit

Permalink
Add resnet inference jupyter notebook.
Browse files Browse the repository at this point in the history
This takes the example from torchscript_resnet18_e2e.py and puts it into
a slightly cleaned up notebook form.

It's still a little rough around the edges. Areas for improvement:
- Installation / setup.
- API usability.

Also,
- Add `npcomp-backend-to-iree-frontend-pipeline` since we will be adding
  more stuff there.
- Slight cleanups.
  • Loading branch information
silvasean committed Aug 9, 2021
1 parent f71845e commit 902c2e5
Show file tree
Hide file tree
Showing 8 changed files with 565 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.vscode
.env
*.code-workspace
.ipynb_checkpoints

/build/
__pycache__
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_argparse():
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
parser.add_argument('--filter', default='.*', help='''
parser.add_argument('-f', '--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
''')
parser.add_argument('-v', '--verbose',
Expand Down
500 changes: 500 additions & 0 deletions frontends/pytorch/examples/resnet_inference.ipynb

Large diffs are not rendered by default.

45 changes: 27 additions & 18 deletions frontends/pytorch/examples/torchscript_resnet18_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree
from npcomp.compiler.utils import logging

logging.enable()

mb = torch_mlir.ModuleBuilder()


def load_and_preprocess_image(url: str):
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
headers = {
'User-Agent':
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
}
img = Image.open(requests.get(url, headers=headers,
stream=True).raw).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)

Expand Down Expand Up @@ -78,6 +81,15 @@ def forward(self, x):
return self.s.forward(x)


image_url = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
import sys

print("load image from " + image_url, file=sys.stderr)
img = load_and_preprocess_image(image_url)
labels = load_labels()

test_module = TestModule()
class_annotator = torch_mlir.ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)
Expand All @@ -88,7 +100,10 @@ def forward(self, x):
class_annotator.annotateArgs(
recursivescriptmodule._c._type(),
["forward"],
[None, ([-1, -1, -1, -1], torch.float32, True),],
[
None,
([-1, -1, -1, -1], torch.float32, True),
],
)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)
Expand All @@ -97,10 +112,4 @@ def forward(self, x):
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

image_url = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
print("load image from " + image_url)
img = load_and_preprocess_image(image_url)
labels = load_labels()
predictions(test_module.forward, jit_module.forward, img, labels)
4 changes: 4 additions & 0 deletions include/npcomp/Backend/IREE/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ namespace IREEBackend {
/// Registers all IREEBackend passes.
void registerIREEBackendPasses();

/// Create a pipeline that runs all passes needed to lower the npcomp backend
/// contract to IREE's frontend contract.
void createNpcompBackendToIreeFrontendPipeline(OpPassManager &pm);

std::unique_ptr<OperationPass<ModuleOp>> createLowerLinkagePass();

} // namespace IREEBackend
Expand Down
20 changes: 15 additions & 5 deletions lib/Backend/IREE/LowerLinkage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@ using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::IREEBackend;

namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/Backend/IREE/Passes.h.inc"
} // end namespace
// This pass lowers the public ABI of the module to the primitives exposed by
// the refbackrt dialect.
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
void runOnOperation() override {
ModuleOp module = getOperation();
for (auto func : module.getOps<FuncOp>()) {
if (func.getVisibility() == SymbolTable::Visibility::Public)
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
}
}
};
} // namespace

void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
::registerPasses();
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
return std::make_unique<LowerLinkagePass>();
}
31 changes: 16 additions & 15 deletions lib/Backend/IREE/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@ using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::IREEBackend;

namespace {
// This pass lowers the public ABI of the module to the primitives exposed by
// the refbackrt dialect.
class LowerLinkagePass : public LowerLinkageBase<LowerLinkagePass> {
void runOnOperation() override {
ModuleOp module = getOperation();
for (auto func : module.getOps<FuncOp>()) {
if (func.getVisibility() == SymbolTable::Visibility::Public)
func->setAttr("iree.module.export", UnitAttr::get(&getContext()));
}
}
};
} // namespace
#define GEN_PASS_REGISTRATION
#include "npcomp/Backend/IREE/Passes.h.inc"
} // end namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::IREEBackend::createLowerLinkagePass() {
return std::make_unique<LowerLinkagePass>();
void mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline(
OpPassManager &pm) {
pm.addPass(createLowerLinkagePass());
}

void mlir::NPCOMP::IREEBackend::registerIREEBackendPasses() {
::registerPasses();

mlir::PassPipelineRegistration<>(
"npcomp-backend-to-iree-frontend-pipeline",
"Pipeline lowering the npcomp backend contract IR to IREE's frontend "
"contract.",
mlir::NPCOMP::IREEBackend::createNpcompBackendToIreeFrontendPipeline);
}
5 changes: 1 addition & 4 deletions python/npcomp/compiler/pytorch/backend/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
"IreeNpcompBackend",
]

PREPARE_FOR_IREE_PASSES = (
"npcomp-iree-backend-lower-linkage",
)

class IreeModuleInvoker:
"""Wrapper around a native IREE module for calling functions."""
Expand Down Expand Up @@ -88,7 +85,7 @@ def compile(self, imported_module: Module):
if self._debug:
logging.debug("IR passed to IREE compiler backend:\n{}",
imported_module)
pipeline_str = ",".join(PREPARE_FOR_IREE_PASSES)
pipeline_str = "npcomp-backend-to-iree-frontend-pipeline"
if self._debug:
logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str)
pm = PassManager.parse(pipeline_str)
Expand Down

0 comments on commit 902c2e5

Please sign in to comment.