Skip to content

Commit

Permalink
make GRU/RNNStep blocks, issue 3084 to expose BlockArgumentsMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu committed Apr 13, 2018
1 parent 89db005 commit cd9285c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
18 changes: 9 additions & 9 deletions Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,17 @@ FunctionPtr GRUComponent(Variable input,
Constant &W, Constant &R, Constant &H1, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto dh2 = PlaceholderVariable(cellShape, input.DynamicAxes());
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());

auto gruCell = GRUCell(
inputPlaceholder,
fActivationOp, gActivationOp,
dh, W, R, H1, B);

auto actualDh = recurrenceHookH(gruCell);

gruCell->ReplacePlaceholders({{dh, actualDh}});

auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}}, L"GRU", L"");
auto actualDh = recurrenceHookH(dh2);
auto gruBlock = AsBlock(std::move(gruCell), {{inputPlaceholder, input}, {dh, actualDh}}, L"GRU", L"");
actualDh->ReplacePlaceholders({{dh2, gruBlock}});
return gruBlock;
}

Expand All @@ -385,17 +384,18 @@ FunctionPtr RNNComponent(Variable input,
Constant &W, Constant &R, Constant &B)
{
auto dh = PlaceholderVariable(cellShape, input.DynamicAxes());
auto dh2 = PlaceholderVariable(cellShape, input.DynamicAxes());
auto inputPlaceholder = PlaceholderVariable(input.Shape(), input.DynamicAxes());

auto rnnCell = RNNCell(
inputPlaceholder,
activationOp,
dh, W, R, B);

auto actualDh = recurrenceHookH(rnnCell);

rnnCell->ReplacePlaceholders({{inputPlaceholder, input}, {dh, actualDh}});
return rnnCell;
auto actualDh = recurrenceHookH(dh2);
auto rnnBlock = AsBlock(std::move(rnnCell), {{inputPlaceholder, input}, {dh, actualDh}}, L"RNNStep", L"");
actualDh->ReplacePlaceholders({{dh2, rnnBlock}});
return rnnBlock;
}

const std::vector<Variable> FindByNameHint(const std::vector<Variable> &inputs, const std::string &hint)
Expand Down
5 changes: 3 additions & 2 deletions bindings/common/CNTKManagedCommon.i
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ SWIG_STD_VECTOR_ENHANCED(CNTK::Learner)
%template(UnorderedMapStreamInformationPairNDArrayViewPtrNDArrayViewPtr) std::unordered_map<CNTK::StreamInformation, std::pair<std::shared_ptr<CNTK::NDArrayView>, std::shared_ptr<CNTK::NDArrayView>>>;
%template(ProgressWriterVector) std::vector<std::shared_ptr<CNTK::ProgressWriter>>;
%template(LearnerVector) std::vector<std::shared_ptr<CNTK::Learner>>;
%template(VariablePair) std::pair<CNTK::Variable, CNTK::Variable>;
%template(VariablePairVector) std::vector<std::pair<CNTK::Variable, CNTK::Variable>>;
%template(UnorderedMapStringDictionaryValue) std::unordered_map<std::wstring, CNTK::DictionaryValue>;
%template(PairSizeTDouble) std::pair<size_t, double>;
%template(VectorPairSizeTDouble) std::vector<std::pair<size_t, double>>;
Expand Down Expand Up @@ -196,7 +198,6 @@ IGNORE_FUNCTION CNTK::Function::Backward;
IGNORE_FUNCTION CNTK::Function::Forward;
IGNORE_FUNCTION CNTK::Function::Serialize;
IGNORE_FUNCTION CNTK::Function::Deserialize;
IGNORE_FUNCTION CNTK::Function::BlockArgumentsMapping;
IGNORE_FUNCTION CNTK::Function::Function;
IGNORE_FUNCTION CNTK::Function::RestoreFromCheckpoint;
IGNORE_FUNCTION CNTK::Function::Gradients;
Expand Down Expand Up @@ -332,7 +333,6 @@ RENAME_AND_MAKE_PRIVATE(CNTK::DeviceDescriptor, AllDevices);
MAKE_GETTER(CNTK::Axis, Name);

// class Function
IGNORE_FUNCTION CNTK::Function::BlockArgumentsMapping;
IGNORE_FUNCTION CNTK::GetCorrespondingOutputVariableFromClone;
IGNORE_FUNCTION CNTK::Function::RegisterUDFDeserializeCallback;
IGNORE_FUNCTION CNTK::Function::GetUDFDeserializeCallback;
Expand Down Expand Up @@ -431,6 +431,7 @@ IGNORE_FUNCTION CNTK::Function::Placeholders;
IGNORE_FUNCTION CNTK::Function::PrintGraph;
IGNORE_FUNCTION CNTK::Function::Constants;
IGNORE_FUNCTION CNTK::Function::Attributes;
IGNORE_FUNCTION CNTK::Function::BlockArgumentsMapping;
IGNORE_CLASS CNTK::Parameter;
IGNORE_CLASS CNTK::Constant;
IGNORE_ENUM_CLASS CNTK::PoolingType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
<Compile Include="SwigProxyClasses\UnorderedMapStringDictionaryValue.cs" />
<Compile Include="SwigProxyClasses\UnorderedMapVariableMinibatchData.cs" />
<Compile Include="SwigProxyClasses\UnsignedCharVector.cs" />
<Compile Include="SwigProxyClasses\VariablePair.cs" />
<Compile Include="SwigProxyClasses\VariablePairVector.cs" />
<Compile Include="SwigProxyClasses\VectorPairSizeTDouble.cs" />
<Compile Include="SwigProxyClasses\Axis.cs" />
<Compile Include="SwigProxyClasses\AxisVector.cs" />
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/cntk/tests/onnx_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def test_Reshape(tmpdir):
verify_one_input(model, data, tmpdir, 'Reshape_1')

#RNN
def test_GRU(tmpdir):
def test_RNN(tmpdir):
def CreatRNN(cell_dim,
activation,
initial_state,
Expand Down Expand Up @@ -869,7 +869,7 @@ def CreatRNN(cell_dim,
return_full_state = False, go_backwards=go_backward)])])

def MakeRNNNameFromConfig(direction, num_layers, initial_state, activition):
model_name = 'GRU.' + direction + '.'
model_name = 'RNN.' + direction + '.'

if num_layers == 1:
model_name += 'one_layer.'
Expand Down

0 comments on commit cd9285c

Please sign in to comment.