diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index fa69cad2b1a3..7ac32a8d4b9b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -737,6 +737,14 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol, */ MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out); +/*! + * \brief Get a symbol that contains only direct children. + * \param symbol The symbol + * \param out The output symbol whose outputs are the direct children. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol, + SymbolHandle *out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol diff --git a/nnvm b/nnvm index 767f81898292..9d6b4e4f9ecb 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 767f8189829252ac43687e1e0a85288dfd1be75c +Subproject commit 9d6b4e4f9ecbb0af8bc935d4b3ca7de0c0eb4147 diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 26aa108df395..1d59bb9800c2 100755 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -269,7 +269,7 @@ def _init_weight(self, _, arr): @register class Constant(Initializer): - """Initialize the weight to 1""" + """Initialize the weight to a scalar value""" def __init__(self, value): super(Constant, self).__init__(value=value) self.value = value diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 391cca65a3ba..ffbe11b79b5b 100644 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -170,6 +170,7 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_ self.batch_size = None self.slices = None self.execs = [] + self._default_execs = None self.data_arrays = None self.label_arrays = None self.param_arrays = None @@ -272,8 +273,8 @@ def bind_exec(self, data_shapes, label_shapes, shared_group=None, reshape=False) label_shapes_i = [] if reshape: - self.execs[i] = self.execs[i].reshape(allow_up_sizing=True, - **dict(data_shapes_i + label_shapes_i)) + self.execs[i] = self._default_execs[i].reshape( + allow_up_sizing=True, **dict(data_shapes_i + label_shapes_i)) else: self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i, shared_group)) @@ -292,6 +293,8 @@ def reshape(self, data_shapes, label_shapes): """ if data_shapes == self.data_shapes and label_shapes == self.label_shapes: return + if self._default_execs is None: + self._default_execs = [i for i in self.execs] self.bind_exec(data_shapes, label_shapes, reshape=True) def set_params(self, arg_params, aux_params): diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index a3522e169683..aa58dc935c3d 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -5,6 +5,7 @@ """ import logging +import warnings from .. import context as ctx from .. import ndarray as nd @@ -402,10 +403,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', (kvstore, update_on_kvstore) = \ _create_kvstore(kvstore, len(self._context), self._arg_params) + batch_size = self._exec_group.batch_size + if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: + batch_size *= kvstore.num_workers + rescale_grad = 1.0/batch_size + if isinstance(optimizer, str): - batch_size = self._exec_group.batch_size - if kvstore and kvstore.type == 'dist_sync': - batch_size *= kvstore.num_workers idx2name = {} if update_on_kvstore: idx2name.update(enumerate(self._exec_group.param_names)) @@ -415,12 +418,19 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', for i, n in enumerate(self._exec_group.param_names)}) optimizer_params = dict(optimizer_params) if 'rescale_grad' not in optimizer_params: - optimizer_params['rescale_grad'] = 1.0/batch_size + optimizer_params['rescale_grad'] = rescale_grad optimizer = opt.create(optimizer, sym=self.symbol, param_idx2name=idx2name, **optimizer_params) else: assert isinstance(optimizer, opt.Optimizer) + if optimizer.rescale_grad != rescale_grad: + #pylint: disable=no-member + warnings.warn( + "Optimizer created manually outside Module but rescale_grad " + + "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%( + optimizer.rescale_grad, rescale_grad) + + "Is this intended?", stacklevel=2) self._optimizer = optimizer self._kvstore = kvstore diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index 2427686a32c0..c09d4d0ea9ea 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -7,7 +7,7 @@ import warnings from .. import symbol, init, ndarray -from ..base import numeric_types, string_types +from ..base import string_types class RNNParams(object): """Container for holding variables. @@ -59,9 +59,14 @@ def __init__(self, prefix='', params=None): self._own_params = False self._prefix = prefix self._params = params + self._modified = False + + self.reset() + + def reset(self): + """Reset before re-using the cell for another graph""" self._init_counter = -1 self._counter = -1 - self._modified = False def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -93,16 +98,18 @@ def state_shape(self): """shape(s) of states""" raise NotImplementedError() - def begin_state(self, init_sym=symbol.zeros, **kwargs): + def begin_state(self, func=symbol.zeros, **kwargs): """Initial state for this cell. Parameters ---------- - init_sym : Symbol, default symbol.zeros - Symbol for generating initial state. Can be zeros, - ones, uniform, normal, etc. + func : callable, default symbol.zeros + Function for creating initial state. Can be symbol.zeros, + symbol.uniform, symbol.Variable etc. + Use symbol.Variable if you want to directly + feed input as states. **kwargs : - more keyword arguments passed to init_sym. For example + more keyword arguments passed to func. For example mean, std, dtype, etc. Returns @@ -113,19 +120,17 @@ def begin_state(self, init_sym=symbol.zeros, **kwargs): assert not self._modified, \ "After applying modifier cells (e.g. DropoutCell) the base " \ "cell cannot be called directly. Call the modifier cell instead." - state_shape = self.state_shape - def recursive(shape): - """Recursively construct input states""" - if isinstance(shape, tuple): - assert len(shape) == 0 or isinstance(shape[0], numeric_types) - self._init_counter += 1 - return init_sym(name='%sinit_%d'%(self._prefix, self._init_counter), - shape=shape, **kwargs) + states = [] + for shape in self.state_shape: + self._init_counter += 1 + if shape is None: + state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter), + **kwargs) else: - assert isinstance(shape, list) - return [recursive(i) for i in shape] - - return recursive(state_shape) + state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter), + shape=shape, **kwargs) + states.append(state) + return states def unpack_weights(self, args): """Unpack fused weight matrices into separate @@ -208,6 +213,8 @@ def unroll(self, length, inputs=None, begin_state=None, states : Symbol or nested list of Symbol has the same structure as begin_state() """ + self.reset() + axis = layout.find('T') if inputs is None: inputs = [symbol.Variable('%st%d_data'%(input_prefix, i)) @@ -271,7 +278,7 @@ def __init__(self, num_hidden, activation='tanh', prefix='rnn_', params=None): @property def state_shape(self): """shape(s) of states""" - return (0, self._num_hidden) + return [(0, self._num_hidden)] def __call__(self, inputs, states): """Construct symbol for one step of RNN. @@ -295,7 +302,7 @@ def __call__(self, inputs, states): i2h = symbol.FullyConnected(data=inputs, weight=self._iW, bias=self._iB, num_hidden=self._num_hidden, name='%si2h'%name) - h2h = symbol.FullyConnected(data=states, weight=self._hW, bias=self._hB, + h2h = symbol.FullyConnected(data=states[0], weight=self._hW, bias=self._hB, num_hidden=self._num_hidden, name='%sh2h'%name) output = self._get_activation(i2h + h2h, self._activation, @@ -471,11 +478,8 @@ def __init__(self, num_hidden, num_layers=1, mode='lstm', bidirectional=False, def state_shape(self): """shape(s) of states""" b = self._bidirectional + 1 - if self._mode == 'lstm': - return [(b*self._num_layers, 0, self._num_hidden), - (b*self._num_layers, 0, self._num_hidden)] - else: - return (b*self._num_layers, 0, self._num_hidden) + n = (self._mode == 'lstm') + 1 + return [(b*self._num_layers, 0, self._num_hidden)]*n def _slice_weights(self, arr, li, lh): """slice fused rnn weights""" @@ -616,6 +620,8 @@ def unroll(self, length, inputs=None, begin_state=None, states : Symbol or nested list of Symbol has the same structure as begin_state() """ + self.reset() + axis = layout.find('T') if inputs is None: inputs = symbol.Variable('%sdata'%input_prefix) @@ -631,8 +637,8 @@ def unroll(self, length, inputs=None, begin_state=None, assert axis == 0, "Unsupported layout %s"%layout else: assert len(inputs) == length - inputs = [symbol.expand_dims(i, axis=0) for i in inputs] - inputs = symbol.Concat(inputs, dim=0) + inputs = [symbol.expand_dims(i, axis=1) for i in inputs] + inputs = symbol.Concat(*inputs, dim=1) if begin_state is None: begin_state = self.begin_state() @@ -699,7 +705,7 @@ def add(self, cell): @property def state_shape(self): """shape(s) of states""" - return [c.state_shape for c in self._cells] + return sum([c.state_shape for c in self._cells], []) def begin_state(self, **kwargs): """Initial state for this cell. @@ -721,7 +727,7 @@ def begin_state(self, **kwargs): assert not self._modified, \ "After applying modifier cells (e.g. DropoutCell) the base " \ "cell cannot be called directly. Call the modifier cell instead." - return [c.begin_state(**kwargs) for c in self._cells] + return sum([c.begin_state(**kwargs) for c in self._cells], []) def unpack_weights(self, args): for cell in self._cells: @@ -752,10 +758,14 @@ def __call__(self, inputs, states): """ self._counter += 1 next_states = [] - for cell, state in zip(self._cells, states): + p = 0 + for cell in self._cells: + n = len(cell.state_shape) + state = states[p:p+n] + p += n inputs, state = cell(inputs, state) next_states.append(state) - return inputs, next_states + return inputs, sum(next_states, []) class ModifierCell(BaseRNNCell): """Base class for modifier cells. A modifier diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 2d1744d3aff9..9c996e3db9a3 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -5,6 +5,7 @@ from __future__ import absolute_import as _abs import ctypes +import warnings from numbers import Number import os as _os @@ -368,7 +369,8 @@ def _set_attr(self, **kwargs): self.handle, c_str(key), c_str(str(value)))) def get_internals(self): - """Get a new grouped symbol whose output contains all the internal outputs of this symbol. + """Get a new grouped symbol whose output contains + internal outputs of this symbol. Returns ------- @@ -380,6 +382,24 @@ def get_internals(self): self.handle, ctypes.byref(handle))) return Symbol(handle=handle) + def get_children(self): + """Get a new grouped symbol whose output contains + inputs to output nodes of the original symbol + + Returns + ------- + sgroup : Symbol or None + The children of the head node. If the symbol has no + inputs None will be returned. + """ + handle = SymbolHandle() + check_call(_LIB.MXSymbolGetChildren( + self.handle, ctypes.byref(handle))) + ret = Symbol(handle=handle) + if len(ret.list_outputs()) == 0: + return None + return ret + def list_arguments(self): """List all the arguments in the symbol. @@ -539,7 +559,23 @@ def infer_shape(self, *args, **kwargs): The order is in the same order as list_auxiliary() """ try: - return self._infer_shape_impl(False, *args, **kwargs) + res = self._infer_shape_impl(False, *args, **kwargs) + if res[1] is None: + arg_shapes, _, _ = self._infer_shape_impl(True, *args, **kwargs) + arg_names = self.list_arguments() + unknowns = [] + for name, shape in zip(arg_names, arg_shapes): + if not shape or not _numpy.prod(shape): + if len(unknowns) >= 10: + unknowns.append('...') + break + unknowns.append('%s: %s'%(name, str(shape))) + warnings.warn( + "Cannot decide shape for the following arguments " + + "(0s in shape means unknown dimensions). " + + "Consider providing them as input:\n\t" + + "\n\t".join(unknowns), stacklevel=2) + return res except MXNetError: print("infer_shape error. Arguments:") for i, arg in enumerate(args): diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 37698125edfb..f7281c999e6a 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -142,7 +142,20 @@ int MXSymbolGetOutput(SymbolHandle symbol, int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle *out) { - return NNSymbolGetInternals(symbol, out); + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->GetInternals(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolGetChildren(SymbolHandle symbol, + SymbolHandle *out) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + *s = static_cast(symbol)->GetChildren(); + *out = s; + API_END_HANDLE_ERROR(delete s); } int MXSymbolFree(SymbolHandle symbol) { diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index ca4ceeea2325..2ae11fa4a87f 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -43,13 +43,29 @@ def test_symbol_internal(): data = mx.symbol.Variable('data') oldfc = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100) - net1.list_arguments() == ['data', - 'fc1_weight', 'fc1_bias', - 'fc2_weight', 'fc2_bias'] + assert net1.list_arguments() == ['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias'] + internal = net1.get_internals() fc1 = internal['fc1_output'] assert fc1.list_arguments() == oldfc.list_arguments() +def test_symbol_children(): + data = mx.symbol.Variable('data') + oldfc = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) + net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100) + + assert net1.get_children().list_outputs() == ['fc1_output', 'fc2_weight', 'fc2_bias'] + assert net1.get_children().get_children().list_outputs() == ['data', 'fc1_weight', 'fc1_bias'] + assert net1.get_children()['fc2_weight'].list_arguments() == ['fc2_weight'] + assert net1.get_children()['fc2_weight'].get_children() is None + + data = mx.sym.Variable('data') + sliced = mx.sym.SliceChannel(data, num_outputs=3, name='slice') + concat = mx.sym.Concat(*list(sliced)) + + assert concat.get_children().list_outputs() == \ + ['slice_output0', 'slice_output1', 'slice_output2'] + assert sliced.get_children().list_outputs() == ['data'] def test_symbol_pickle(): mlist = [models.mlp2(), models.conv()] @@ -166,6 +182,7 @@ def test_load_000800(): if __name__ == '__main__': + test_symbol_children() test_load_000800() test_symbol_infer_shape_var() test_symbol_infer_shape()