forked from Rajeevveera24/pytorch-copy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_digraph.py
174 lines (144 loc) · 5.53 KB
/
_digraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# mypy: allow-untyped-defs
from collections import deque
from typing import List, Set
class DiGraph:
"""Really simple unweighted directed graph data structure to track dependencies.
The API is pretty much the same as networkx so if you add something just
copy their API.
"""
def __init__(self):
# Dict of node -> dict of arbitrary attributes
self._node = {}
# Nested dict of node -> successor node -> nothing.
# (didn't implement edge data)
self._succ = {}
# Nested dict of node -> predecessor node -> nothing.
self._pred = {}
# Keep track of the order in which nodes are added to
# the graph.
self._node_order = {}
self._insertion_idx = 0
def add_node(self, n, **kwargs):
"""Add a node to the graph.
Args:
n: the node. Can we any object that is a valid dict key.
**kwargs: any attributes you want to attach to the node.
"""
if n not in self._node:
self._node[n] = kwargs
self._succ[n] = {}
self._pred[n] = {}
self._node_order[n] = self._insertion_idx
self._insertion_idx += 1
else:
self._node[n].update(kwargs)
def add_edge(self, u, v):
"""Add an edge to graph between nodes ``u`` and ``v``
``u`` and ``v`` will be created if they do not already exist.
"""
# add nodes
self.add_node(u)
self.add_node(v)
# add the edge
self._succ[u][v] = True
self._pred[v][u] = True
def successors(self, n):
"""Returns an iterator over successor nodes of n."""
try:
return iter(self._succ[n])
except KeyError as e:
raise ValueError(f"The node {n} is not in the digraph.") from e
def predecessors(self, n):
"""Returns an iterator over predecessors nodes of n."""
try:
return iter(self._pred[n])
except KeyError as e:
raise ValueError(f"The node {n} is not in the digraph.") from e
@property
def edges(self):
"""Returns an iterator over all edges (u, v) in the graph"""
for n, successors in self._succ.items():
for succ in successors:
yield n, succ
@property
def nodes(self):
"""Returns a dictionary of all nodes to their attributes."""
return self._node
def __iter__(self):
"""Iterate over the nodes."""
return iter(self._node)
def __contains__(self, n):
"""Returns True if ``n`` is a node in the graph, False otherwise."""
try:
return n in self._node
except TypeError:
return False
def forward_transitive_closure(self, src: str) -> Set[str]:
"""Returns a set of nodes that are reachable from src"""
result = set(src)
working_set = deque(src)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.successors(cur):
if n not in result:
result.add(n)
working_set.append(n)
return result
def backward_transitive_closure(self, src: str) -> Set[str]:
"""Returns a set of nodes that are reachable from src in reverse direction"""
result = set(src)
working_set = deque(src)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.predecessors(cur):
if n not in result:
result.add(n)
working_set.append(n)
return result
def all_paths(self, src: str, dst: str):
"""Returns a subgraph rooted at src that shows all the paths to dst."""
result_graph = DiGraph()
# First compute forward transitive closure of src (all things reachable from src).
forward_reachable_from_src = self.forward_transitive_closure(src)
if dst not in forward_reachable_from_src:
return result_graph
# Second walk the reverse dependencies of dst, adding each node to
# the output graph iff it is also present in forward_reachable_from_src.
# we don't use backward_transitive_closures for optimization purposes
working_set = deque(dst)
while len(working_set) > 0:
cur = working_set.popleft()
for n in self.predecessors(cur):
if n in forward_reachable_from_src:
result_graph.add_edge(n, cur)
# only explore further if its reachable from src
working_set.append(n)
return result_graph.to_dot()
def first_path(self, dst: str) -> List[str]:
"""Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
path = []
while dst:
path.append(dst)
candidates = self._pred[dst].keys()
dst, min_idx = "", None
for candidate in candidates:
idx = self._node_order.get(candidate, None)
if idx is None:
break
if min_idx is None or idx < min_idx:
min_idx = idx
dst = candidate
return list(reversed(path))
def to_dot(self) -> str:
"""Returns the dot representation of the graph.
Returns:
A dot representation of the graph.
"""
edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
return f"""\
digraph G {{
rankdir = LR;
node [shape=box];
{edges}
}}
"""