Skip to content

Commit

Permalink
add api symbol.get_children (apache#5141)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Feb 25, 2017
1 parent c9e252f commit dc47e06
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 46 deletions.
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,14 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
*/
MXNET_DLL int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
* \param out The output symbol whose outputs are the direct children.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
Expand Down
2 changes: 1 addition & 1 deletion nnvm
2 changes: 1 addition & 1 deletion python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _init_weight(self, _, arr):

@register
class Constant(Initializer):
"""Initialize the weight to 1"""
"""Initialize the weight to a scalar value"""
def __init__(self, value):
super(Constant, self).__init__(value=value)
self.value = value
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_
self.batch_size = None
self.slices = None
self.execs = []
self._default_execs = None
self.data_arrays = None
self.label_arrays = None
self.param_arrays = None
Expand Down Expand Up @@ -272,8 +273,8 @@ def bind_exec(self, data_shapes, label_shapes, shared_group=None, reshape=False)
label_shapes_i = []

if reshape:
self.execs[i] = self.execs[i].reshape(allow_up_sizing=True,
**dict(data_shapes_i + label_shapes_i))
self.execs[i] = self._default_execs[i].reshape(
allow_up_sizing=True, **dict(data_shapes_i + label_shapes_i))
else:
self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
shared_group))
Expand All @@ -292,6 +293,8 @@ def reshape(self, data_shapes, label_shapes):
"""
if data_shapes == self.data_shapes and label_shapes == self.label_shapes:
return
if self._default_execs is None:
self._default_execs = [i for i in self.execs]
self.bind_exec(data_shapes, label_shapes, reshape=True)

def set_params(self, arg_params, aux_params):
Expand Down
18 changes: 14 additions & 4 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
import warnings

from .. import context as ctx
from .. import ndarray as nd
Expand Down Expand Up @@ -402,10 +403,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
(kvstore, update_on_kvstore) = \
_create_kvstore(kvstore, len(self._context), self._arg_params)

batch_size = self._exec_group.batch_size
if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type:
batch_size *= kvstore.num_workers
rescale_grad = 1.0/batch_size

if isinstance(optimizer, str):
batch_size = self._exec_group.batch_size
if kvstore and kvstore.type == 'dist_sync':
batch_size *= kvstore.num_workers
idx2name = {}
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
Expand All @@ -415,12 +418,19 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
for i, n in enumerate(self._exec_group.param_names)})
optimizer_params = dict(optimizer_params)
if 'rescale_grad' not in optimizer_params:
optimizer_params['rescale_grad'] = 1.0/batch_size
optimizer_params['rescale_grad'] = rescale_grad
optimizer = opt.create(optimizer,
sym=self.symbol, param_idx2name=idx2name,
**optimizer_params)
else:
assert isinstance(optimizer, opt.Optimizer)
if optimizer.rescale_grad != rescale_grad:
#pylint: disable=no-member
warnings.warn(
"Optimizer created manually outside Module but rescale_grad " +
"is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%(
optimizer.rescale_grad, rescale_grad) +
"Is this intended?", stacklevel=2)

self._optimizer = optimizer
self._kvstore = kvstore
Expand Down
74 changes: 42 additions & 32 deletions python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from .. import symbol, init, ndarray
from ..base import numeric_types, string_types
from ..base import string_types

class RNNParams(object):
"""Container for holding variables.
Expand Down Expand Up @@ -59,9 +59,14 @@ def __init__(self, prefix='', params=None):
self._own_params = False
self._prefix = prefix
self._params = params
self._modified = False

self.reset()

def reset(self):
"""Reset before re-using the cell for another graph"""
self._init_counter = -1
self._counter = -1
self._modified = False

