LC 654. 最大二叉树

题目描述

这是 LeetCode 上的 654. 最大二叉树 ,难度为 中等

给定一个不重复的整数数组 nums

最大二叉树 可以用下面的算法从 nums 递归地构建:

  1. 创建一个根节点,其值为 nums 中的最大值。
  2. 递归地在最大值 左边 的 子数组前缀上 构建左子树。
  3. 递归地在最大值 右边 的 子数组后缀上 构建右子树。

返回 nums 构建的最大二叉树。

示例 1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
输入:nums = [3,2,1,6,0,5]

输出:[6,3,5,null,2,0,null,null,1]

解释:递归调用如下所示:
- [3,2,1,6,0,5] 中的最大值是 6 ,左边部分是 [3,2,1] ,右边部分是 [0,5]
- [3,2,1] 中的最大值是 3 ,左边部分是 [] ,右边部分是 [2,1]
- 空数组,无子节点。
- [2,1] 中的最大值是 2 ,左边部分是 [] ,右边部分是 [1]
- 空数组,无子节点。
- 只有一个元素,所以子节点是一个值为 1 的节点。
- [0,5] 中的最大值是 5 ,左边部分是 [0] ,右边部分是 []
- 只有一个元素,所以子节点是一个值为 0 的节点。
- 空数组,无子节点。

示例 2:

1
2
3
输入:nums = [3,2,1]

输出:[3,null,2,null,1]

提示:

  • $1 <= nums.length <= 1000$
  • $0 <= nums[i] <= 1000$
  • nums 中的所有整数 互不相同

基本分析

根据题目描述,可知该问题本质是「区间求最值」问题(RMQ)。

而求解 RMQ 有多种方式:递归分治、有序集合/ST/线段树 和 单调栈。

其中递归分治做法复杂度为 $O(n^2)$,对本题来说可过;而其余诸如线段树的方式需要 $O(n\log{n})$ 的建树和单次 $O(\log{n})$ 的查询,整体复杂度为 $O(n\log{n})$;单调栈解法则是整体复杂度为 $O(n)$。


递归分治

设置递归函数 TreeNode build(int[] nums, int l, int r) 含义为从 nums 中的 $[l, r]$ 下标范围进行构建,返回构建后的头结点。

