Skip to content

Commit

Permalink
!9142 fix bug of addn not supported dynamic shape
Browse files Browse the repository at this point in the history
From: @M20211031
Reviewed-by: @zh_qh
Signed-off-by: @zh_qh
  • Loading branch information
mindspore-bot authored and gitee-org committed Dec 2, 2020
2 parents 4f032cf + 59f9ac9 commit 79fef0d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 27 deletions.
106 changes: 79 additions & 27 deletions mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,87 @@ py::object BuildValue(const ValuePtr &value_ptr) {
return ValuePtrToPyData(value_ptr);
}
}

py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
py::tuple shape_tuple(len);
py::tuple dtype_tuple(len);
py::tuple min_shape_tuple(len);
py::tuple max_shape_tuple(len);
auto dic = py::dict();
bool dyn_shape = false;

for (size_t i = 0; i < len; i++) {
auto arg = arg_tuple->elements()[i];
py::dict out = ConvertAbstractToPython(arg);
shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE];

// Elements in tuple is tensor, which shape is dynamic.
if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
dyn_shape = true;
} else {
min_shape_tuple[i] = out[ATTR_SHAPE];
max_shape_tuple[i] = out[ATTR_SHAPE];
}
}
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());

if (dyn_shape) {
dic[ATTR_MIN_SHAPE] = min_shape_tuple;
dic[ATTR_MAX_SHAPE] = max_shape_tuple;
}

return dic;
}

py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
auto arg_list = dyn_cast<AbstractList>(abs_base);
size_t len = arg_list->size();
py::list shape_list(len);
py::list dtype_list(len);
py::list min_shape_list(len);
py::list max_shape_list(len);
auto dic = py::dict();
bool dyn_shape = false;

for (size_t i = 0; i < len; i++) {
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
shape_list[i] = out[ATTR_SHAPE];
dtype_list[i] = out[ATTR_DTYPE];

// Elements in list is tensor, which shape is dynamic.
if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
min_shape_list[i] = out[ATTR_MIN_SHAPE];
max_shape_list[i] = out[ATTR_MAX_SHAPE];
dyn_shape = true;
} else {
min_shape_list[i] = out[ATTR_SHAPE];
max_shape_list[i] = out[ATTR_SHAPE];
}
}

if (dyn_shape) {
dic[ATTR_MIN_SHAPE] = min_shape_list;
dic[ATTR_MAX_SHAPE] = max_shape_list;
}

dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());

return dic;
}
} // end anonymous namespace

py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
MS_EXCEPTION_IF_NULL(abs_base);
py::dict dic;
auto dic = py::dict();
if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
Expand Down Expand Up @@ -311,33 +387,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_DTYPE] = py::ellipsis();
dic[ATTR_VALUE] = py::ellipsis();
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
py::tuple shape_tuple(len);
py::tuple dtype_tuple(len);

for (size_t i = 0; i < len; i++) {
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE];
}
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
return AbstractTupleToPython(abs_base);
} else if (abs_base->isa<AbstractList>()) {
auto arg_list = dyn_cast<AbstractList>(abs_base);
size_t len = arg_list->size();
py::list shape_list(len);
py::list dtype_list(len);

for (size_t i = 0; i < len; i++) {
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
shape_list[i] = out[ATTR_SHAPE];
dtype_list[i] = out[ATTR_DTYPE];
}
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
return AbstractListToPython(abs_base);
} else if (abs_base->isa<AbstractNone>()) {
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::none();
Expand Down
18 changes: 18 additions & 0 deletions tests/ut/python/ops/test_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,21 @@ def construct(self, x, y):
y = Tensor(np.ones([8], dtype=np.int32))
net = Net()
net(x, y)


def test_addn():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.unq = P.Unique()
self.addn = P.AddN()

def construct(self, x):
u, _ = self.unq(x)
u = self.addn((u, u, u))
z = self.addn([u, u])
return z

y = Tensor(np.ones([8], dtype=np.int32))
net = Net()
net(y)

0 comments on commit 79fef0d

Please sign in to comment.