Skip to content

Commit

Permalink
Make NNCFNetwork deepcopy-able (openvinotoolkit#1096)
Browse files Browse the repository at this point in the history
* Make NNCFNetwork deepcopy-able

deepcopy-ability is only available outside the `forward` call of the
same NNCFNetwork, which is probably good enough for most purposes.

* Another version of the fix
  • Loading branch information
vshampor authored Feb 8, 2022
1 parent b07ca14 commit af19a57
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 68 deletions.
113 changes: 52 additions & 61 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def __str__(self):
def __hash__(self):
return hash(str(self))

class CopySafeThreadingVars:
""" A class holding variables that are related to threading and
thus impossible to deepcopy. The deepcopy will simply return a
new object without copying, but won't fail."""
def __init__(self):
self.thread_local = threading.local()
self.cond = threading.Condition()

def __deepcopy__(self, memo):
return CopySafeThreadingVars()

# pylint: disable=too-many-public-methods
class TracingContext:
Expand All @@ -60,10 +70,10 @@ def __init__(self):
self._pre_hooks = {} # type: Dict[PreHookId, List[Callable]]
self._num_nested_hooks = 0

self._thread_local = threading.local()
self._threading = CopySafeThreadingVars()

self._n_instances_searching_graph = 0
self._cond = threading.Condition()

self._is_tracing = True
self._is_forwarding = False
self._may_add_nodes = True
Expand All @@ -76,28 +86,29 @@ def __enter__(self):
global _CURRENT_CONTEXT
self._save_context = _CURRENT_CONTEXT
_CURRENT_CONTEXT = self
self._init_thread_local()
self._reset_thread_local()
if is_debug():
self.reset_node_call_counters()

return self

def __exit__(self, *args):
self.reset_scope_operator_call_counters()
self.relative_scopes_stack.clear()
self.module_call_stack.clear()
self.leave()
self._reset_thread_local()

global _CURRENT_CONTEXT
_CURRENT_CONTEXT = self._save_context
self._save_context = None

def find_operator_node(self, tensor_metas: List[Optional[TensorMeta]],
op_address: OperationAddress) -> Optional[DynamicGraphNode]:
with self._cond:
with self._threading.cond:
self._n_instances_searching_graph += 1

node = self.graph.find_node(op_address, tensor_metas, self._input_comparators_per_scope)

with self._cond:
with self._threading.cond:
self._n_instances_searching_graph -= 1
self._cond.notify_all()
self._threading.cond.notify_all()
return node

def register_global_buffer(self, name: str, buffer):
Expand All @@ -109,9 +120,9 @@ def maybe_add_node(self, inputs: OperatorInput, tensor_metas: List[Optional[Tens
ignored_algorithms: List[str] = None) -> Optional[DynamicGraphNode]:
if not self._may_add_nodes:
return None
with self._cond:
with self._threading.cond:
while self._n_instances_searching_graph > 0:
self._cond.wait()
self._threading.cond.wait()
# Another thread may have added a node inside this block,
# so we need to check again if a node is already added.
node = self.graph.find_node(op_address, tensor_metas, self._input_comparators_per_scope)
Expand Down Expand Up @@ -144,41 +155,30 @@ def reset_scope_operator_call_counters(self):
Must be called after each "forward" operation of the model that is made
within this context
"""
self._thread_local.operator_counters = {}
self._threading.thread_local.operator_counters = {}

@staticmethod
def _get_operator_counter_key(operator_name: str, scope: Scope):
return "{}_{}".format(str(scope), operator_name)

def register_operator_call(self, operator_name: str, scope: Scope):
key = self._get_operator_counter_key(operator_name, scope)
if key in self._thread_local.operator_counters:
self._thread_local.operator_counters[key] += 1
if key in self._threading.thread_local.operator_counters:
self._threading.thread_local.operator_counters[key] += 1
else:
self._thread_local.operator_counters[key] = 1
self._threading.thread_local.operator_counters[key] = 1

def get_operator_call_count_in_scope(self, operator_name: str, scope: Scope):
key = self._get_operator_counter_key(operator_name, scope)
if key in self._thread_local.operator_counters:
return self._thread_local.operator_counters[key]
if key in self._threading.thread_local.operator_counters:
return self._threading.thread_local.operator_counters[key]
return 0

def reset_operator_call_count_in_scope(self, scope):
scoped_op_name = str(scope)
for key in self._thread_local.operator_counters.keys():
for key in self._threading.thread_local.operator_counters.keys():
if scoped_op_name in key:
self._thread_local.operator_counters[key] = 0

def enter(self):
global _CURRENT_CONTEXT
self._save_context = _CURRENT_CONTEXT
_CURRENT_CONTEXT = self
self._init_thread_local()

def leave(self):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT = self._save_context
self._save_context = None
self._threading.thread_local.operator_counters[key] = 0

def push_scope(self, called_module: torch.nn.Module):
relative_scopes_list = self._get_scope_relative_to_last_registered_module_call(called_module)
Expand All @@ -200,7 +200,7 @@ def execute_pre_hooks(self, op_address: OperationAddress,
op_inputs: OperatorInput) -> OperatorInput:
in_op = getattr(self, 'in_operator', False)
self.in_operator = False
self._thread_local.num_nested_hooks += 1
self._threading.thread_local.num_nested_hooks += 1

pre_hook_ids_for_curr_op = [x for x in self._pre_hooks if x.op_address == op_address]
pre_hook_ids_for_curr_op = sorted(pre_hook_ids_for_curr_op, key=lambda x: x.input_port_id)
Expand All @@ -209,7 +209,7 @@ def execute_pre_hooks(self, op_address: OperationAddress,
input_arg_to_process = pre_hook_id.input_port_id
for hook in hook_list_for_current_input_port:
op_inputs[input_arg_to_process] = hook(op_inputs[input_arg_to_process])
self._thread_local.num_nested_hooks -= 1
self._threading.thread_local.num_nested_hooks -= 1
self.in_operator = in_op
return op_inputs

Expand All @@ -221,11 +221,11 @@ def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddr
def execute_post_hooks(self, op_address: OperationAddress, outputs):
in_op = getattr(self, 'in_operator', False)
self.in_operator = False
self._thread_local.num_nested_hooks += 1
self._threading.thread_local.num_nested_hooks += 1
if op_address in self._post_hooks:
for hook in self._post_hooks[op_address]:
outputs = hook(outputs)
self._thread_local.num_nested_hooks -= 1
self._threading.thread_local.num_nested_hooks -= 1
self.in_operator = in_op
return outputs

Expand Down Expand Up @@ -260,29 +260,24 @@ def add_node_comparators(self, scopes_to_apply: List[str],
self._input_comparators_per_scope.append((node_input_comparator, scopes_to_apply))

@property
def base_module_thread_local_replica(self):
self._init_thread_local()
return self._thread_local.base_module_replica
def base_module_thread_local_replica(self) -> torch.nn.Module:
return self._threading.thread_local.base_module_replica

@base_module_thread_local_replica.setter
def base_module_thread_local_replica(self, value):
self._init_thread_local()
self._thread_local.base_module_replica = value
def base_module_thread_local_replica(self, value: torch.nn.Module):
self._threading.thread_local.base_module_replica = value

@property
def in_operator(self):
self._init_thread_local()
return self._thread_local.in_operator
return self._threading.thread_local.in_operator

@in_operator.setter
def in_operator(self, val):
self._init_thread_local()
self._thread_local.in_operator = val
self._threading.thread_local.in_operator = val

@property
def module_call_stack(self) -> List[torch.nn.Module]:
self._init_thread_local()
return self._thread_local.module_call_stack
return self._threading.thread_local.module_call_stack

def get_current_module(self) -> Optional[torch.nn.Module]:
if self.module_call_stack:
Expand All @@ -291,8 +286,7 @@ def get_current_module(self) -> Optional[torch.nn.Module]:

@property
def relative_scopes_stack(self) -> List[Scope]:
self._init_thread_local()
return self._thread_local.scopes
return self._threading.thread_local.scopes

@property
def trace_dynamic_graph(self) -> bool:
Expand All @@ -305,12 +299,8 @@ def disable_trace_dynamic_graph(self):
def enable_trace_dynamic_graph(self):
self._trace_dynamic_graph = True

def _init_thread_local(self):
# todo: primary node part!
tl = self._thread_local
if getattr(tl, 'ready', False):
return
tl.ready = True
def _reset_thread_local(self):
tl = self._threading.thread_local
tl.scopes = []
tl.module_call_stack = []
tl.in_operator = False
Expand All @@ -319,18 +309,19 @@ def _init_thread_local(self):
tl.operator_counters = {}
tl.node_call_tracker = {}


def register_node_call(self, node: DynamicGraphNode):
if node.node_id in self._thread_local.node_call_tracker:
self._thread_local.node_call_tracker[node.node_id] += 1
if node.node_id in self._threading.thread_local.node_call_tracker:
self._threading.thread_local.node_call_tracker[node.node_id] += 1
else:
self._thread_local.node_call_tracker[node.node_id] = 1
self._threading.thread_local.node_call_tracker[node.node_id] = 1

def reset_node_call_counters(self):
for k, _ in self._thread_local.node_call_tracker.items():
self._thread_local.node_call_tracker[k] = 0
for k, _ in self._threading.thread_local.node_call_tracker.items():
self._threading.thread_local.node_call_tracker[k] = 0

def get_node_call_counter_dict(self):
return self._thread_local.node_call_tracker
return self._threading.thread_local.node_call_tracker

def _get_scope_relative_to_last_registered_module_call(self, module) -> Scope:
module_class = module.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/graph_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def trace_graph(self, model: torch.nn.Module, context_to_use: Optional['TracingC

context_to_use.enable_trace_dynamic_graph()
from nncf.torch.utils import training_mode_switcher
context_to_use.base_module_thread_local_replica = model
with context_to_use as _ctx:
_ctx.base_module_thread_local_replica = model
with torch.no_grad():
if as_eval:
with training_mode_switcher(model, is_training=False):
Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ class LoadStateListener:
restores model state by calling this method.
"""

def __init__(self, model, all_quantizations):
def __init__(self, model: 'NNCFNetwork', all_quantizations: Dict[str, torch.nn.Module]):
# pylint: disable=protected-access
self.hook = model._register_load_state_dict_pre_hook(
functools.partial(self.hook_fn, quantize_modules=all_quantizations.values()))
functools.partial(self.hook_fn, quantize_modules=list(all_quantizations.values())))

def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs,
quantize_modules):
quantize_modules: List[torch.nn.Module]):
for module in quantize_modules:
module.initialized = False