当 $l > r$ 时,返回空节点,否则在 $[l, r]$ 中进行扫描,找到最大值对应的下标 idx 并创建对应的头结点,递归构建 $[l, idx - 1]$ 和 $[idx + 1, r]$ 作为头节点的左右子树。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution {
public TreeNode constructMaximumBinaryTree(int[] nums) {
return build(nums, 0, nums.length - 1);
}
TreeNode build(int[] nums, int l, int r) {
if (l > r) return null;
int idx = l;
for (int i = l; i <= r; i++) {
if (nums[i] > nums[idx]) idx = i;
}
TreeNode ans = new TreeNode(nums[idx]);
ans.left = build(nums, l, idx - 1);
ans.right = build(nums, idx + 1, r);
return ans;
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution {
public:
TreeNode* constructMaximumBinaryTree(vector<int>& nums) {
return build(nums, 0, nums.size() - 1);
}
TreeNode* build(vector<int>& nums, int l, int r) {
if (l > r) return nullptr;
int idx = l;
for (int i = l; i <= r; ++i) {
if (nums[i] > nums[idx]) idx = i;
}
TreeNode* ans = new TreeNode(nums[idx]);
ans->left = build(nums, l, idx - 1);
ans->right = build(nums, idx + 1, r);
return ans;
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
def constructMaximumBinaryTree(self, nums: List[int]) -> Optional[TreeNode]:
return self.build(nums, 0, len(nums) - 1)

def build(self, nums, l, r):
if l > r:
return None
idx = l
for i in range(l, r + 1):
if nums[i] > nums[idx]:
idx = i
node = TreeNode(nums[idx])
node.left = self.build(nums, l, idx - 1)
node.right = self.build(nums, idx + 1, r)
return node

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
function constructMaximumBinaryTree(nums: number[]): TreeNode | null {
return build(nums, 0, nums.length - 1)
};
function build(nums: number[], l: number, r: number): TreeNode | null {
if (l > r) return null
let idx = l
for (let i = l; i <= r; i++) {
if (nums[i] > nums[idx]) idx = i
}
const ans = new TreeNode(nums[idx])
ans.left = build(nums, l, idx - 1)
ans.right = build(nums, idx + 1, r)
return ans
}

  • 时间复杂度:$O(n^2)$
  • 空间复杂度:忽略递归带来的额外空间开销,复杂度为 $O(1)$

线段树

抽象成区间求和问题后,涉及「单点修改」和「区间查询」,再结合节点数量为 $1e3$,可使用 build $4n$ 空间不带懒标记的线段树进行求解。

设计线段树节点 Node 包含属性:左节点下标 l、右节点下标 r 和当前区间 $[l, r]$ 所对应的最值 $val$。

构建线段树的过程为基本的线段树模板内容,而构建答案树的过程与递归分治过程类型(将线性找最值过程用线段树优化)。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
class Node {
int l, r, val;
Node (int _l, int _r) {
l = _l; r = _r;
}
}
void build(int u, int l, int r) {
tr[u] = new Node(l, r);
if (l == r) return ;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void update(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) {
tr[u].val = Math.max(tr[u].val, v);
return ;
}
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) update(u << 1, x, v);
else update(u << 1 | 1, x, v);
pushup(u);
}
int query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].val;
int mid = tr[u].l + tr[u].r >> 1, ans = 0;
if (l <= mid) ans = query(u << 1, l, r);
if (r > mid) ans = Math.max(ans, query(u << 1 | 1, l, r));
return ans;
}
void pushup(int u) {
tr[u].val = Math.max(tr[u << 1].val, tr[u << 1 | 1].val);
}
Node[] tr = new Node[4010];
int[] hash = new int[1010];
public TreeNode constructMaximumBinaryTree(int[] nums) {
int n = nums.length;
build(1, 1, n);
for (int i = 0; i < n; i++) {
hash[nums[i]] = i + 1;
update(1, i + 1, nums[i]);
}
return dfs(nums, 1, n);
}
TreeNode dfs(int[] nums, int l, int r) {
if (l > r) return null;
int val = query(1, l, r), idx = hash[val];
TreeNode ans = new TreeNode(val);
ans.left = dfs(nums, l, idx - 1);
ans.right = dfs(nums, idx + 1, r);
return ans;
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
struct Node {
int l, r, val;
Node(int _l, int _r) : l(_l), r(_r), val(0) {}
Node() : l(0), r(0), val(0) {}
};
class Solution {
public:
vector<Node> tr;
unordered_map<int, int> hash;
TreeNode* constructMaximumBinaryTree(vector<int>& nums) {
int n = nums.size();
tr.resize(4010);
build(1, 1, n);
for (int i = 0; i < n; ++i) {
hash[nums[i]] = i + 1;
update(1, i + 1, nums[i]);
}
return dfs(nums, 1, n);
}
void build(int u, int l, int r) {
tr[u] = Node(l, r);
if (l == r) return;
int mid = (l + r) >> 1;
build((u << 1), l, mid);
build((u << 1) + 1, mid + 1, r);
}
void update(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) {
tr[u].val = max(tr[u].val, v);
return;
}
int mid = (tr[u].l + tr[u].r) >> 1;
if (x <= mid) update((u << 1), x, v);
else update((u << 1) + 1, x, v);
pushup(u);
}
int query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].val;
int mid = (tr[u].l + tr[u].r) >> 1, ans = 0;
if (l <= mid) ans = max(ans, query((u << 1), l, r));
if (r > mid) ans = max(ans, query((u << 1) + 1, l, r));
return ans;
}
void pushup(int u) {
tr[u].val = max(tr[(u << 1)].val, tr[(u << 1) + 1].val);
}
TreeNode* dfs(vector<int>& nums, int l, int r) {
if (l > r) return NULL;
int val = query(1, l, r), idx = hash[val];
TreeNode* ans = new TreeNode(val);
ans -> left = dfs(nums, l, idx - 1);
ans -> right = dfs(nums, idx + 1, r);
return ans;
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Node:
def __init__(self, l=0, r=0, val=0):
self.l = l
self.r = r
self.val = val

class Solution:
def __init__(self):
self.tr = [Node() for _ in range(4010)]
self.hash = {}

def constructMaximumBinaryTree(self, nums):
n = len(nums)
self.build(1, 1, n)
for i, num in enumerate(nums):
self.hash[num] = i + 1
self.update(1, i + 1, num)
return self.dfs(nums, 1, n)

def build(self, u, l, r):
self.tr[u] = Node(l, r)
if l == r: return
mid = (l + r) >> 1
self.build(u << 1, l, mid)
self.build(u << 1 | 1, mid + 1, r)

def update(self, u, x, v):
if self.tr[u].l == x and self.tr[u].r == x:
self.tr[u].val = max(self.tr[u].val, v)
return
mid = (self.tr[u].l + self.tr[u].r) >> 1
if x <= mid: self.update(u << 1, x, v)
else: self.update(u << 1 | 1, x, v)
self.pushup(u)

def query(self, u, l, r):
if l <= self.tr[u].l and self.tr[u].r <= r: return self.tr[u].val
mid = (self.tr[u].l + self.tr[u].r) >> 1
ans = 0
if l <= mid: ans = max(ans, self.query(u << 1, l, r))
if r > mid: ans = max(ans, self.query(u << 1 | 1, l, r))
return ans

def pushup(self, u):
self.tr[u].val = max(self.tr[u << 1].val, self.tr[u << 1 | 1].val)

def dfs(self, nums, l, r):
if l > r: return None
val = self.query(1, l, r)
idx = self.hash[val]
node = TreeNode(val)
node.left = self.dfs(nums, l, idx - 1)
node.right = self.dfs(nums, idx + 1, r)
return node

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class TNode {
l = 0; r = 0; val = 0;
constructor (_l: number, _r: number) {
this.l = _l; this.r = _r;
}
}
const tr: TNode[] = new Array<TNode>(4010)
const hash: number[] = new Array<number>(1010)
function constructMaximumBinaryTree(nums: number[]): TreeNode | null {
const n = nums.length
build(1, 1, n)
for (let i = 0; i < n; i++) {
hash[nums[i]] = i + 1
update(1, i + 1, nums[i])
}
return dfs(nums, 1, n)
};
function build(u: number, l: number, r: number): void {
tr[u] = new TNode(l, r)
if (l == r) return
const mid = l + r >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
}
function update(u: number, x: number, v: number): void {
if (tr[u].l == x && tr[u].r == x) {
tr[u].val = Math.max(tr[u].val, v)
return
}
const mid = tr[u].l + tr[u].r >> 1
if (x <= mid) update(u << 1, x, v)
else update(u << 1 | 1, x, v)
pushup(u)
}
function query(u: number, l: number, r: number): number {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].val
let mid = tr[u].l + tr[u].r >> 1, ans = 0
if (l <= mid) ans = query(u << 1, l, r)
if (r > mid) ans = Math.max(ans, query(u << 1 | 1, l, r))
return ans
}
function pushup(u: number): void {
tr[u].val = Math.max(tr[u << 1].val, tr[u << 1 | 1].val)
}
function dfs(nums: number[], l: number, r: number): TreeNode {
if (l > r) return null
let val = query(1, l, r), idx = hash[val]
const ans = new TreeNode(val)
ans.left = dfs(nums, l, idx - 1)
ans.right = dfs(nums, idx + 1, r)
return ans
}

  • 时间复杂度:构建线段树复杂度为 $O(n\log{n})$;构造答案树复杂度为 $O(n\log{n})$。整体复杂度为 $O(n\log{n})$
  • 空间复杂度:$O(n)$

