# module-level functions to run on a BST
from bt import BTNode


def height(node):
    '''(BTNode) -> int

    Return height of tree rooted at node, in other words the
    number of nodes in a longest path.

    >>> height(None)
    0
    >>> height(BTNode(5))
    1
    >>> height(BTNode(5, BTNode(3), BTNode(7)))
    2
    '''
    return 1 + max(height(node.left), height(node.right)) if node else 0


def find_max(node):
    ''' (BTNode) -> BTNode

    Find and return node with maximum data, assume node is not None.

    Assumption: node is the root of a binary search tree.

    >>> find_max(BTNode(5, BTNode(3), BTNode(7)))
    BTNode(7, None, None)
    '''
    return find_max(node.right) if node.right else node


def insert(node, data):
    ''' (BTNode, object) -> BTNode

    Insert data in BST rooted at node if necessary, and return new root.

    >>> b = BTNode(5)
    >>> b1 = insert(b, 3)
    >>> print(b1)
    5
        3
    <BLANKLINE>
    '''
    return_node = node
    if not node:
        return_node = BTNode(data)
    elif data < node.data:
        node.left = insert(node.left, data)
    elif data > node.data:
        node.right = insert(node.right, data)
    else:  # nothing to do
        pass
    return return_node


def contains(node, value):
    ''' (BTNode, object) -> value

    Return whether tree rooted at node contains value.

    >>> contains(None, 5)
    False
    >>> contains(BTNode(5, BTNode(7), BTNode(9)), 7)
    True
    '''
    if node is None:
        return False
    else:
        return (node.data == value or 
                contains(node.left, value) or 
                contains(node.right, value))


def bst_contains(node, value):
    ''' (BTNode, object) -> value

    Return whether tree rooted at node contains value.

    Assume node is the root of a BST.

    >>> contains(None, 5)
    False
    >>> contains(BTNode(7, BTNode(5), BTNode(9)), 5)
    True
    '''
    assert node is None or isinstance(node, BTNode), (
        'Not a BTNode: {}'.format(node))
    if node is None:
        return False
    elif node.data > value:
        return bst_contains(node.left)
    elif node.data < value:
        return bst_contains(node.right)
    else:
        return True


def delete(node, data):
    ''' (BTNode, object) -> BTNode:

    Delete node containing data, if it exists, and return resulting tree.

        >>> b = BTNode(8)
        >>> b = insert(b, 4)
        >>> b = insert(b, 2)
        >>> b = insert(b, 6)
        >>> b = insert(b, 12)
        >>> b = insert(b, 14)
        >>> b = insert(b, 10)
        >>> b = delete(b, 12)
        >>> print(b)
                14
            10
        8
                6
            4
                2
        <BLANKLINE>
        >>> b = delete(b, 14)
        >>> print(b)
            10
        8
                6
            4
                2
        <BLANKLINE>
    '''
    # Algorithm for delete:
    # 1. If this node is None, return that
    # 2. If data is less than node.data, delete it from left child and
    #     return this node
    # 3. If data is more than node.data, delete it from right child
    #     and return this node
    # 4. If node with data has fewer than two children,
    #     and you know one is None, return the other one
    # 5. If node with data has two non-None children,
    #     replace data with that of its largest child in the left subtree,
    #     and delete that child, and return this node
    return_node = node
    if not node:
        pass
    elif data < node.data:
        node.left = delete(node.left, data)
    elif data > node.data:
        node.right = delete(node.right, data)
    elif not node.left:
        return_node = node.right
    elif not node.right:
        return_node = node.left
    else:
        node.data = find_max(node.left).data
        node.left = delete(node.left, node.data)
    return return_node


def find(node, data):
    ''' (BTNode, object) -> BTNode

    Return the BTNode containing data, or else None.

    >>> b = BTNode(5, BTNode(4))
    >>> find(b, 7) is None
    True
    >>> find(b, 4)
    BTNode(4, None, None)
    '''
    if not node or node.data == data:
        return node
    else:
        return (find(node.left, data) if (data < node.data)
                else find(node.right, data))


def evaluate(b):
    ''' (BTNode) -> float

    Evaluate the expression rooted at b.  If b is a leaf,
    return its float data.  Otherwise, evaluate b.left and
    b.right and combine them with b.data.

    Assume:  -- b is a binary tree
             -- interior nodes contain data in {'+', '-', '*', '/'}
             -- interior nodes always have two children
             -- leaves contain float data

    >>> b = BTNode(3.0)
    >>> evaluate(b)
    3.0
    >>> b = BTNode('*', BTNode(3.0), BTNode(4.0))
    >>> evaluate(b)
    12.0
    '''
    if is_leaf(b):
        return b.data
    else:
        # produce the string expression, then evaluate it
        return eval(str(evaluate(b.left)) + b.data + str(evaluate(b.right)))


