LC 749. 隔离病毒
题目描述
这是 LeetCode 上的 749. 隔离病毒 ,难度为 困难。
病毒扩散得很快,现在你的任务是尽可能地通过安装防火墙来隔离病毒。
假设世界由 $m \times n$ 的二维矩阵 isInfected 组成,isInfected[i][j] == 0 表示该区域未感染病毒,而  isInfected[i][j] == 1 表示该区域已感染病毒。可以在任意 $2$ 个相邻单元之间的共享边界上安装一个防火墙(并且只有一个防火墙)。
每天晚上,病毒会从被感染区域向相邻未感染区域扩散,除非被防火墙隔离。现由于资源有限,每天你只能安装一系列防火墙来隔离其中一个被病毒感染的区域(一个区域或连续的一片区域),且该感染区域对未感染区域的威胁最大且 保证唯一 。
你需要努力使得最后有部分区域不被病毒感染,如果可以成功,那么返回需要使用的防火墙个数; 如果无法实现,则返回在世界被病毒全部感染时已安装的防火墙个数。
示例 1:

| 1 |  | 

示例 2:

| 1 |  | 
示例 3:1
2
3
4
5输入: isInfected = [[1,1,1,0,0,0,0,0,0],[1,0,1,0,1,1,1,1,1],[1,1,1,0,0,0,0,0,0]]
输出: 13
解释: 在隔离右边感染区域后,隔离左边病毒区域只需要 2 个防火墙。
提示:
- $m = isInfected.length$
- $n = isInfected[i].length$
- $1 <= m, n <= 50$
- isInfected[i][j]is either- 0or- 1
- 在整个描述的过程中,总有一个相邻的病毒区域,它将在下一轮严格地感染更多未受污染的方块
搜索模拟
根据题意,我们可以按天进行模拟,设计函数 getCnt 用于返回当天会被安装的防火墙数量,在 getCnt 内部我们会进行如下操作:
- 找出当天「对未感染区域的威胁最大」的区域,并将该区域进行隔离(将 $1$ 设置为 $-1$);
- 对其他区域,进行步长为 $1$ 的感染操作。
考虑如何实现 getCnt:我们需要以「连通块」为单位进行处理,因此每次的 getCnt 操作,我们先重建一个与矩阵等大的判重数组 vis,对于每个 $g[i][j] = 1$ 且未被 $vis[i][j]$ 标记为 True 的位置进行搜索,搜索过程使用 BFS 实现。
在 BFS 过程中,我们除了统计该连通块所需要的防火墙数量 $b$ 以外,还需要额外记录当前连通块中 $1$ 的点集 s1(简称为原集,含义为连通块的格子集合),以及当前连通块相邻的 $0$ 的点集 s2(简称为扩充集,含义为将要被感染的格子集合)。
根据题意,在单次的 getCnt 中,我们需要在所有连通块中取出其 s2 大小最大(对未感染区域的威胁最大)的连通块进行隔离操作,而其余连通块则进行扩充操作。
因此我们可以使用两个变量 max 和 ans 分别记录所有 s2 中的最大值,以及取得最大 s2 所对应连通块所需要的防火墙数量,同时需要使用两个数组 l1 和 l2 分别记录每个连通块对应的「原集」和「扩充集」,方便我们后续进行「隔离」和「感染」。
Java 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66class Solution {
    int[][] g;
    int n, m, ans;
    int[][] dirs = new int[][]{{1,0},{-1,0},{0,1},{0,-1}};
    boolean[][] vis;
    int search(int _x, int _y, Set<Integer> s1, Set<Integer> s2) {
        int ans = 0;
        Deque<int[]> d = new ArrayDeque<>();
        vis[_x][_y] = true;
        d.addLast(new int[]{_x, _y});
        s1.add(_x * m + _y);
        while (!d.isEmpty()) {
            int[] info = d.pollFirst();
            int x = info[0], y = info[1];
            for (int[] di : dirs) {
                int nx = x + di[0], ny = y + di[1], loc = nx * m + ny;
                if (nx < 0 || nx >= n || ny < 0 || ny >= m || vis[nx][ny]) continue;
                if (g[nx][ny] == 1) {
                    s1.add(loc);
                    vis[nx][ny] = true;
                    d.addLast(new int[]{nx, ny});
                } else if (g[nx][ny] == 0) {
                    s2.add(loc);
                    ans++;
                }
            }
        }
        return ans;
    }
    int getCnt() {
        vis = new boolean[n][m];
        int max = 0, ans = 0;
        // l1: 每个连通块的点集 s2: 每个连通块的候选感染点集
        List<Set<Integer>> l1 = new ArrayList<>(), l2 = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                if (g[i][j] == 1 && !vis[i][j]) {
                    // s1: 当前连通块的点集 s2: 当前连通块的候选感染点集
                    Set<Integer> s1 = new HashSet<>(), s2 = new HashSet<>();
                    int b = search(i, j, s1, s2), a = s2.size();
                    if (a > max) {
                        max = a; ans = b;
                    }
                    l1.add(s1); l2.add(s2);
                }
            }
        }
        for (int i = 0; i < l2.size(); i++) {
            for (int loc : l2.get(i).size() == max ? l1.get(i) : l2.get(i)) {
                int x = loc / m, y = loc % m;
                g[x][y] = l2.get(i).size() == max ? -1 : 1;
            }
        }
        return ans;
    }
    public int containVirus(int[][] _g) {
        g = _g;
        n = g.length; m = g[0].length;
        while (true) {
            int cnt = getCnt();
            if (cnt == 0) break;
            ans += cnt;
        }
        return ans;
    }
}
C++ 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Solution {
public:
    vector<vector<int>> g;
    int n, m, ans;
    vector<vector<int>> dirs = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
    vector<vector<bool>> vis;
    int search(int _x, int _y, unordered_set<int>& s1, unordered_set<int>& s2) {
        int ans = 0;
        deque<pair<int, int>> d;
        vis[_x][_y] = true;
        d.push_back({_x, _y});
        s1.insert(_x * m + _y);
        while (!d.empty()) {
            auto [x, y] = d.front(); d.pop_front();
            for (auto& di : dirs) {
                int nx = x + di[0], ny = y + di[1], loc = nx * m + ny;
                if (nx < 0 || nx >= n || ny < 0 || ny >= m || vis[nx][ny]) continue;
                if (g[nx][ny] == 1) {
                    s1.insert(loc);
                    vis[nx][ny] = true;
                    d.push_back({nx, ny});
                } else if (g[nx][ny] == 0) {
                    s2.insert(loc);
                    ans++;
                }
            }
        }
        return ans;
    }
    int getCnt() {
        vis = vector<vector<bool>>(n, vector<bool>(m, false));
        int maxv = 0, ans = 0;
        vector<unordered_set<int>> l1, l2;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                if (g[i][j] == 1 && !vis[i][j]) {
                    unordered_set<int> s1, s2;
                    int b = search(i, j, s1, s2), a = s2.size();
                    if (a > maxv) {
                        maxv = a; ans = b;
                    }
                    l1.push_back(s1); l2.push_back(s2);
                }
            }
        }
        for (int i = 0; i < l2.size(); i++) {
            for (int loc : (l2[i].size() == maxv ? l1[i] : l2[i])) {
                int x = loc / m, y = loc % m;
                g[x][y] = l2[i].size() == maxv ? -1 : 1;
            }
        }
        return ans;
    }
    int containVirus(vector<vector<int>>& _g) {
        g = _g;
        n = g.size(); m = g[0].size();
        ans = 0;
        while (true) {
            int cnt = getCnt();
            if (cnt == 0) break;
            ans += cnt;
        }
        return ans;
    }
};
Python 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55class Solution:
    def search(self, x, y, s1, s2):
        ans = 0
        d = deque()
        self.vis[x][y] = True
        d.append((x, y))
        s1.add(x * self.m + y)
        while d:
            nx, ny = d.popleft()
            for dx, dy in self.dirs:
                nx_, ny_ = nx + dx, ny + dy
                loc = nx_ * self.m + ny_
                if nx_ < 0 or nx_ >= self.n or ny_ < 0 or ny_ >= self.m or self.vis[nx_][ny_]:
                    continue
                if self.g[nx_][ny_] == 1:
                    s1.add(loc)
                    self.vis[nx_][ny_] = True
                    d.append((nx_, ny_))
                elif self.g[nx_][ny_] == 0:
                    s2.add(loc)
                    ans += 1
        return ans
    def getCnt(self):
        self.vis = [[False] * self.m for _ in range(self.n)]
        max_area = 0
        ans = 0
        l1, l2 = [], []
        for i in range(self.n):
            for j in range(self.m):
                if self.g[i][j] == 1 and not self.vis[i][j]:
                    s1, s2 = set(), set()
                    b = self.search(i, j, s1, s2), len(s2)
                    if b[1] > max_area:
                        max_area = b[1]
                        ans = b[0]
                    l1.append(s1)
                    l2.append(s2)
        for i in range(len(l2)):
            for loc in (l1[i] if len(l2[i]) == max_area else l2[i]):
                x, y = divmod(loc, self.m)
                self.g[x][y] = -1 if len(l2[i]) == max_area else 1
        return ans
    def containVirus(self, g: List[List[int]]) -> int:
        self.g = g
        self.n = len(g)
        self.m = len(g[0])
        self.ans = 0
        self.dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]]
        while True:
            cnt = self.getCnt()
            if cnt == 0: break
            self.ans += cnt
        return self.ans
