Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for iterative pruning based on user-specified discard conditions #214

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
99f90c6
Adds functionality to iteratively prune paths
ns-rse Jul 20, 2023
64e8593
Pruning loops and linear skeletons to retain longest path
ns-rse Jul 28, 2023
fd02995
Loops and linear skeletons retain longest path
ns-rse Jul 28, 2023
ae0bf51
Adding examples for pruning
ns-rse Aug 14, 2023
b2fb6cc
Adding examples for pruning
ns-rse Aug 14, 2023
010bda3
Merge branch 'main' into ns-rse/iterative-pruning
ns-rse Oct 16, 2023
5b6d5a3
Add function to make nx Graph from Skeleton
jni Oct 16, 2023
482d2d1
Add function to compute main branches in nx directly
jni Oct 16, 2023
eac5da3
Initial work on networkx-based pruning
jni Oct 16, 2023
19a61cc
Tweak to text
ns-rse Oct 16, 2023
0ad90df
Networkx <--> Skeleton development and tests
ns-rse Oct 18, 2023
fbc0988
Adds array_to_nx() and nx_to_array() functions
ns-rse Nov 8, 2023
e2d5110
Correcting doc string
ns-rse Nov 8, 2023
9eb836b
Merge branch 'main' into ns-rse/iterative-pruning
ns-rse Nov 22, 2023
db48a41
Correcting import order
ns-rse Nov 22, 2023
b1c4127
Correcting test, had chopped part of key fixture resolving conflicts
ns-rse Nov 22, 2023
a44dfbd
Merge branch 'main' into ns-rse/iterative-pruning
ns-rse May 1, 2024
3854de6
Rebase on main
ns-rse Nov 27, 2023
643cea7
Add full data dict to graph attributes
jni Jun 29, 2024
a62d9bb
Use multi-graph edges with keys everywhere
jni Jun 29, 2024
4c2f37f
Add example predicate with is_endpoint
jni Jun 29, 2024
8969838
Handle self-edges / redundant edges
jni Jun 29, 2024
20d6470
Updates to iterative pruning example
jni Jun 29, 2024
cfa9d36
Merge branch 'main' into ns-rse/iterative-pruning
jni Jun 29, 2024
b642c29
Use minimum buffer size of 1 to support empty ims
jni Jun 29, 2024
01f338f
Use keep_images=True when creating skeletons
jni Jun 29, 2024
141ccb4
Include values an indices in merged edges
jni Jun 29, 2024
669eeba
Add todo note for remaining attributes
jni Jun 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Initial work on networkx-based pruning
  • Loading branch information
jni committed Oct 16, 2023
commit eac5da3ad96714f64ebdb26f5d6acbdd00329b08
238 changes: 142 additions & 96 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy.typing as npt
import numba
import warnings
from typing import Tuple
from typing import Tuple, Callable
from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches

Expand Down Expand Up @@ -1220,15 +1220,60 @@ def skeleton_to_nx(skeleton: Skeleton, summary: pd.DataFrame | None = None):
return g


def _merge_paths(p1: npt.NDArray, p2: npt.NDArray):
"""Join two paths together that have a common endpoint."""
return np.concatenate([p1[:-1], p2], axis=0)


def _merge_edges(g: nx.Graph, e1: tuple[int], e2: tuple[int]):
middle_node = set(e1) & set(e2)
new_edge = sorted(
(set(e1) | set(e2)) - {middle_node},
key=lambda i: i in e2,
)
d1 = g.edges[e1]
d2 = g.edges[e2]
p1 = d1['path'] if e1[1] == middle_node else d1['path'][::-1]
p2 = d2['path'] if e2[0] == middle_node else d2['path'][::-1]
n1 = len(d1['path'])
n2 = len(d2['path'])
new_edge_values = {
'skeleton_id':
g.edges[e1]['skeleton_id'],
'node_id_src':
new_edge[0],
'node_id_dst':
new_edge[1],
'branch_distance':
d1['branch_distance'] + d2['branch_distance'],
'branch_type':
min(d1['branch_type'], d2['branch_type']),
'mean_pixel_value': (
n1 * d1['mean_pixel_value'] + n2 * d2['mean_pixel_value']
) / (n1+n2),
'stdev_pixel_value':
np.sqrt((
d1['stdev_pixel_value']**2 *
(n1-1) + d2['stdev_pixel_value']**2 * (n2-1)
) / (n1+n2-1)),
'path':
_merge_paths(p1, p2),
}
g.add_edge(new_edge[0], new_edge[1], **new_edge_values)
g.remove_node(middle_node)


