Skip to content

Commit

Permalink
Python function to extract information on mobile::Module from flatbuf…
Browse files Browse the repository at this point in the history
…fer (pytorch#77328)

Includes following refactor:
1. common loading on operator validation that is dup'd in pickle and
   flatbuffer loader moved to function.h/cpp
2. Allow loading of a function without wiring operator.

This function will be used to implement get_bundled_input and friends
for flatbuffer.

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#77328
Approved by: https://github.com/cccclai
  • Loading branch information
qihqi authored and pytorchmergebot committed May 16, 2022
1 parent b5bc954 commit 69fa49f
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 75 deletions.
28 changes: 28 additions & 0 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,34 @@ def forward(self) -> Optional[FooTuple]:
output = m_loaded()
self.assertEqual(output, None)

def test_module_info_flatbuffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x

first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
first_saved_module.seek(0)
expected = {
'bytecode_version': 4,
'operator_version': 4,
'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
'type_names': set(),
'opname_to_num_args': {'aten::linear': 3}}
self.assertEqual(
torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
expected)


def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/init_flatbuffer_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,18 @@ extern "C"
reinterpret_cast<char*>(detached_buffer.data()),
detached_buffer.size());
});
pym.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});

return module;
}
2 changes: 2 additions & 0 deletions torch/csrc/jit/mobile/code.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct Code {
// be done in parseMethods().
std::vector<mobile::Function*> functions_;
size_t register_size_ = 0; // Aggregated output size.
// initialized means operators_ array is filled with operators
bool initialized = false;
};

} // namespace mobile
Expand Down
24 changes: 12 additions & 12 deletions torch/csrc/jit/mobile/flatbuffer_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
function->append_constant(getIValue(i));
}

std::unordered_set<std::string> unsupported_op_names;

appendUpgraderFunctions(function.get());
// 2. Decides if upgrader is needed
const uint32_t operator_version = module_->operator_version();
Expand All @@ -254,19 +252,13 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
num_args = op->num_args_serialized();
}

auto op_found = function->append_operator(
function->append_operator(
op->name()->str(), op->overload_name()->str(), num_args);

if (!op_found) {
unsupported_op_names.emplace(
op->name()->str() + "/" + op->overload_name()->str());
}
}

TORCH_CHECK(
unsupported_op_names.empty(),
"Unsupported ops: ",
c10::Join(", ", unsupported_op_names));
if (should_load_operators_) {
function->initialize_operators(true);
}

for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
Expand Down Expand Up @@ -725,5 +717,13 @@ uint64_t get_bytecode_version(const std::string& filename) {
return flatbuffer_module->bytecode_version();
}

mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
FlatbufferLoader loader;
loader.setShouldLoadOperators(false);
mobile::Module m = loader.parseModule(ff_module);
return mobile::get_module_info(m);
}

} // namespace jit
} // namespace torch
17 changes: 17 additions & 0 deletions torch/csrc/jit/mobile/flatbuffer_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
// If should_copy_tensor_memory is true, then the returned module will NOT
// have refences to flatbuffer_module, so it can be discarded.
// If should_copy_tensor_memory is false, then returned module will have
// tensors that points inside of flatbuffer_module; the caller need to make
// sure that flatbuffer_module outlives returned Module.
TORCH_API mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device> device = c10::nullopt,
Expand Down Expand Up @@ -66,6 +71,9 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);

TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
char* flatbuffer_content);

class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();
Expand Down Expand Up @@ -118,6 +126,14 @@ class TORCH_API FlatbufferLoader {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}

// Whether or not should load operators in functions.
// Not loading operators is useful because if an operator is not found
// then we throw exceptions, and sometimes we want to print out
// what operators are included before that to debug.
void setShouldLoadOperators(bool should_load_operators) {
should_load_operators_ = should_load_operators;
}

std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;

Expand All @@ -141,6 +157,7 @@ class TORCH_API FlatbufferLoader {
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
bool should_load_operators_ = true;
};

} // namespace jit
Expand Down
48 changes: 42 additions & 6 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>

Expand Down Expand Up @@ -49,16 +50,50 @@ bool Function::append_operator(
const c10::optional<int>& num_specified_args) {
// Keep the original opname in code_
code_.op_names_.emplace_back(name, overload_name);
const auto& opname = code_.op_names_.back();
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
auto func = makeOperatorFunction(opname, num_specified_args);
if (!func.has_value()) {
return false;
}
code_.operators_.emplace_back(*func);
return true;
}

void print_unsupported_ops_and_throw(
const std::unordered_set<std::string>& unsupported_ops) {}

std::string operator_str(const c10::OperatorName& opname) {
std::string result = opname.name;
if (!opname.overload_name.empty()) {
result += "." + opname.overload_name;
}
return result;
}

bool Function::initialize_operators(bool should_check_operators) {
if (code_.initialized) {
return true;
}
std::unordered_set<std::string> unsupported_op_names;
code_.operators_.clear();
bool all_ops_supported = true;
for (int i = 0; i < code_.op_names_.size(); i++) {
const auto& opname = code_.op_names_[i];
int num_args = code_.operator_input_sizes_[i];
c10::optional<int> num_specified_args =
num_args < 0 ? c10::nullopt : c10::optional<int>(num_args);
auto func = makeOperatorFunction(opname, num_specified_args);
if (!func.has_value()) {
unsupported_op_names.insert(operator_str(opname));
all_ops_supported = false;
}
code_.operators_.emplace_back(*func);
}
if (should_check_operators) {
TORCH_CHECK(
unsupported_op_names.empty(),
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
c10::Join(", ", unsupported_op_names));
}
code_.initialized = all_ops_supported;
return all_ops_supported;
}

