给定一个由 n
个节点组成的网络,用 n x n
个邻接矩阵 graph
表示。在节点网络中,只有当 graph[i][j] = 1
时,节点 i
能够直接连接到另一个节点 j
。
一些节点 initial
最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。
假设 M(initial)
是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。
我们可以从 initial
中删除一个节点,并完全移除该节点以及从该节点到任何其他节点的任何连接。
请返回移除后能够使 M(initial)
最小化的节点。如果有多个节点满足条件,返回索引 最小的节点 。
示例 1:
输入:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1] 输出:0
示例 2:
输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1] 输出:1
示例 3:
输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1] 输出:1
提示:
n == graph.length
n == graph[i].length
2 <= n <= 300
graph[i][j]
是0
或1
.graph[i][j] == graph[j][i]
graph[i][i] == 1
1 <= initial.length < n
0 <= initial[i] <= n - 1
-
initial
中每个整数都不同
逆向思维并查集。对于本题,先遍历所有未被感染的节点(即不在 initial 列表的节点),构造并查集,并且在集合根节点维护 size,表示当前集合的节点数。
然后找到只被一个 initial 节点感染的集合,求得感染节点数的最小值。
被某个 initial 节点感染的集合,节点数越多,若移除此 initial 节点,感染的节点数就越少。
以下是并查集的几个常用模板。
模板 1——朴素并查集:
# 初始化,p存储每个点的父节点
p = list(range(n))
# 返回x的祖宗节点
def find(x):
if p[x] != x:
# 路径压缩
p[x] = find(p[x])
return p[x]
# 合并a和b所在的两个集合
p[find(a)] = find(b)
模板 2——维护 size 的并查集:
# 初始化,p存储每个点的父节点,size只有当节点是祖宗节点时才有意义,表示祖宗节点所在集合中,点的数量
p = list(range(n))
size = [1] * n
# 返回x的祖宗节点
def find(x):
if p[x] != x:
# 路径压缩
p[x] = find(p[x])
return p[x]
# 合并a和b所在的两个集合
if find(a) != find(b):
size[find(b)] += size[find(a)]
p[find(a)] = find(b)
模板 3——维护到祖宗节点距离的并查集:
# 初始化,p存储每个点的父节点,d[x]存储x到p[x]的距离
p = list(range(n))
d = [0] * n
# 返回x的祖宗节点
def find(x):
if p[x] != x:
t = find(p[x])
d[x] += d[p[x]]
p[x] = t
return p[x]
# 合并a和b所在的两个集合
p[find(a)] = find(b)
d[find(a)] = distance
class Solution:
def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
def union(a, b):
pa, pb = find(a), find(b)
if pa != pb:
size[pb] += size[pa]
p[pa] = pb
n = len(graph)
p = list(range(n))
size = [1] * n
clean = [True] * n
for i in initial:
clean[i] = False
for i in range(n):
if not clean[i]:
continue
for j in range(i + 1, n):
if clean[j] and graph[i][j] == 1:
union(i, j)
cnt = Counter()
mp = {}
for i in initial:
s = {find(j) for j in range(n) if clean[j] and graph[i][j] == 1}
for root in s:
cnt[root] += 1
mp[i] = s
mx, ans = -1, 0
for i, s in mp.items():
t = sum(size[root] for root in s if cnt[root] == 1)
if mx < t or mx == t and i < ans:
mx, ans = t, i
return ans
class Solution {
private int[] p;
private int[] size;
public int minMalwareSpread(int[][] graph, int[] initial) {
int n = graph.length;
p = new int[n];
size = new int[n];
for (int i = 0; i < n; ++i) {
p[i] = i;
size[i] = 1;
}
boolean[] clean = new boolean[n];
Arrays.fill(clean, true);
for (int i : initial) {
clean[i] = false;
}
for (int i = 0; i < n; ++i) {
if (!clean[i]) {
continue;
}
for (int j = i + 1; j < n; ++j) {
if (clean[j] && graph[i][j] == 1) {
union(i, j);
}
}
}
int[] cnt = new int[n];
Map<Integer, Set<Integer>> mp = new HashMap<>();
for (int i : initial) {
Set<Integer> s = new HashSet<>();
for (int j = 0; j < n; ++j) {
if (clean[j] && graph[i][j] == 1) {
s.add(find(j));
}
}
for (int root : s) {
cnt[root] += 1;
}
mp.put(i, s);
}
int mx = -1;
int ans = 0;
for (Map.Entry<Integer, Set<Integer>> entry : mp.entrySet()) {
int i = entry.getKey();
int t = 0;
for (int root : entry.getValue()) {
if (cnt[root] == 1) {
t += size[root];
}
}
if (mx < t || (mx == t && i < ans)) {
mx = t;
ans = i;
}
}
return ans;
}
private int find(int x) {
if (p[x] != x) {
p[x] = find(p[x]);
}
return p[x];
}
private void union(int a, int b) {
int pa = find(a);
int pb = find(b);
if (pa != pb) {
size[pb] += size[pa];
p[pa] = pb;
}
}
}
class Solution {
public:
vector<int> p;
vector<int> size;
int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
int n = graph.size();
p.resize(n);
size.resize(n);
for (int i = 0; i < n; ++i) p[i] = i;
fill(size.begin(), size.end(), 1);
vector<bool> clean(n, true);
for (int i : initial) clean[i] = false;
for (int i = 0; i < n; ++i) {
if (!clean[i]) continue;
for (int j = i + 1; j < n; ++j)
if (clean[j] && graph[i][j] == 1) merge(i, j);
}
vector<int> cnt(n, 0);
unordered_map<int, unordered_set<int>> mp;
for (int i : initial) {
unordered_set<int> s;
for (int j = 0; j < n; ++j)
if (clean[j] && graph[i][j] == 1) s.insert(find(j));
for (int e : s) ++cnt[e];
mp[i] = s;
}
int mx = -1, ans = 0;
for (auto& [i, s] : mp) {
int t = 0;
for (int root : s)
if (cnt[root] == 1)
t += size[root];
if (mx < t || (mx == t && i < ans)) {
mx = t;
ans = i;
}
}
return ans;
}
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int a, int b) {
int pa = find(a), pb = find(b);
if (pa != pb) {
size[pb] += size[pa];
p[pa] = pb;
}
}
};
func minMalwareSpread(graph [][]int, initial []int) int {
n := len(graph)
p := make([]int, n)
size := make([]int, n)
clean := make([]bool, n)
for i := 0; i < n; i++ {
p[i], size[i], clean[i] = i, 1, true
}
for _, i := range initial {
clean[i] = false
}
var find func(x int) int
find = func(x int) int {
if p[x] != x {
p[x] = find(p[x])
}
return p[x]
}
union := func(a, b int) {
pa, pb := find(a), find(b)
if pa != pb {
size[pb] += size[pa]
p[pa] = pb
}
}
for i := 0; i < n; i++ {
if !clean[i] {
continue
}
for j := i + 1; j < n; j++ {
if clean[j] && graph[i][j] == 1 {
union(i, j)
}
}
}
cnt := make([]int, n)
mp := make(map[int]map[int]bool)
for _, i := range initial {
s := make(map[int]bool)
for j := 0; j < n; j++ {
if clean[j] && graph[i][j] == 1 {
s[find(j)] = true
}
}
for root, _ := range s {
cnt[root]++
}
mp[i] = s
}
mx, ans := -1, 0
for i, s := range mp {
t := 0
for root, _ := range s {
if cnt[root] == 1 {
t += size[root]
}
}
if mx < t || (mx == t && i < ans) {
mx, ans = t, i
}
}
return ans
}