def _remove_simple_path_nodes(g):
"""Remove any nodes of degree 2 by merging their incident edges."""
to_remove = [n for n in g.nodes if g.degree(n) == 2]
for u in to_remove:
v, w = g[u].keys()
_merge_edges(g, (u, v), (u, w))


def iteratively_prune_paths(
skeleton: np.ndarray | Skeleton,
min_skeleton: int = 1,
spacing: int = 1,
source_image: np.ndarray = None,
keep_images: bool = True,
value_is_height: bool = False,
find_main_branch: bool = True,
imgname: str = None
skeleton: nx.Graph,
discard: Callable[[nx.Graph, dict], bool],
) -> Skeleton:
"""Iteratively prune a skeleton leaving the specified number of paths.

Expand All @@ -1243,100 +1288,101 @@ def iteratively_prune_paths(
----------
skeleton: np.ndarray | Skeleton
Skeleton object to be pruned, may be a binary Numpy array or a Skeleton.
min_skeleton: int
Minimum paths for a skeleton, default is 1 but you may wish to retain more.
spacing: int
Scale of pixel spacing passed to Skeleton
source_image: np.ndarray
Image from which the skeleton was generated passed to Skeleton.
keep_images: bool
Whether or not to keep the original input images (passed to Skeleton).
value_is_height: bool
Whether to consider the value of a float skeleton to be the "height" of the image (passed to Skeleton).
find_main_branch: bool
Whether to find the main branch of a skeleton. If False then skeletons can be pruned more than might be
expected. If True the longest path is identified using the find_main_branches() utility.
discard : Callable[[nx.Graph, dict], bool]
A predicate that is True if the edge should be discarded. The input is
a dictionary of all the attributes of that edge — the same as the
columns in the output of `summarize`.

Returns
-------
Skeleton
Returns a new Skeleton instance.
Graph
Returns a networkx Graph with the given edges pruned and remaining
paths merged.
"""
kwargs = {
"spacing": spacing, "source_image": source_image, "keep_images":
keep_images, "value_is_height": value_is_height
}
pruned = Skeleton(skeleton, **kwargs
) if isinstance(skeleton, np.ndarray) else skeleton
branch_data = summarize(pruned, find_main_branch=find_main_branch)

while branch_data.shape[0] > min_skeleton:
# Remove branches that have endpoint (branch_type == 1)
n_paths = branch_data.shape[0]
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=1,
find_main_branch=find_main_branch,
**kwargs
)
# Check to see if we have a looped path with a branches, if so and the branch is shorter than the loop we
# remove it and break. Can either look for whether there are just two branches
# if branch_data.shape[0] == 2:
# length_branch_type1 = branch_data.loc[branch_data["branch-type"] ==
# 1,
# "branch-distance"].values[0]
# length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# 3,
# "branch-distance"].values[0]
# if length_branch_type3 > length_branch_type1:
# pruned, branch_data = _remove_branch_type(
# pruned, branch_data, branch_type=1, find_main_branch=find_main_branch, **kwargs
# )
# ...or perhaps more generally whether we have just one loop left and if its length is less than other branches
if branch_data.loc[branch_data["branch-type"] == 3].shape[0] == 1:
# Extract the length of a loop
length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
3,
"branch-distance"].values[0]
# Extract indices for branches lengths less than this and prune them
pruned = pruned.prune_paths(
branch_data.loc[branch_data["branch-distance"] <
length_branch_type3].index
)
branch_data = summarize(pruned, find_main_branch=find_main_branch)

