Skip to content

Commit

Permalink
Support for DataArray.expand_dims() (pydata#1347)
Browse files Browse the repository at this point in the history
* Started implementation.

* An empty commit to trigger Travis rebuild.

* Expand_dims implemented.

* Use Variables.expand_dims

* Fix the default naming scheme of xr.DataArray.dims

* Deprecate default argument for expand_dims(dim)

* Move expand_dims from DataArray into Dataset

* Rename Variable.expand_dims -> Variable.set_dims

* Negative axis support

* Add IndexError. Slight code cleaning.

* Small fix in docstrings.

* Raise exception specific for xr.Dataset

* Small cleanup. New section in reshaping.rst.
  • Loading branch information
fujiisoup authored and shoyer committed Apr 10, 2017
1 parent 61bb71d commit 444fce8
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 23 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Dataset contents
Dataset.merge
Dataset.rename
Dataset.swap_dims
Dataset.expand_dims
Dataset.drop
Dataset.set_coords
Dataset.reset_coords
Expand Down Expand Up @@ -223,6 +224,7 @@ DataArray contents
DataArray.pipe
DataArray.rename
DataArray.swap_dims
DataArray.expand_dims
DataArray.drop
DataArray.reset_coords
DataArray.copy
Expand Down
22 changes: 22 additions & 0 deletions doc/reshaping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose` or
ds.transpose('y', 'z', 'x')
ds.T
Expand and squeeze dimensions
-----------------------------

To expand a :py:class:`~xarray.DataArray` or all
variables on a :py:class:`~xarray.Dataset` along a new dimension,
use :py:meth:`~xarray.DataArray.expand_dims`

.. ipython:: python
expanded = ds.expand_dims('w')
expanded
This method attaches a new dimension with size 1 to all data variable.

To remove such a size-1 dimension from the py:class:`~xarray.DataArray`
or :py:class:`~xarray.Dataset`,
use :py:meth:`~xarray.DataArray.squeeze`

.. ipython:: python
expanded.squeeze('w')
Converting between datasets and arrays
--------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ The minor release includes bug-fixes and backwards compatible enhancements.

Enhancements
~~~~~~~~~~~~
- `expand_dims` on DataArray is newly supported (:issue:`1326`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- ``rolling`` on Dataset is now supported (:issue:`859`).

- ``.rolling()`` on Dataset is now supported (:issue:`859`).
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,26 +489,26 @@ def broadcast(*args, **kwargs):
if dim in arg.coords:
common_coords[dim] = arg.coords[dim].variable

def _expand_dims(var):
def _set_dims(var):
# Add excluded dims to a copy of dims_map
var_dims_map = dims_map.copy()
for dim in exclude:
with suppress(ValueError):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.expand_dims(var_dims_map)
return var.set_dims(var_dims_map)

def _broadcast_array(array):
data = _expand_dims(array.variable)
data = _set_dims(array.variable)
coords = OrderedDict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name,
attrs=array.attrs, encoding=array.encoding)

def _broadcast_dataset(ds):
data_vars = OrderedDict(
(k, _expand_dims(ds.variables[k]))
(k, _set_dims(ds.variables[k]))
for k in ds.data_vars)
coords = OrderedDict(ds.coords)
coords.update(common_coords)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def ensure_common_dims(vars):
if var.dims != common_dims:
common_shape = tuple(non_concat_dims.get(d, dim_len)
for d in common_dims)
var = var.expand_dims(common_dims, common_shape)
var = var.set_dims(common_dims, common_shape)
yield var

# stack up each variable to fill-out the dataset (in order)
Expand Down
27 changes: 27 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,33 @@ def swap_dims(self, dims_dict):
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)

def expand_dims(self, dim, axis=None):
"""Return a new object with an additional axis (or axes) inserted at the
corresponding position in the array shape.
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.
Parameters
----------
dim : str or sequence of str.
Dimensions to include on the new variable.
dimensions are inserted with length 1.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
ds = self._to_temp_dataset().expand_dims(dim, axis)
return self._from_temp_dataset(ds)

def set_index(self, append=False, inplace=False, **indexes):
"""Set DataArray (multi-)indexes using one or more existing coordinates.
Expand Down
88 changes: 86 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,90 @@ def swap_dims(self, dims_dict, inplace=False):
return self._replace_vars_and_dims(variables, coord_names,
inplace=inplace)

def expand_dims(self, dim, axis=None):
"""Return a new object with an additional axis (or axes) inserted at the
corresponding position in the array shape.
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.
Parameters
----------
dim : str or sequence of str.
Dimensions to include on the new variable.
dimensions are inserted with length 1.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
the same length list. If axis=None is passed, all the axes will
be inserted to the start of the result array.
Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
if isinstance(dim, int):
raise ValueError('dim should be str or sequence of strs or dict')

if isinstance(dim, basestring):
dim = [dim]
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]

if axis is None:
axis = list(range(len(dim)))

if len(dim) != len(axis):
raise ValueError('lengths of dim and axis should be identical.')
for d in dim:
if d in self.dims:
raise ValueError(
'Dimension {dim} already exists.'.format(dim=d))
if (d in self._variables and
not utils.is_scalar(self._variables[d])):
raise ValueError(
'{dim} already exists as coordinate or'
' variable name.'.format(dim=d))

if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')

variables = OrderedDict()
for k, v in iteritems(self._variables):
if k not in dim:
if k in self._coord_names: # Do not change coordinates
variables[k] = v
else:
result_ndim = len(v.dims) + len(axis)
for a in axis:
if a < -result_ndim or result_ndim - 1 < a:
raise IndexError(
'Axis {a} is out of bounds of the expanded'
' dimension size {dim}.'.format(
a=a, v=k, dim=result_ndim))

axis_pos = [a if a >= 0 else result_ndim + a
for a in axis]
if len(axis_pos) != len(set(axis_pos)):
raise ValueError('axis should not contain duplicate'
' values.')
# We need to sort them to make sure `axis` equals to the
# axis positions of the result array.
zip_axis_dim = sorted(zip(axis_pos, dim))

all_dims = list(v.dims)
for a, d in zip_axis_dim:
all_dims.insert(a, d)
variables[k] = v.set_dims(all_dims)
else:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
variables[k] = v.set_dims(k)

return self._replace_vars_and_dims(variables, self._coord_names)

def set_index(self, append=False, inplace=False, **indexes):
"""Set Dataset (multi-)indexes using one or more existing coordinates or
variables.
Expand Down Expand Up @@ -1704,7 +1788,7 @@ def _stack_once(self, dims, new_dim):
add_dims = [d for d in dims if d not in var.dims]
vdims = list(var.dims) + add_dims
shape = [self.dims[d] for d in vdims]
exp_var = var.expand_dims(vdims, shape)
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims})
variables[name] = stacked_var
else:
Expand Down Expand Up @@ -2245,7 +2329,7 @@ def to_array(self, dim='variable', name=None):

def _to_dataframe(self, ordered_dims):
columns = [k for k in self if k not in self.dims]
data = [self._variables[k].expand_dims(ordered_dims).values.reshape(-1)
data = [self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns]
index = self.coords.to_index(ordered_dims)
return pd.DataFrame(OrderedDict(zip(columns, data)), index=index)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def unique_variable(name, variables, compat='broadcast_equals'):

if compat == 'broadcast_equals':
dim_lengths = broadcast_dimension_size(variables)
out = out.expand_dims(dim_lengths)
out = out.set_dims(dim_lengths)

if compat == 'no_conflicts':
combine_method = 'fillna'
Expand Down
16 changes: 12 additions & 4 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,16 @@ def transpose(self, *dims):
data = ops.transpose(self.data, axes)
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)

def expand_dims(self, dims, shape=None):
"""Return a new variable with expanded dimensions.
def expand_dims(self, *args):
import warnings
warnings.warn('Variable.expand_dims is deprecated: use '
'Variable.set_dims instead', DeprecationWarning,
stacklevel=2)
return self.expand_dims(*args)

def set_dims(self, dims, shape=None):
"""Return a new variable with given set of dimensions.
This method might be used to attach new dimension(s) to variable.
When possible, this operation does not copy this variable's data.
Expand Down Expand Up @@ -1336,7 +1344,7 @@ def _unified_dims(variables):

def _broadcast_compat_variables(*variables):
dims = tuple(_unified_dims(variables))
return tuple(var.expand_dims(dims) if var.dims != dims else var
return tuple(var.set_dims(dims) if var.dims != dims else var
for var in variables)


Expand All @@ -1352,7 +1360,7 @@ def broadcast_variables(*variables):
"""
dims_map = _unified_dims(variables)
dims_tuple = tuple(dims_map)
return tuple(var.expand_dims(dims_map) if var.dims != dims_tuple else var
return tuple(var.set_dims(dims_map) if var.dims != dims_tuple else var
for var in variables)


Expand Down
94 changes: 94 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,100 @@ def test_swap_dims(self):
actual = array.swap_dims({'x': 'y'})
self.assertDataArrayIdentical(expected, actual)

def test_expand_dims_error(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3.0)},
attrs={'key': 'entry'})

with self.assertRaisesRegexp(ValueError, 'dim should be str or'):
array.expand_dims(0)
with self.assertRaisesRegexp(ValueError, 'lengths of dim and axis'):
# dims and axis argument should be the same length
array.expand_dims(dim=['a', 'b'], axis=[1, 2, 3])
with self.assertRaisesRegexp(ValueError, 'Dimension x already'):
# Should not pass the already existing dimension.
array.expand_dims(dim=['x'])
# raise if duplicate
with self.assertRaisesRegexp(ValueError, 'duplicate values.'):
array.expand_dims(dim=['y', 'y'])
with self.assertRaisesRegexp(ValueError, 'duplicate values.'):
array.expand_dims(dim=['y', 'z'], axis=[1, 1])
with self.assertRaisesRegexp(ValueError, 'duplicate values.'):
array.expand_dims(dim=['y', 'z'], axis=[2, -2])

# out of bounds error, axis must be in [-4, 3]
with self.assertRaises(IndexError):
array.expand_dims(dim=['y', 'z'], axis=[2, 4])
with self.assertRaises(IndexError):
array.expand_dims(dim=['y', 'z'], axis=[2, -5])
# Does not raise an IndexError
array.expand_dims(dim=['y', 'z'], axis=[2, -4])
array.expand_dims(dim=['y', 'z'], axis=[2, 3])

def test_expand_dims(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
# pass only dim label
actual = array.expand_dims(dim='y')
expected = DataArray(np.expand_dims(array.values, 0),
dims=['y', 'x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
self.assertDataArrayIdentical(expected, actual)
roundtripped = actual.squeeze('y', drop=True)
self.assertDatasetIdentical(array, roundtripped)

# pass multiple dims
actual = array.expand_dims(dim=['y', 'z'])
expected = DataArray(np.expand_dims(np.expand_dims(array.values, 0),
0),
dims=['y', 'z', 'x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
self.assertDataArrayIdentical(expected, actual)
roundtripped = actual.squeeze(['y', 'z'], drop=True)
self.assertDatasetIdentical(array, roundtripped)

# pass multiple dims and axis. Axis is out of order
actual = array.expand_dims(dim=['z', 'y'], axis=[2, 1])
expected = DataArray(np.expand_dims(np.expand_dims(array.values, 1),
2),
dims=['x', 'y', 'z', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
self.assertDataArrayIdentical(expected, actual)
# make sure the attrs are tracked
self.assertTrue(actual.attrs['key'] == 'entry')
roundtripped = actual.squeeze(['z', 'y'], drop=True)
self.assertDatasetIdentical(array, roundtripped)

# Negative axis and they are out of order
actual = array.expand_dims(dim=['y', 'z'], axis=[-1, -2])
expected = DataArray(np.expand_dims(np.expand_dims(array.values, -1),
-1),
dims=['x', 'dim_0', 'z', 'y'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
self.assertDataArrayIdentical(expected, actual)
self.assertTrue(actual.attrs['key'] == 'entry')
roundtripped = actual.squeeze(['y', 'z'], drop=True)
self.assertDatasetIdentical(array, roundtripped)

def test_expand_dims_with_scalar_coordinate(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0},
attrs={'key': 'entry'})
actual = array.expand_dims(dim='z')
expected = DataArray(np.expand_dims(array.values, 0),
dims=['z', 'x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3),
'z': np.ones(1)},
attrs={'key': 'entry'})
self.assertDataArrayIdentical(expected, actual)
roundtripped = actual.squeeze(['z'], drop=False)
self.assertDatasetIdentical(array, roundtripped)

def test_set_index(self):
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names]
coords = {idx.name: ('x', idx) for idx in indexes}
Expand Down
Loading

0 comments on commit 444fce8

Please sign in to comment.