1
+ # Definition for a binary tree node.
2
+ import heapq
3
+ import sys
4
+ import unittest
5
+
6
+ # Read about enumerate in python
7
+ from collections import defaultdict
8
+ from typing import List
9
+
10
+ class TreeNode :
11
+ def __init__ (self , val = 0 , left = None , right = None ):
12
+ self .val = val
13
+ self .left = left
14
+ self .right = right
15
+
16
+ class MaximumSumBstInBinaryTree (unittest .TestCase ):
17
+
18
+ def maxSumBST2 (self , root : TreeNode ) -> int :
19
+ # returns {lo, hi, sum, valid}
20
+ def dfs (node ):
21
+ if not node : return (1e9 , - 1e9 , 0 , True )
22
+ ll , lh , ls , lv = dfs (node .left )
23
+ rl , rh , rs , rv = dfs (node .right )
24
+ v = lv and rv and node .val > lh and node .val < rl
25
+ s = ls + rs + node .val if v else - 1
26
+ self .ans = max (self .ans , s )
27
+ return (min (ll , node .val ), max (rh , node .val ), s , v )
28
+
29
+ self .ans = 0
30
+ dfs (root )
31
+ return self .ans
32
+
33
+ def maxSumBST (self , root : TreeNode ) -> int :
34
+ def helper (node : TreeNode ) -> (bool , int , int , int ): # isBST, sum, maxValue, minValue
35
+ if not node :
36
+ return True , 0 , - 1e9 , 1e9
37
+
38
+ isBSTLeft , leftSum , leftMaxValue , leftMinValue = helper (node .left )
39
+ isBSTRight , rightSum , rightMaxValue , rightMinValue = helper (node .right )
40
+
41
+ isBST = isBSTLeft and isBSTRight and leftMaxValue < root .val < rightMinValue
42
+ sumValue = leftSum + rightSum + node .val if isBST else - 1
43
+ maxValue = max (leftMaxValue , rightMaxValue , node .val )
44
+ minValue = min (leftMinValue , rightMinValue , node .val )
45
+ self .ans = max (self .ans , sumValue )
46
+ return isBST , sumValue , maxValue , minValue
47
+
48
+ self .ans = 0
49
+ isBST , sumValue , maxValue , minValue = helper (root )
50
+ return self .ans
51
+
52
+ def test_Leetcode (self ):
53
+ node1 = TreeNode (5 )
54
+ node2 = TreeNode (4 )
55
+ node3 = TreeNode (8 )
56
+ node4 = TreeNode (3 )
57
+ node5 = TreeNode (6 )
58
+ node6 = TreeNode (3 )
59
+ node1 .left = node2
60
+ node1 .right = node3
61
+ node2 .left = node4
62
+ node2 .right = None
63
+ node3 .left = node5
64
+ node3 .right = node6
65
+ self .assertEqual (7 , self .maxSumBST (node1 ))
66
+
67
+ if __name__ == '__main__' :
68
+ unittest .main ()
0 commit comments