def __call__(self, inputs, states):
"""Construct symbol for one step of RNN.
Expand Down Expand Up @@ -93,16 +98,18 @@ def state_shape(self):
"""shape(s) of states"""
raise NotImplementedError()

def begin_state(self, init_sym=symbol.zeros, **kwargs):
def begin_state(self, func=symbol.zeros, **kwargs):
"""Initial state for this cell.
Parameters
----------
init_sym : Symbol, default symbol.zeros
Symbol for generating initial state. Can be zeros,
ones, uniform, normal, etc.
func : callable, default symbol.zeros
Function for creating initial state. Can be symbol.zeros,
symbol.uniform, symbol.Variable etc.
Use symbol.Variable if you want to directly
feed input as states.
**kwargs :
more keyword arguments passed to init_sym. For example
more keyword arguments passed to func. For example
mean, std, dtype, etc.
Returns
Expand All @@ -113,19 +120,17 @@ def begin_state(self, init_sym=symbol.zeros, **kwargs):
assert not self._modified, \
"After applying modifier cells (e.g. DropoutCell) the base " \
"cell cannot be called directly. Call the modifier cell instead."
state_shape = self.state_shape
def recursive(shape):
"""Recursively construct input states"""
if isinstance(shape, tuple):
assert len(shape) == 0 or isinstance(shape[0], numeric_types)
self._init_counter += 1
return init_sym(name='%sinit_%d'%(self._prefix, self._init_counter),
shape=shape, **kwargs)
states = []
for shape in self.state_shape:
self._init_counter += 1
if shape is None:
state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter),
**kwargs)
else:
assert isinstance(shape, list)
return [recursive(i) for i in shape]

return recursive(state_shape)
state = func(name='%sbegin_state_%d'%(self._prefix, self._init_counter),
shape=shape, **kwargs)
states.append(state)
return states

def unpack_weights(self, args):
"""Unpack fused weight matrices into separate
Expand Down Expand Up @@ -208,6 +213,8 @@ def unroll(self, length, inputs=None, begin_state=None,
states : Symbol or nested list of Symbol
has the same structure as begin_state()
"""
self.reset()

