Python 算法


from typing import Optional

class TreeNode:
    Represents a node in a binary tree.

        val: The value of the node.
        left: The left child node.
        right: The right child node.
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def maxPathSum(root: Optional[TreeNode]) -> int:
    Finds the maximum path sum in a binary tree.

    A path is defined as any sequence of nodes from some starting node to any node in the tree along the parent-child connections.
    The path must contain at least one node.  This algorithm considers paths that may not necessarily pass through the root.

        root: The root node of the binary tree.

        The maximum sum of a path in the tree.  Returns negative infinity if the tree is empty.

    ans = float('-inf') # Initialize the maximum path sum to negative infinity.  This handles cases with all negative values.

    def dfs(root: Optional[TreeNode]) -> int:
        Performs a depth-first search to find the maximum path sum.

        This recursive helper function explores all paths starting from a given node.

            root: The current node being visited.

            The maximum path sum starting from the current node, considering only paths that go down (either left or right, but not both).  This is crucial for preventing double-counting.
        nonlocal ans # Access the 'ans' variable from the outer scope.

        if root is None: # Base case: empty subtree
            return 0

        # Recursively calculate the maximum path sum from the left and right subtrees.
        # We use max(dfs(root.left), 0) to handle negative path sums.  If a subtree's max path sum is negative, we ignore it.
        left = max(dfs(root.left), 0) 
        right = max(dfs(root.right), 0)

        # Update the maximum path sum found so far. This line considers paths that pass through the current node.
        ans = max(ans, root.val + left + right) 

        # Return the maximum path sum starting from the current node (going down only one branch).
        return root.val + max(left, right)

    return ans


  • 该类和这两个函数的说明字符串,清楚地说明了它们的用途、参数和返回值。
  • 代码中的注释,用于解释每个步骤背后的逻辑。注释突出了防止重复计数节点值的重要方面。
  • 改进了变量名称,提高了可读性。
  • 明确的 is None 检查可提高清晰度(尽管 Python 的隐式真值在此处也适用)。
  • 更强大的 ans 初始化,用于处理所有节点值均为负数的情况。


