Skip to content

Commit

Permalink
Merge pull request #18 from alvarorga/master
Browse files Browse the repository at this point in the history
Update sdf.py in order not to read mps
  • Loading branch information
juanjosegarciaripoll authored Feb 22, 2017
2 parents 5e4ab18 + 6e86121 commit ca0ab53
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions python/sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def combine_dataset(set):
output = {}
for name in list(set[0].keys()):
objs = [data[name] for data in set]
output[name] = np.stack(objs, axis=objs[0].ndim)
if type(objs[0]) is list:
output[name] = objs
else:
output[name] = np.stack(objs, axis=objs[0].ndim)
return output

class SDF:
Expand All @@ -50,13 +53,14 @@ def __init__(self, filename):
self.interpret = False # is the same endian?

def load(self):
self.f = self.filename.open("rb")
self.f = self.filename.open('rb')
output = {}
while self.f.readable():
obj, name = self.load_record()
if not name:
if len(name):
output[name] = obj
else:
break
output[name] = obj
self.f.close()
f = []
return output
Expand All @@ -68,20 +72,25 @@ def set_endian(self, newendian):
def load_record(self):
name, code = self.load_tag()
obj = []
if name:
if code == -1:
name = '';
obj = [];
elif code == 0:
obj = self.load_tensor(False)
elif code == 1:
obj = self.load_tensor(True)
else:
raise Error('Unknown SDF tag')
if code == -1:
name = '';
obj = [];
elif code == 0:
obj = self.load_tensor(False)
elif code == 1:
obj = self.load_tensor(True)
elif code == 2:
obj = self.load_mp(False)
elif code == 3:
obj = self.load_mp(True)
else:
raise Error('Unknown SDF tag')
return obj, name

def load_mp(self, iscomplex):
[ self.load_tensor(iscomplex) for i in range(sefl.read_longs(1)[0])]
L = self.read_longs(1)[0]
[self.load_record()[0] for i in range(L)]
return []

def load_tensor(self, iscomplex):
rank = self.read_longs(1)[0]
Expand Down Expand Up @@ -133,4 +142,6 @@ def load_tag(self):
name, sep, rest = name.partition(b'\x00')
name = str(name,'utf-8')
code = self.read_longs(1)[0]
else:
code = -1
return name, code

0 comments on commit ca0ab53

Please sign in to comment.