Skip to content

Commit

Permalink
mask reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
54rt1n committed Jan 26, 2024
1 parent 35afed5 commit 1c6cfd9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 6 deletions.
4 changes: 3 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .components.dare_mbw import DareUnetMergerMBW
from .components.block import BlockUnetMerger
from .components.normalize import NormalizeUnet
from .components.mask_model import MagnitudeMasker, MaskOperations
from .components.mask_model import MagnitudeMasker, MaskOperations, MaskReporting


NODE_CLASS_MAPPINGS = {
Expand All @@ -14,6 +14,7 @@
"DM_NormalizeModel": NormalizeUnet,
"DM_MagnitudeMasker": MagnitudeMasker,
"DM_MaskOperations": MaskOperations,
"DM_MaskReporting": MaskReporting,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -24,6 +25,7 @@
"DM_NormalizeModel": "Normalize Model",
"DM_MagnitudeMasker": "Magnitude Masker",
"DM_MaskOperations": "Mask Operations",
"DM_MaskReporting": "Mask Reporting",
}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
76 changes: 75 additions & 1 deletion components/mask_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# components/mask_model.py
from collections import defaultdict
from comfy.model_patcher import ModelPatcher
import torch
from typing import Dict, Tuple
Expand Down Expand Up @@ -120,4 +121,77 @@ def mask_ops(self, mask_a: ModelMask, mask_b: ModelMask, operation: str = "union
elif operation == "xor":
return (ModelMask.symmetric_distance(mask_a, mask_b),)
else:
raise ValueError("Unknown operation: {}".format(operation))
raise ValueError("Unknown operation: {}".format(operation))


class MaskReporting:
"""
Take two masks and perform a set operation. union, intersect, difference, xor
"""

@classmethod
def INPUT_TYPES(cls) -> Dict[str, tuple]:
"""
Defines the input types for the masking process.
Returns:
Dict[str, tuple]: A dictionary specifying the required model types and parameters.
"""
return {
"required": {
"mask": ("MODEL_MASK",),
"report": (["size"], {"default": "size"}),
}
}

RETURN_TYPES = ("STRING",)
FUNCTION = "mask_report"
CATEGORY = MASK_CATEGORY

def mask_report(self, mask: ModelMask, report: str = "size", **kwargs) -> Tuple[str]:
"""
Generate a report on the mask.
Args:
mask (ModelMask): The mask.
report (str): The report to generate.
Returns:
Tuple[str]: A tuple containing the report.
"""
if report == "size":
return (self.size_report(mask), )
else:
raise ValueError("Unknown report: {}".format(report))

def size_report(self, mask: ModelMask) -> Tuple[str]:
"""
Generate a report on the size of the mask.
Args:
mask (ModelMask): The mask.
Returns:
Tuple[str]: A tuple containing the report.
"""
sd = mask.state_dict
data = defaultdict(dict)
for k in sd.keys():
parts = k.split(".", 2)
if len(parts) == 2:
print("skipping", k)
else:
model, block, rest = parts
# our report is a tuple containing the number of elements, and the number of elements that are true
data[block][rest] = (sd[k].numel(), sd[k].sum().item())

report = ""
for block in data.keys():
total = 0
total_true = 0
for rest in data[block].keys():
total += data[block][rest][0]
total_true += data[block][rest][1]
report += f"{block}: {total_true} / {total} ({total_true / total * 100:.2f}%)\n"

return report
15 changes: 11 additions & 4 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# Examples

The sample images in this folder contain ComfyUI workflows.
Some of the images in this directory have embedded workflows:
Workflows:

* `daremerge.png` - The above DARE merge workflow
![image](./daremerge.png)

![image](./daremergepic.png)

Images:
* `daremerge.png` - The above DARE merge workflow
* `maskedmerge.png` - The masked merge workflow
* `maskedmerge.png` - The masked merge workflow, which shows the difference of just using just the mask merge instead of the DARE merge
![image](./maskedmerge.png)

## Masking
You can see the basic mask operations below:
![image](./maskops.png)
Binary file added examples/maskops.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1c6cfd9

Please sign in to comment.