TypeScript 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64let g: number[][] = null
let n: number = 0, m: number = 0
let vis: boolean[][] = null
const dirs: number[][] = [[1,0],[-1,0],[0,1],[0,-1]]
function search(_x: number, _y: number, s1: Set<number>, s2: Set<number>): number {
    let he = 0, ta = 0, ans = 0
    let d: Array<number> = new Array<number>()
    s1.add(_x * m + _y)
    vis[_x][_y] = true
    d[ta++] = _x * m + _y
    while (he < ta) {
        const poll = d[he++]
        const x = Math.floor(poll / m), y = poll % m
        for (const di of dirs) {
            const nx = x + di[0], ny = y + di[1], loc = nx * m + ny
            if (nx < 0 || nx >= n || ny < 0 || ny >= m || vis[nx][ny]) continue
            if (g[nx][ny] == 1) {
                s1.add(loc)
                vis[nx][ny] = true
                d[ta++] = loc
            } else if (g[nx][ny] == 0) {
                s2.add(loc)
                ans++
            }
        }
    }
    return ans
}
function getCnt(): number {
    vis = new Array<Array<boolean>>(n)
    for (let i = 0; i < n; i++) vis[i] = new Array<boolean>(m).fill(false)
    let max = 0, ans = 0
    let l1: Array<Set<number>> = new Array<Set<number>>(), l2: Array<Set<number>> = new Array<Set<number>>()
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            if (g[i][j] == 1 && !vis[i][j]) {
                let s1 = new Set<number>(), s2 = new Set<number>()
                const b = search(i, j, s1, s2), a = s2.size
                if (a > max) {
                    max = a; ans = b
                }
                l1.push(s1); l2.push(s2)
            }
        }
    }
    for (let i = 0; i < l2.length; i++) {
        for (let loc of l2[i].size == max ? l1[i] : l2[i]) {
            const x = Math.floor(loc / m), y = loc % m
            g[x][y] = l2[i].size == max ? -1 : 1
        }
    }
    return ans
}
function containVirus(_g: number[][]): number {
    g = _g
    n = g.length; m = g[0].length
    let ans: number = 0
    while (true) {
        const cnt = getCnt()
        if (cnt == 0) break
        ans += cnt
    }
    return ans
};
- 时间复杂度:最多有 $n + m$ 天需要模拟,每天模拟复杂度 $O(n \times m)$,整体复杂度为 $O((n + m) \times nm)$
- 空间复杂度:$O(nm)$
最后
这是我们「刷穿 LeetCode」系列文章的第 No.749 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。
在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。
为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode 。
在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!