单调栈

更进一步,根据题目对树的构建的描述可知,nums 中的任二节点所在构建树的水平截面上的位置仅由下标大小决定。

不难想到可抽象为找最近元素问题,可使用单调栈求解。

具体的,我们可以从前往后处理所有的 $nums[i]$,若存在栈顶元素并且栈顶元素的值比当前值要小,根据我们从前往后处理的逻辑,可确定栈顶元素可作为当前 $nums[i]$ 对应节点的左节点,同时为了确保最终 $nums[i]$ 的左节点为 $[0, i - 1]$ 范围的最大值,我们需要确保在构建 $nums[i]$ 节点与其左节点的关系时,$[0, i - 1]$ 中的最大值最后出队,此时可知容器栈具有「单调递减」特性。基于此,我们可以分析出,当处理完 $nums[i]$ 节点与其左节点关系后,可明确 $nums[i]$ 可作为未出栈的栈顶元素的右节点。

一些细节:Java 容易使用 ArrayDeque 充当容器,但为与 TS 保存一致,两者均使用数组充当容器。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution {
static TreeNode[] stk = new TreeNode[1010];
public TreeNode constructMaximumBinaryTree(int[] nums) {
int he = 0, ta = 0;
for (int x : nums) {
TreeNode node = new TreeNode(x);
while (he < ta && stk[ta - 1].val < x) node.left = stk[--ta];
if (he < ta) stk[ta - 1].right = node;
stk[ta++] = node;
}
return stk[0];
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution {
public:
TreeNode* stk[1010];
TreeNode* constructMaximumBinaryTree(vector<int>& nums) {
int he = 0, ta = 0;
for (int x : nums) {
TreeNode* node = new TreeNode(x);
while (he < ta && stk[ta - 1]->val < x) node->left = stk[--ta];
if (he < ta) stk[ta - 1]->right = node;
stk[ta++] = node;
}
return stk[0];
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Solution:
def constructMaximumBinaryTree(self, nums: List[int]) -> Optional[TreeNode]:
stk = [None] * 1010
he, ta = 0, 0
for x in nums:
node = TreeNode(x)
while he < ta and stk[ta - 1].val < x:
node.left = stk[ta - 1]
ta -= 1
if he < ta:
stk[ta - 1].right = node
stk[ta] = node
ta += 1
return stk[0]

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
const stk = new Array<TreeNode>(1010)
function constructMaximumBinaryTree(nums: number[]): TreeNode | null {
let he = 0, ta = 0
for (const x of nums) {
const node = new TreeNode(x)
while (he < ta && stk[ta - 1].val < x) node.left = stk[--ta]
if (he < ta) stk[ta - 1].right = node
stk[ta++] = node
}
return stk[0]
};

  • 时间复杂度:$O(n)$
  • 空间复杂度:$O(n)$

最后

这是我们「刷穿 LeetCode」系列文章的第 No.654 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。

在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。

为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode

在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。