class BTNode:
    '''Binary Tree node.'''

    def __init__(self, data, left=None, right=None):
        ''' (BTNode, object, BTNode, BTNode) -> NoneType

        Create BTNode (self) with data and children left and right.
        '''
        self.data, self.left, self.right = data, left, right

    def __eq__(self, other):
        ''' (BTNode, object) -> bool

        Return whether BTNode (self) is equivalent to other.

        >>> BTNode(7).__eq__('seven')
        False
        >>> b1 = BTNode(7, BTNode(5))
        >>> b1.__eq__(BTNode(7, BTNode(5), None))
        True
        '''
        return (type(self) == type(other) and
                self.data == other.data and
                (self.left, self.right) == (other.left, other.right))


    def __repr__(self):
        ''' (BTNode) -> str

        Represent BTNode (self) as a string that can be evaluated to
        produce an equivalent BTNode.

        >>> BTNode(1, BTNode(2), BTNode(3))
        BTNode(1, BTNode(2, None, None), BTNode(3, None, None))
        '''
        return 'BTNode({}, {}, {})'.format(repr(self.data),
                                           repr(self.left),
                                           repr(self.right))

    def __str__(self, indent=''):
        ''' (BTNode) -> str

        Return a user-friendly string representing BTNode (self) inorder.
        Indent by indent.

        >>> b = BTNode(1, BTNode(2, BTNode(3)), BTNode(4))
        >>> print(b)
            4
        1
            2
                3
        <BLANKLINE>
        '''
        right_tree = self.right.__str__(indent + '    ') if self.right else ''
        left_tree = self.left.__str__(indent + '    ') if self.left else ''
        return right_tree + '{}{}\n'.format(indent, str(self.data)) + left_tree


def insert(node, data):
    ''' (BTNode, object) -> BTNode

    Insert data in BST rooted at node if necessary, and return new root.

    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> print(b)
            14
        12
            10
    8
            6
        4
            2
    <BLANKLINE>
    '''
    return_node = node
    if not node:
        return_node = BTNode(data)
    elif data < node.data:
        node.left = insert(node.left, data)
    elif data > node.data:
        node.right = insert(node.right, data)
    else:  # nothing to do
        pass
    return return_node


def delete(node, data):
    ''' (BTNode, data) -> BTNode

    Delete, if it exists, node with data and return resulting tree.

    >>> b = BTNode(8)
    >>> b = insert(b, 4)
    >>> b = insert(b, 2)
    >>> b = insert(b, 6)
    >>> b = insert(b, 12)
    >>> b = insert(b, 14)
    >>> b = insert(b, 10)
    >>> b = delete(b, 12)
    >>> print(b)
            14
        10
    8
            6
        4
            2
    <BLANKLINE>
    '''
    # Algorithm for delete:
    # 1. If this node is None, return that
    # 2. If data is less than node.data, delete it from left child and
    #     return this node
    # 3. If data is more than node.data, delete it from right child
    #     and return this node
    # 4. If node with data has fewer than two children,
    #     and you know one is None, return the other one
    # 5. If node with data has two non-None children,
    #     replace data with that of its largest child in the left subtree,
    #     and delete that child, and return this node
    return_node = node
    if node is None:
        pass
    elif node.data > data:
        node.left = delete(node.left, data)
    elif node.data < data:
        node.right = delete(node.right, data)
    elif node.left is None:
        return_node = node.right
    elif node.right is None:
        return_node = node.left
    else:
        node.data = _find_max(node.left).data
        node.left = delete(node.left, node.data)
    return return_node

def _find_max(node: BTNode) -> BTNode:
    '''Find and return maximal node, assume node is not None'''
    return _find_max(node.right) if node.right else node


if __name__ == '__main__':
    import doctest
    doctest.testmod()