# We now need to check whether we have the desired number of branches (often 1), have to check before removing
# branches of type 3 in case this is the final, clean, loop.
if branch_data.shape[0] == min_skeleton:
break
# If not we need to remove any small side loops (branch_type == 3)
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=3,
find_main_branch=find_main_branch,
**kwargs
)
# We don't need to check if we have a single path as that is the control check for the while loop, however we do
# need to check if we are removing anything as some skeletons of closed loops have internal branches that won't
# ever get pruned. This happens when there are internal loops to the main one so we never observe a loop with a
# single branch. The remaining branches ARE part of the main branch which is why they haven't (yet) been
# removed. We now prune those and check whether we have reduced the number of paths, if not we're done pruning.
if branch_data.shape[0] == n_paths:
pruned, branch_data = _remove_branch_type(
pruned,
branch_data,
branch_type=1,
find_main_branch=False,
**kwargs
)
# If this HASN'T removed any more branches we are done
if branch_data.shape[0] == n_paths:
break
pruned = skeleton # we start with no pruning

num_pruned = 1

while num_pruned > 0:
for_pruning = []
for u, v in pruned.edges:
attrs = pruned.edges[u, v]
if discard(pruned, attrs):
for_pruning.append((u, v))
num_pruned = len(for_pruning)
pruned.remove_edges_from(for_pruning)
_remove_simple_path_nodes(pruned)
return pruned


# Below code needs to be turned into a discard predicate callback
# while branch_data.shape[0] > min_skeleton:
# # Remove branches that have endpoint (branch_type == 1)
# n_paths = branch_data.shape[0]
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=1,
# find_main_branch=find_main_branch,
# **kwargs
# )
# # Check to see if we have a looped path with a branches, if so and the branch is shorter than the loop we
# # remove it and break. Can either look for whether there are just two branches
# # if branch_data.shape[0] == 2:
# # length_branch_type1 = branch_data.loc[branch_data["branch-type"] ==
# # 1,
# # "branch-distance"].values[0]
# # length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# # 3,
# # "branch-distance"].values[0]
# # if length_branch_type3 > length_branch_type1:
# # pruned, branch_data = _remove_branch_type(
# # pruned, branch_data, branch_type=1, find_main_branch=find_main_branch, **kwargs
# # )
# # ...or perhaps more generally whether we have just one loop left and if its length is less than other branches
# if branch_data.loc[branch_data["branch-type"] == 3].shape[0] == 1:
# # Extract the length of a loop
# length_branch_type3 = branch_data.loc[branch_data["branch-type"] ==
# 3,
# "branch-distance"].values[0]
# # Extract indices for branches lengths less than this and prune them
# pruned = pruned.prune_paths(
# branch_data.loc[branch_data["branch-distance"] <
# length_branch_type3].index
# )
# branch_data = summarize(pruned, find_main_branch=find_main_branch)
#
# # We now need to check whether we have the desired number of branches (often 1), have to check before removing
# # branches of type 3 in case this is the final, clean, loop.
# if branch_data.shape[0] == min_skeleton:
# break
# # If not we need to remove any small side loops (branch_type == 3)
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=3,
# find_main_branch=find_main_branch,
# **kwargs
# )
# # We don't need to check if we have a single path as that is the control check for the while loop, however we do
# # need to check if we are removing anything as some skeletons of closed loops have internal branches that won't
# # ever get pruned. This happens when there are internal loops to the main one so we never observe a loop with a
# # single branch. The remaining branches ARE part of the main branch which is why they haven't (yet) been
# # removed. We now prune those and check whether we have reduced the number of paths, if not we're done pruning.
# if branch_data.shape[0] == n_paths:
# pruned, branch_data = _remove_branch_type(
# pruned,
# branch_data,
# branch_type=1,
# find_main_branch=False,
# **kwargs
# )
# # If this HASN'T removed any more branches we are done
# if branch_data.shape[0] == n_paths:
# break
# return pruned


def _remove_branch_type(
skeleton: Skeleton, branch_data: pd.DataFrame, branch_type: int,
find_main_branch: bool, **kwargs
Expand Down
Loading