"""Simple n-ary Tree ADT"""


class TreeNode:
    """Node with children."""

    def __init__(self: 'TreeNode',
                 value: object =None, children: list =None):
        """Node with any number of children"""

        self.value = value
        if not children:
            self.children = []
        else:
            self.children = children[:] # quick-n-dirty copy of list

    def __repr__(self: 'TreeNode'):
        """Represent this TreeNode as a string"""

        return ('TreeNode(' + str(self.value) + ', ' +
                repr(self.children) + ')')


class Tree:
    """Bare-bones Tree ADT"""

    def __init__(self, root=None):
        """Create a new tree rooted at root."""

        self.root = root

    def __contains__(self: 'Tree' , value: object) -> bool:
        """True if tree rooted at self.root has a node with value
        """
        def _contains(node: 'TreeNode') -> bool:
            """Helper that recurses on nodes"""
            if not node:
                return False
            if node.value == value:
                return True
            else:
                return any([_contains(n) for n in node.children])
        
        return _contains(self.root)

    def arity(self: 'Tree') -> int:
        """Branching factor of this tree"""
        def _arity(node: 'Node') -> int:
            """Branching factor of tree rooted at node"""
            if not node:
                return 0
            else:
                return max([len(node.children)] + 
                           [_arity(n) for n in node.children])
        
        return _arity(self.root)
    
    def count(self: 'Tree') -> int:
        """How many nodes in this Tree?"""
        def _count(node: 'TreeNode') -> int:
            """How many nodes in this subtree?"""
            if not node:
                return 0
            else:
                return 1 + sum([_count(n) for n in node.children])
            
        return _count(self.root)
            
if __name__ == '__main__':
    tn2 = TreeNode(2, [TreeNode(4), TreeNode(4.5), TreeNode(5), TreeNode(5.75)])
    tn3 = TreeNode(3, [TreeNode(6), TreeNode(7)])
    tn1 = TreeNode(1, [tn2, tn3])
    T = Tree(tn1)
    print("Arity: {}".format(T.arity()))
    print("5 in T?: {}".format(5 in T))
    print("5.5 in T?: {}".format(5.5 in T))
    print("4 in T?: {}".format(T.__contains__(4)))
    print("Node count T?: {}".format(T.count()))
    
