Skip to content

Commit

Permalink
Merge pull request scikit-hep#99 from scikit-hep/issue-98
Browse files Browse the repository at this point in the history
Clean up any issues of losing table keys after masking or indexing
  • Loading branch information
jpivarski authored Mar 9, 2019
2 parents bd803ca + 181b5a6 commit c7c75d5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
81 changes: 49 additions & 32 deletions awkward/array/jagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def fromindex(cls, index, content, validate=True):

@classmethod
def fromjagged(cls, jagged):
jagged = jagged._tojagged(copy=False)
return cls(jagged._starts, jagged._stops, jagged._content)

@classmethod
Expand Down Expand Up @@ -432,8 +431,9 @@ def parents(self, value):

@property
def index(self):
out = self.numpy.arange(len(self._content), dtype=self.INDEXTYPE)
return self.copy(content=(out - out[self._starts[self.parents]]))
tmp = self.compact()
out = self.numpy.arange(len(tmp._content), dtype=self.INDEXTYPE)
return self.copy(starts=tmp._starts, stops=tmp._stops, content=(out - tmp._starts[tmp.parents]))

def _getnbytes(self, seen):
if id(self) in seen:
Expand Down Expand Up @@ -529,7 +529,7 @@ def __getitem__(self, where):
headoffsets = self.counts2offsets(head.counts)
head = head._tojagged(headoffsets[:-1], headoffsets[1:], copy=False)

counts = head._broadcast(self.counts)._content
counts = head.tojagged(self.counts)._content

indexes = self.numpy.array(head._content[:headoffsets[-1]], copy=True)

Expand All @@ -539,7 +539,7 @@ def __getitem__(self, where):
if not self.numpy.bitwise_and(0 <= indexes, indexes < counts).all():
raise IndexError("jagged array used as index contains out-of-bounds values")

indexes += head._broadcast(self._starts)._content
indexes += head.tojagged(self._starts)._content

return self.copy(starts=head._starts, stops=head._stops, content=self._content[indexes])

Expand Down Expand Up @@ -753,34 +753,51 @@ def __getitem__(self, where):

def __setitem__(self, where, what):
if isinstance(where, awkward.util.string):
if isinstance(what, JaggedArray):
self._content[where] = what._tojagged(self._starts, self._stops, copy=False)._content
else:
self._content[where] = self._broadcast(what)._content
self._content[where] = self.tojagged(what)._content

elif self._util_isstringslice(where):
if len(where) != len(what):
raise ValueError("number of keys ({0}) does not match number of provided arrays ({1})".format(len(where), len(what)))
for x, y in zip(where, what):
if isinstance(y, JaggedArray):
self._content[x] = y._tojagged(self._starts, self._stops, copy=False)._content
else:
self._content[x] = self._broadcast(y)._content
self._content[x] = self.tojagged(y)._content

else:
raise TypeError("invalid index for assigning column to Table: {0}".format(where))

def _broadcast(self, data):
data = self._util_toarray(data, self._content.dtype)
good = (self.parents >= 0)
content = self.numpy.empty(len(self.parents), dtype=data.dtype)
if len(data.shape) == 0:
content[good] = data
def tojagged(self, data):
if isinstance(data, JaggedArray):
if not self.numpy.array_equal(self.counts, data.counts):
raise ValueError("cannot broadcast JaggedArray to match JaggedArray with a different counts")
if len(self._starts) == 0:
return self.copy(content=data._content)
data = data.compact()
return self.copy(content=data._content[self.IndexedArray.invert((self.index + self._starts)._content)])

elif isinstance(data, awkward.array.base.AwkwardArray):
if len(self._starts) != len(data):
raise ValueError("cannot broadcast AwkwardArray to match JaggedArray with a different length")
if len(self._starts) == 0:
return self.copy(content=data)
out = self.copy(content=data[self.parents])
out._parents = self.parents
return out

elif isinstance(data, self.numpy.ndarray):
content = self.numpy.empty(len(self.parents), dtype=data.dtype)
if len(data.shape) == 0 or (len(data.shape) == 1 and data.shape[0] == 1):
content[:] = data
else:
good = (self.parents >= 0)
content[good] = data[self.parents[good]]
out = self.copy(content=content)
out._parents = self.parents
return out

elif isinstance(data, Iterable):
return self.tojagged(self.numpy.array(data))

else:
content[good] = data[self.parents[good]]
out = self.copy(content=content)
out._parents = self.parents
return out
return self.tojagged(self.numpy.array([data]))

def _tojagged(self, starts=None, stops=None, copy=True):
if starts is None and stops is None:
Expand Down Expand Up @@ -1097,7 +1114,7 @@ def argcross(self, other, nested=False):
out["1"] = out["1"] - other._starts

if nested:
out = self.JaggedArray.fromcounts(self.counts, self.JaggedArray.fromcounts(self._broadcast(other.counts).flatten(), out._content))
out = self.JaggedArray.fromcounts(self.counts, self.JaggedArray.fromcounts(self.tojagged(other.counts).flatten(), out._content))

return out

