Skip to content

Commit

Permalink
Fix invalidArguments to take kwargs and out into account (pytorch#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Jan 5, 2017
1 parent c976dd3 commit 0e345aa
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 105 deletions.
2 changes: 1 addition & 1 deletion tools/cwrap/plugins/CuDNNPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CuDNNPlugin(CWrapPlugin):
$options
}
THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
THPUtils_invalidArguments(args, kwargs, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
Expand Down
2 changes: 1 addition & 1 deletion tools/cwrap/plugins/StandaloneExtension.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class StandaloneExtension(CWrapPlugin):
int __argcount = args ? PyTuple_Size(args) : 0;
$options
} else {
THPUtils_invalidArguments(args, "$name", 1, $expected_args);
THPUtils_invalidArguments(args, NULL, "$name", 1, $expected_args);
return NULL;
}
END_HANDLE_TH_ERRORS
Expand Down
14 changes: 13 additions & 1 deletion tools/cwrap/plugins/THPPlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class THPPlugin(CWrapPlugin):
$options
}
THPUtils_invalidArguments(args, "$readable_name", $num_options, $expected_args);
THPUtils_invalidArguments(args, kwargs, "$readable_name", $num_options, $expected_args);
return NULL;
END_HANDLE_TH_ERRORS
}
Expand Down Expand Up @@ -200,6 +200,18 @@ def format_args(args, var_args=False):
for arg in args
if not arg.get('ignore_check', False)
and not arg.get('output')]
output_args = list(filter(lambda a: a.get('output'), args))
if output_args:
if len(output_args) > 1:
out_type = 'tuple['
out_type += ', '.join(
self.TYPE_NAMES[arg['type']] for arg in output_args)
out_type += ']'
option_desc += ['#' + out_type + ' out']
else:
arg = output_args[0]
option_desc += ['#' + self.TYPE_NAMES[arg['type']] + ' out']

if option_desc:
return '({})'.format(', '.join(option_desc))
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
PyObject *raw_grad_output = PyTuple_GET_ITEM(args, 0);
PyObject *retain_variables = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(raw_grad_output) || !PyBool_Check(retain_variables)) {
THPUtils_invalidArguments(args, "_do_backward", 1, "(tuple, bool)");
THPUtils_invalidArguments(args, NULL, "_do_backward", 1, "(tuple, bool)");
return NULL;
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/generic/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
}

invalid_arguments:
THPUtils_invalidArguments(args, THPStorageStr " constructor", 6,
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
"no arguments",
"(int size)",
"(Sequence data)",
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/generic/StorageSharing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ static PyObject * THPStorage_(newSharedFilename)(PyObject *_unused, PyObject *ar
PyObject *_object_handle = PyTuple_GET_ITEM(args, 1);
PyObject *_size = PyTuple_GET_ITEM(args, 2);
if (!THPUtils_checkBytes(_manager_handle) || !THPUtils_checkBytes(_object_handle) || !THPUtils_checkLong(_size)) {
THPUtils_invalidArguments(args, "_new_shared in file system mode", 1, "a handle (string/bytes) and storage size (int)");
THPUtils_invalidArguments(args, NULL, "_new_shared in file system mode", 1,
"a handle (string/bytes) and storage size (int)");
return NULL;
}
const char *manager_handle = THPUtils_bytesAsString(_manager_handle);
Expand Down Expand Up @@ -163,7 +164,8 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args)
PyObject *_tmp_fd = PyTuple_GET_ITEM(args, 0);
PyObject *_size = PyTuple_GET_ITEM(args, 1);
if (!THPUtils_checkLong(_tmp_fd) || !THPUtils_checkLong(_size)) {
THPUtils_invalidArguments(args, "_new_shared in file descriptor mode", 1, "a file descriptor (int) and storage size (int)");
THPUtils_invalidArguments(args, NULL, "_new_shared in file descriptor mode",
1, "a file descriptor (int) and storage size (int)");
return NULL;
}
int fd;
Expand Down Expand Up @@ -233,7 +235,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size)
&& (_handle == Py_None || PyBytes_Check(_handle))
&& THPUtils_checkLong(_offset) && THPUtils_checkLong(_view_size))) {
THPUtils_invalidArguments(args, "_new_shared in CUDA mode", 1,
THPUtils_invalidArguments(args, NULL, "_new_shared in CUDA mode", 1,
"(int device, bytes handle, int storage_size, int offset, int view_size");
return NULL;
}
Expand Down Expand Up @@ -357,7 +359,7 @@ static PyObject * THPStorage_(newView)(THPStorage *self, PyObject *args)
HANDLE_TH_ERRORS
if (PyTuple_Size(args) != 2 || !THPUtils_checkLong(PyTuple_GET_ITEM(args, 0))
|| ! THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, "_new_view", 1, "(int offset, int size)");
THPUtils_invalidArguments(args, NULL, "_new_view", 1, "(int offset, int size)");
return NULL;
}
long offset = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/generic/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
return (PyObject *)self.release();
}

