Skip to content

Commit

Permalink
Merge pull request #113 from bwintermann/main
Browse files Browse the repository at this point in the history
Updated partitioning function for PartitionFromDict
  • Loading branch information
maltanar authored Dec 16, 2024
2 parents 2c91d6d + 5d246e8 commit f08a869
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/qonnx/transformation/create_generic_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ def __init__(self, partitioning={}, partition_dir=None):
def apply(self, model):
# prepare node -> int assignment fct.
def partitioning_func(node):
partition_id = -1
for key in self.partitioning:
if node in list(model.graph.node) and list(model.graph.node).index(node) in list(self.partitioning[key]):
assert partition_id == -1, """single node assigned to multiple partitions"""
partition_id = key

return partition_id
if node not in model.graph.node:
return -1
node_index = list(model.graph.node).index(node)
candidates = list(filter(lambda key_value: node_index in key_value[1], self.partitioning.items()))
if len(candidates) == 0:
return -1
assert len(candidates) == 1, f"single node assigned to multiple partitions: {candidates}"
return candidates[0][0] # partition_id

# apply partitioning
model = model.transform(PartitionFromLambda(partitioning_func, self.partition_dir))
Expand Down

0 comments on commit f08a869

Please sign in to comment.