Expand All @@ -1116,15 +1133,15 @@ def cross(self, other, nested=False):

if nested:
old = out
out = self.JaggedArray.fromcounts(thyself.counts, self.JaggedArray.fromcounts(thyself._broadcast(other.counts).flatten(), out._content))
out = self.JaggedArray.fromcounts(thyself.counts, self.JaggedArray.fromcounts(thyself.tojagged(other.counts).flatten(), out._content))
out._nestedcross = old

if hasattr(self, "_nestedcross"):
counts = out.counts.copy()
mask = (self.counts != 0)
counts[mask] //= self.counts[mask]
old = out
out = self.JaggedArray.fromcounts(self.counts, self.JaggedArray.fromcounts(self._broadcast(counts).flatten(), out._content))
out = self.JaggedArray.fromcounts(self.counts, self.JaggedArray.fromcounts(self.tojagged(counts).flatten(), out._content))
out._nestedcross = old

return out
Expand Down Expand Up @@ -1517,13 +1534,13 @@ def ready(x):
if isinstance(x, JaggedArray):
columns1[n] = x._content
elif isinstance(x, Iterable):
columns1[n] = first._broadcast(x)._content
columns1[n] = first.tojagged(x)._content
elif isinstance(x, (numbers.Number, numpy.number, numpy.bool, numpy.bool_)):
columns1[n] = JaggedArray(first._starts, first._stops, numpy.full(first._stops.max(), columns1, dtype=type(columns1)))._content
else:
raise TypeError("unrecognized type for JaggedArray.zip: {0}".format(type(x)))
elif isinstance(columns1, Iterable):
columns1 = first._broadcast(columns1)._content
columns1 = first.tojagged(columns1)._content
elif isinstance(columns1, (numbers.Number, numpy.number, numpy.bool, numpy.bool_)):
columns1 = JaggedArray(first._starts, first._stops, numpy.full(first._stops.max(), columns1, dtype=type(columns1)))._content
else:
Expand All @@ -1534,7 +1551,7 @@ def ready(x):
if isinstance(x, JaggedArray):
columns2[i] = x._content
elif not isinstance(x, dict) and isinstance(x, Iterable):
columns2[i] = first._broadcast(x)._content
columns2[i] = first.tojagged(x)._content
elif isinstance(x, (numbers.Number, numpy.number, numpy.bool, numpy.bool_)):
columns2[i] = JaggedArray(first._starts, first._stops, numpy.full(first._stops.max(), x, dtype=type(x)))._content
else:
Expand All @@ -1545,7 +1562,7 @@ def ready(x):
if isinstance(x, JaggedArray):
columns3[n] = x._content
elif not isinstance(x, dict) and isinstance(x, Iterable):
columns3[n] = first._broadcast(x)._content
columns3[n] = first.tojagged(x)._content
elif isinstance(x, (numbers.Number, numpy.number, numpy.bool, numpy.bool_)):
columns3[n] = JaggedArray(first._starts, first._stops, numpy.full(first._stops.max(), x, dtype=type(x)))._content
else:
Expand Down Expand Up @@ -1574,8 +1591,8 @@ def pandas(self):
out = self._content.pandas()

if isinstance(self._content, JaggedArray):
parents = self._content._broadcast(self.parents)._content
index = self._content._broadcast(self.index._content)._content
parents = self._content.tojagged(self.parents)._content
index = self._content.tojagged(self.index._content)._content
out.index = pandas.MultiIndex.from_arrays([parents, index] + out.index.labels[1:])
else:
out.index = pandas.MultiIndex.from_arrays([self.parents, self.index._content])
Expand Down
8 changes: 8 additions & 0 deletions awkward/array/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,12 +581,20 @@ def __setitem__(self, where, what):
raise ValueError("new columns can only be attached to the original Table, not a view (try table.base['col'] = array)")

if isinstance(where, awkward.util.string):
try:
len(what)
except TypeError:
what = self.numpy.full(len(self), what)
self._contents[where] = self._util_toarray(what, self.DEFAULTTYPE)

elif self._util_isstringslice(where):
if len(where) != len(what):
raise ValueError("number of keys ({0}) does not match number of provided arrays ({1})".format(len(where), len(what)))
for x, y in zip(where, what):
try:
len(y)
except TypeError:
y = self.numpy.full(len(self), y)
self._contents[x] = self._util_toarray(y, self.DEFAULTTYPE)

else:
Expand Down
2 changes: 1 addition & 1 deletion awkward/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def recurse(obj, mask):
elif isinstance(obj, awkward.array.jagged.JaggedArray):
obj = obj.compact()
if mask is not None:
mask = obj._broadcast(mask).flatten()
mask = obj.tojagged(mask).flatten()
return pyarrow.ListArray.from_arrays(obj.offsets, recurse(obj.content, mask))

elif isinstance(obj, awkward.array.masked.IndexedMaskedArray):
Expand Down
2 changes: 1 addition & 1 deletion awkward/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import re

__version__ = "0.8.8"
__version__ = "0.8.9"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down

0 comments on commit c7c75d5

Please sign in to comment.