void Function::append_constant(const c10::IValue& constant) {
code_.constants_.push_back(constant);
}
Expand Down Expand Up @@ -96,6 +131,7 @@ const c10::FunctionSchema& Function::getSchema() const {
}

void Function::run(Stack& stack) {
initialize_operators(/* should_check_operators */ true);
if (hasSchema()) { // if we have a schema then resolve optional args if any
getSchema().checkAndNormalizeInputs<c10::DynamicType>(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/mobile/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class TORCH_API Function : public torch::jit::Function {
const std::vector<c10::TypePtr>& types,
const size_t register_size);

// if not initialize, initialize by loading operators.
// return true of all op loaded, return false if some op is not found
// in the current runtime. Then, the ops that did not found will be filled
// in unsupported_op_names
bool initialize_operators(bool should_check_operators);

private:
c10::QualifiedName name_;
Code code_;
Expand All @@ -73,6 +79,8 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
c10::OperatorName opname,
c10::optional<int> num_specified_args);

TORCH_API std::string operator_str(const c10::OperatorName& opname);

} // namespace mobile
} // namespace jit
} // namespace torch
35 changes: 35 additions & 0 deletions torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <exception>

Expand Down Expand Up @@ -263,6 +264,40 @@ c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
return stack.front();
}

c10::optional<std::string> print_type(const c10::Type& t) {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return namedType->name().value().qualifiedName();
}
if (auto dyn = t.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str();
}
return c10::nullopt;
}

TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
ModuleInfo minfo;
minfo.operator_version = module.min_operator_version();
minfo.bytecode_version = module.bytecode_version();
std::vector<std::string> type_name_list;
for (const auto& func_ptr : module.compilation_unit().methods()) {
const auto& function = *func_ptr;
for (int i = 0; i < function.get_code().op_names_.size(); i++) {
const auto& op = function.get_code().op_names_[i];
minfo.opname_to_num_args[mobile::operator_str(op)] =
function.get_code().operator_input_sizes_[i];
}
for (const c10::TypePtr& tp : function.get_code().types_) {
type_name_list.push_back(tp->annotation_str(print_type));
}
minfo.function_names.insert(function.qualname().qualifiedName());
}
c10::TypeParser parser(type_name_list);
parser.parseList();
minfo.type_names = parser.getContainedTypes();
return minfo;
}

} // namespace mobile
} // namespace jit
} // namespace torch
10 changes: 10 additions & 0 deletions torch/csrc/jit/mobile/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ class TORCH_API Module {
// Extra handle for the module to delete when itself is deleted
std::shared_ptr<char> mem_to_delete_;
};

struct TORCH_API ModuleInfo {
uint64_t bytecode_version;
uint64_t operator_version;
std::unordered_map<std::string, int> opname_to_num_args;
std::unordered_set<std::string> function_names;
std::unordered_set<std::string> type_names;
};
TORCH_API ModuleInfo get_module_info(const mobile::Module& module);

} // namespace mobile
} // namespace jit
} // namespace torch
55 changes: 5 additions & 50 deletions torch/csrc/jit/mobile/parse_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,10 @@ namespace torch {
namespace jit {
namespace mobile {

std::string operator_str(
const std::string& name,
const std::string& overloadname) {
std::string result = name;
if (!overloadname.empty()) {
result += "." + overloadname;
}
return result;
}

/**
* Loads operators by looking them up in the Dispatcher and returns
* the set of operator names (with overload) that are not supported
* by the current runtime.
*/
std::unordered_set<std::string> load_and_find_unsupported_operator_names(
void parseOperators(
c10::ivalue::TupleElements&& ops_list,
const uint64_t& module_load_options,
mobile::Function* function) {
std::unordered_set<std::string> unsupported_op_names;
// ops_list is the list of operator names that were read in from
// bytecode.plk for the method that is currently being processed.
for (auto& op : std::move(ops_list)) {
auto op_item = std::move(*std::move(op).toTuple()).elements();
TORCH_CHECK(
Expand All @@ -37,41 +20,13 @@ std::unordered_set<std::string> load_and_find_unsupported_operator_names(
if (op_item.size() > 2) {
num_args = op_item[2].toInt();
}
auto op_found = function->append_operator(
function->append_operator(
op_item[0].toString()->string(),
op_item[1].toString()->string(),
num_args);
if (!op_found) {
unsupported_op_names.emplace(operator_str(
op_item[0].toString()->string(), op_item[1].toString()->string()));
}
}
return unsupported_op_names;
}

void print_unsupported_ops_and_throw(
const std::unordered_set<std::string>& unsupported_ops) {
std::string error_message("{");
for (const auto& op_name : unsupported_ops) {
error_message += op_name + ", ";
}
error_message += "}";
TORCH_CHECK(
false,
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
error_message);
}

void parseOperators(
c10::ivalue::TupleElements&& ops_list,
const uint64_t& module_load_options,
mobile::Function* function) {
std::unordered_set<std::string> unsupported_op_names =
load_and_find_unsupported_operator_names(std::move(ops_list), function);
if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
}
function->initialize_operators(
(module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK));
}

} // namespace mobile
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/mobile/upgrader_mobile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
op.overload_name,
op.num_specified_args);
}
upgrader_function.function.initialize_operators(true);
}
return upgrader_function_list;
};
Expand Down
Loading

0 comments on commit 69fa49f

Please sign in to comment.