Expand Down
5 changes: 3 additions & 2 deletions nncf/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
limitations under the License.
"""
from collections import OrderedDict
from typing import Dict, Any
from typing import Dict, Any, List

import warnings
import numpy as np
Expand Down Expand Up @@ -76,7 +76,8 @@ def get_all_modules_by_type(model, module_types=None, current_scope=None,
return found


def get_state_dict_names_with_modules(model, str_types=None, prefix=''):
def get_state_dict_names_with_modules(model: 'NNCFNetwork',
str_types: List[str] = None, prefix='') -> Dict[str, torch.nn.Module]:
found = OrderedDict()
for name, module in model.named_children():
full_node_name = "{}{}".format(prefix, name)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,4 +574,4 @@ def test_scope_and_call_counters_are_reset_on_exceptions():
assert not ctx.module_call_stack
assert not ctx.relative_scopes_stack
#pylint:disable=protected-access
assert not ctx._thread_local.operator_counters
assert not ctx._threading.thread_local.operator_counters
7 changes: 7 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,10 @@ def get_ip_graph_for_test(nncf_graph: NNCFGraph,
allowed_pre_hook_insertion_points=pre_hooks,
allowed_post_hook_insertion_points=post_hooks)
return ip_graph

def test_deepcopy_nncf_network():
model = TwoConvTestModelWithUserModule()
config = get_basic_sparsity_plus_quantization_config()
register_bn_adaptation_init_args(config)
sparse_quantized_model, _ = create_compressed_model_and_algo_for_test(model, config)
_ = deepcopy(sparse_quantized_model)

0 comments on commit af19a57

Please sign in to comment.