axis = layout.find('T')
if inputs is None:
inputs = [symbol.Variable('%st%d_data'%(input_prefix, i))
Expand Down Expand Up @@ -271,7 +278,7 @@ def __init__(self, num_hidden, activation='tanh', prefix='rnn_', params=None):
@property
def state_shape(self):
"""shape(s) of states"""
return (0, self._num_hidden)
return [(0, self._num_hidden)]

def __call__(self, inputs, states):
"""Construct symbol for one step of RNN.
Expand All @@ -295,7 +302,7 @@ def __call__(self, inputs, states):
i2h = symbol.FullyConnected(data=inputs, weight=self._iW, bias=self._iB,
num_hidden=self._num_hidden,
name='%si2h'%name)
h2h = symbol.FullyConnected(data=states, weight=self._hW, bias=self._hB,
h2h = symbol.FullyConnected(data=states[0], weight=self._hW, bias=self._hB,
num_hidden=self._num_hidden,
name='%sh2h'%name)
output = self._get_activation(i2h + h2h, self._activation,
Expand Down Expand Up @@ -471,11 +478,8 @@ def __init__(self, num_hidden, num_layers=1, mode='lstm', bidirectional=False,
def state_shape(self):
"""shape(s) of states"""
b = self._bidirectional + 1
if self._mode == 'lstm':
return [(b*self._num_layers, 0, self._num_hidden),
(b*self._num_layers, 0, self._num_hidden)]
else:
return (b*self._num_layers, 0, self._num_hidden)
n = (self._mode == 'lstm') + 1
return [(b*self._num_layers, 0, self._num_hidden)]*n

def _slice_weights(self, arr, li, lh):
"""slice fused rnn weights"""
Expand Down Expand Up @@ -616,6 +620,8 @@ def unroll(self, length, inputs=None, begin_state=None,
states : Symbol or nested list of Symbol
has the same structure as begin_state()
"""
self.reset()

axis = layout.find('T')
if inputs is None:
inputs = symbol.Variable('%sdata'%input_prefix)
Expand All @@ -631,8 +637,8 @@ def unroll(self, length, inputs=None, begin_state=None,
assert axis == 0, "Unsupported layout %s"%layout
else:
assert len(inputs) == length
inputs = [symbol.expand_dims(i, axis=0) for i in inputs]
inputs = symbol.Concat(inputs, dim=0)
inputs = [symbol.expand_dims(i, axis=1) for i in inputs]
inputs = symbol.Concat(*inputs, dim=1)
if begin_state is None:
begin_state = self.begin_state()

Expand Down Expand Up @@ -699,7 +705,7 @@ def add(self, cell):
@property
def state_shape(self):
"""shape(s) of states"""
return [c.state_shape for c in self._cells]
return sum([c.state_shape for c in self._cells], [])

def begin_state(self, **kwargs):
"""Initial state for this cell.
Expand All @@ -721,7 +727,7 @@ def begin_state(self, **kwargs):
assert not self._modified, \
"After applying modifier cells (e.g. DropoutCell) the base " \
"cell cannot be called directly. Call the modifier cell instead."
return [c.begin_state(**kwargs) for c in self._cells]
return sum([c.begin_state(**kwargs) for c in self._cells], [])

def unpack_weights(self, args):
for cell in self._cells:
Expand Down Expand Up @@ -752,10 +758,14 @@ def __call__(self, inputs, states):
"""
self._counter += 1
next_states = []
for cell, state in zip(self._cells, states):
p = 0
for cell in self._cells:
n = len(cell.state_shape)
state = states[p:p+n]
p += n
inputs, state = cell(inputs, state)
next_states.append(state)
return inputs, next_states
return inputs, sum(next_states, [])

class ModifierCell(BaseRNNCell):
"""Base class for modifier cells. A modifier
Expand Down
40 changes: 38 additions & 2 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import absolute_import as _abs

import ctypes
import warnings
from numbers import Number

import os as _os
Expand Down Expand Up @@ -368,7 +369,8 @@ def _set_attr(self, **kwargs):
self.handle, c_str(key), c_str(str(value))))

def get_internals(self):
"""Get a new grouped symbol whose output contains all the internal outputs of this symbol.
"""Get a new grouped symbol whose output contains
internal outputs of this symbol.
Returns
-------
Expand All @@ -380,6 +382,24 @@ def get_internals(self):
self.handle, ctypes.byref(handle)))
return Symbol(handle=handle)

def get_children(self):
"""Get a new grouped symbol whose output contains
inputs to output nodes of the original symbol
Returns
-------
sgroup : Symbol or None
The children of the head node. If the symbol has no
inputs None will be returned.
"""
handle = SymbolHandle()
check_call(_LIB.MXSymbolGetChildren(
self.handle, ctypes.byref(handle)))
ret = Symbol(handle=handle)
if len(ret.list_outputs()) == 0:
return None
return ret

def list_arguments(self):
"""List all the arguments in the symbol.
Expand Down Expand Up @@ -539,7 +559,23 @@ def infer_shape(self, *args, **kwargs):
The order is in the same order as list_auxiliary()
"""
try:
return self._infer_shape_impl(False, *args, **kwargs)
res = self._infer_shape_impl(False, *args, **kwargs)
if res[1] is None:
arg_shapes, _, _ = self._infer_shape_impl(True, *args, **kwargs)
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
if not shape or not _numpy.prod(shape):
if len(unknowns) >= 10:
unknowns.append('...')
break
unknowns.append('%s: %s'%(name, str(shape)))
warnings.warn(
"Cannot decide shape for the following arguments " +
"(0s in shape means unknown dimensions). " +
"Consider providing them as input:\n\t" +
"\n\t".join(unknowns), stacklevel=2)
return res
except MXNetError:
print("infer_shape error. Arguments:")
for i, arg in enumerate(args):
Expand Down
15 changes: 14 additions & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,20 @@ int MXSymbolGetOutput(SymbolHandle symbol,

int MXSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out) {
return NNSymbolGetInternals(symbol, out);
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
*s = static_cast<nnvm::Symbol*>(symbol)->GetInternals();
*out = s;
API_END_HANDLE_ERROR(delete s);
}

int MXSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
*s = static_cast<nnvm::Symbol*>(symbol)->GetChildren();
*out = s;
API_END_HANDLE_ERROR(delete s);
}

int MXSymbolFree(SymbolHandle symbol) {
Expand Down
Loading

0 comments on commit dc47e06

Please sign in to comment.