THPUtils_invalidArguments(args, THPTensorStr " constructor", 6,
THPUtils_invalidArguments(args, kwargs, THPTensorStr " constructor", 6,
"no arguments",
"(int ...)",
"(" THPTensorStr " viewed_tensor)",
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/generic/methods/Tensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ PyObject * THPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
return PyInt_FromLong(THTensor_(size)(LIBRARY_STATE tensor, dim));
}

THPUtils_invalidArguments(args, "size", 2, "(int dim)", "no arguments");
THPUtils_invalidArguments(args, kwargs, "size", 2, "(int dim)", "no arguments");
return NULL;
END_HANDLE_TH_ERRORS
}
Expand Down Expand Up @@ -280,7 +280,7 @@ PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs)
return PyInt_FromLong(THTensor_(stride)(LIBRARY_STATE tensor, dim));
}

THPUtils_invalidArguments(args, "stride", 2, "(int dim)", "no arguments");
THPUtils_invalidArguments(args, kwargs, "stride", 2, "(int dim)", "no arguments");
return NULL;
END_HANDLE_TH_ERRORS
}
Expand Down Expand Up @@ -650,9 +650,9 @@ static PyObject * THPTensor_stateless_(cat)(THPTensor *_unused, PyObject *args)
return (PyObject*)result.release();

invalid_arguments:
THPUtils_invalidArguments(args, "cat", 2,
"(sequence tensors)",
"(sequence tensors, int dim)");
THPUtils_invalidArguments(args, NULL, "cat", 2,
"(sequence[" THPTensorStr "] tensors)",
"(sequence[" THPTensorStr "] tensors, int dim)");
return NULL;
END_HANDLE_TH_ERRORS
}
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/generic/methods/TensorSerialization.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
PyObject * THPTensor_(writeMetadata)(THPTensor *self, PyObject *args)
{
if (!args || PyTuple_Size(args) != 1) {
THPUtils_invalidArguments(args, "_write_metadata", 1, "a single file object");
THPUtils_invalidArguments(args, NULL, "_write_metadata", 1, "a single file object");
return NULL;
}
int fd = PyObject_AsFileDescriptor(PyTuple_GET_ITEM(args, 0));
Expand All @@ -28,7 +28,8 @@ PyObject * THPTensor_(newWithMetadataFile)(PyObject *_null, PyObject *args)
{
if (!args || PyTuple_Size(args) != 2 ||
!THPStorage_(Check)(PyTuple_GET_ITEM(args, 1))) {
THPUtils_invalidArguments(args, "_new_with_metadata_file", 1, "single file object and a storage object");
THPUtils_invalidArguments(args, NULL, "_new_with_metadata_file", 1,
"single file object and a storage object");
return NULL;
}
int fd = PyObject_AsFileDescriptor(PyTuple_GET_ITEM(args, 0));
Expand Down
Loading

0 comments on commit 0e345aa

Please sign in to comment.