Skip to content

Commit

Permalink
Improve save_for_mobile cxx binary (pytorch#43721)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#43721

We can combine optimization pass and save_for_mobile together to reduce friction. Since lite interpreter model can also be used in full JIT, I don't think we need the option to save it as full JIT model.

Also
- improved usage message
- print op list before and after optimization pass

Test Plan:
```
buck run //xplat/caffe2:optimize_for_mobile -- --model=/home/linbin/sparkspot.pt

Building: finished in 12.4 sec (100%) 2597/2597 jobs, 2 updated
  Total time: 12.5 sec

pt_operator_library(
        name = "old_op_library",
        ops = [
                "aten::_convolution",
                "aten::adaptive_avg_pool2d",
                "aten::add_.Tensor",
                "aten::batch_norm",
                "aten::mul.Tensor",
                "aten::relu_",
                "aten::softplus",
                "aten::sub.Tensor",
        ],
)

pt_operator_library(
        name = "new_op_library",
        ops = [
                "aten::adaptive_avg_pool2d",
                "aten::add_.Tensor",
                "aten::batch_norm",
                "aten::mul.Tensor",
                "aten::relu_",
                "aten::softplus",
                "aten::sub.Tensor",
                "prepacked::conv2d_clamp_run",
        ],
)

The optimized model for lite interpreter was saved to /home/linbin/sparkspot_mobile_optimized.bc
```

```
buck run //xplat/caffe2:optimize_for_mobile -- --model=/home/linbin/sparkspot.pt --backend=vulkan
```

Reviewed By: kimishpatel

Differential Revision: D23363533

fbshipit-source-id: f7fd61aaeda5944de5bf198e7f93cacf8368babd
  • Loading branch information
linbinyu authored and facebook-github-bot committed Aug 27, 2020
1 parent 3830998 commit bff741a
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions binaries/optimize_for_mobile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,71 @@

#include <string>

#include "torch/script.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
#include "torch/csrc/jit/serialization/export.h"

C10_DEFINE_string(model, "", "The given torch script model to transform.");
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(
output,
"",
"Name of the output model to be saved.");
C10_DEFINE_bool(
save_for_mobile,
false,
"Save the model with bytecode format compatible with lite inteprter.");
C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
C10_DEFINE_string(backend, "", "The backend to be optimized");

int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run speed benchmark for pytorch model.\n"
"Example usage:\n"
"\nRun optimization pass for pytorch model. Example usage:\n"
"./optimize_for_mobile"
" --model=<model_file>"
" --output=<output_file_name>");
" [--output=<output_file_name>]"
" [--backend=<cpu|vulkan>]"
);

if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
std::cout << c10::UsageMessage() << std::endl;
return 1;
}

CAFFE_ENFORCE(FLAGS_model != "", "Valid input must be provided.");
CAFFE_ENFORCE(FLAGS_model != "", c10::UsageMessage());

std::string output_model_name =
FLAGS_model.substr(0, FLAGS_model.find(".")) + "_mobile_optimized.pt";
FLAGS_model.substr(0, FLAGS_model.find(".")) + "_optimized.bc";

if (FLAGS_output != "") {
output_model_name = FLAGS_output;
}

auto module = torch::jit::load(FLAGS_model);
auto ops = torch::jit::export_opnames(module);
std::cout << "\npt_operator_library(" << std::endl;
std::cout << "\tname = \"old_op_library\"," << std::endl;
std::cout << "\tops = [" << std::endl;
for (auto const& op: ops) {
std::cout << "\t\t\"" << op << "\"," << std::endl;
}
std::cout << "\t],\n)\n" << std::endl;

auto optimized_module = FLAGS_vulkan
? torch::jit::vulkanOptimizeForMobile(module)
: torch::jit::optimizeForMobile(module);

if (FLAGS_save_for_mobile) {
optimized_module._save_for_mobile(output_model_name);
torch::jit::Module optimized_module;
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
optimized_module = torch::jit::optimizeForMobile(module);
} else if (FLAGS_backend == "vulkan") {
optimized_module = torch::jit::vulkanOptimizeForMobile(module);
} else {
optimized_module.save(output_model_name);
CAFFE_ENFORCE(false, "Unknown backend: " + FLAGS_backend);
}

auto new_ops = torch::jit::export_opnames(optimized_module);
std::cout << "\npt_operator_library(" << std::endl;
std::cout << "\tname = \"new_op_library\"," << std::endl;
std::cout << "\tops = [" << std::endl;
for (auto const& op: new_ops) {
std::cout << "\t\t\"" << op << "\"," << std::endl;
}
std::cout << "\t],\n)\n" << std::endl;
optimized_module._save_for_mobile(output_model_name);
std::cout << "The optimized model for lite interpreter was saved to " << output_model_name << std::endl;
return 0;
}

0 comments on commit bff741a

Please sign in to comment.