def parenthesize(b):
    ''' (BTNode) -> str

    Parenthesize the expression rooted at b, so that float data is not parenthesized,
    but each pair of expressions joined by an operator are parenthesized.

    Assume:  -- b is a binary tree
             -- interior nodes contain data in {'+', '-', '*', '/'}
             -- interior nodes always have two children
             -- leaves contain float data

    >>> b = BTNode(3.0)
    >>> print(parenthesize(b))
    3.0
    >>> b = BTNode('+', BTNode('*', BTNode(3.0), BTNode(4.0)), BTNode(7.0))
    >>> print(parenthesize(b))
    ((3.0*4.0)+7.0)
    '''
    if is_leaf(b):
        return str(b.data)
    else:
        # produce the string expression, then evaluate it
        return '({}{}{})'.format(parenthesize(b.left), str(b.data), parenthesize(b.right))


def list_between(node, start, end):
    ''' (BTNode, object, object) -> list

    Return a Python list of all values in the binary search tree
    rooted at node that are between start and end (inclusive).

    >>> list_between(None, 3, 13)
    []
    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> list_between(b, 2, 3)
    [2]
    >>> L = list_between(b, 3, 11)
    >>> L.sort()
    >>> L
    [4, 6, 8, 10]
    '''
    if node is None:
        return []
    else:
        left_list = (list_between(node.left, start, end) 
                     if node.data > start 
                     else [])
        right_list = (list_between(node.right, start, end) 
                      if node.data < end 
                      else [])
        node_list = ([node.data] 
                     if (start <= node.data <= end) 
                     else [])
        return left_list + node_list + right_list


def list_internal_between(node, start, end):
        ''' (BTNode, object, object) ->

        Return a Python list of the data from all internal nodes of
        the tree rooted at node that are between start and end,
        inclusive.

        >>> list_internal_between(None, 3, 13)
        []
        >>> b = BTNode(8)
        >>> b = insert(b, 4)
        >>> b = insert(b, 2)
        >>> b = insert(b, 6)
        >>> b = insert(b, 12)
        >>> b = insert(b, 14)
        >>> b = insert(b, 10)
        >>> L = list_internal_between(b, 3, 13)
        >>> L.sort()
        >>> L
        [4, 8, 12]
        '''
        if node is None:
            return []
        else:
            left_list = (list_internal_between(node.left, start, end) 
                         if node.data > start
                         else [])
            right_list = (list_internal_between(node.right, start, end) 
                          if node.data < end 
                          else [])
            node_list = ([node.data]
                         if (start <= node.data <= end and not is_leaf(node)) 
                         else [])
            return left_list + node_list + right_list
            

def list_longest_path(node):
    ''' (BTNode) -> list

    List the data in a longest path of node.

    >>> list_longest_path(None)
    []
    >>> list_longest_path(BTNode(5))
    [5]
    >>> list_longest_path(BTNode(5, BTNode(3, BTNode(2), None), BTNode(7)))
    [5, 3, 2]
    '''
    if node is None:
        return []
    else:
        left_list = list_longest_path(node.left)
        right_list = list_longest_path(node.right)
        return ([node.data] + 
                (left_list 
                if (len(left_list) > len(right_list)) 
                else right_list))


def is_leaf(node):
    ''' (BTNode) -> bool

    Return whether nodeis a leaf.

    >>> b = BTNode(1, BTNode(2))
    >>> is_leaf(b)
    False
    >>> is_leaf(b.left)
    True
    '''
    return not node.left and not node.right


def inorder_visit(root, perform):
    ''' (BTNode, function) -> NoneType

    Visit each node of binary tree rooted at root in order and perform.

    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> def f(node): print(node.data)
    >>> inorder_visit(b, f)
    2
    4
    6
    8
    10
    12
    14
    '''
    if root is None:
        pass
    else:
        inorder_visit(root.left, perform)
        perform(root)
        inorder_visit(root.right, perform)


def preorder_visit(root, perform):
    ''' (BTNode, function) -> NoneType

    Visit each node of binary tree rooted at root in preorder
    and perform.

    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> def f(node): print(node.data)
    >>> preorder_visit(b, f)
    8
    4
    2
    6
    12
    10
    14
    '''
    if root is None:
        pass
    else:
        perform(root)
        preorder_visit(root.left, perform)
        preorder_visit(root.right, perform)


def postorder_visit(root, perform):
    ''' (BTNode, function) -> NoneType

    Visit each node of binary tree rooted at root in postorder
    and perform.

    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> def f(node): print(node.data)
    >>> postorder_visit(b, f)
    2
    6
    4
    10
    14
    12
    8
    '''
    if root is None:
        pass
    else:
        postorder_visit(root.left, perform)
        postorder_visit(root.right, perform)
        perform(root)


if __name__ == '__main__':
    import doctest
    doctest.testmod()
