# module-level functions, alternatively these could be implemented as methods,
# but then they might not be appropriate for every Tree
from tree import Tree
from csc148_queue import Queue

def arity(t):
    ''' (Tree) -> int

    Return the maximum branching factor of Tree t.

    >>> t = Tree(23)
    >>> arity(t)
    0
    >>> tn2 = Tree(2, [Tree(4), Tree(4.5), Tree(5), Tree(5.75)])
    >>> tn3 = Tree(3, [Tree(6), Tree(7)])
    >>> tn1 = Tree(1, [tn2, tn3])
    >>> arity(tn1)
    4
    '''
    return max([len(t.children)] + [arity(n) for n in t.children])

def count(t):
    ''' (Tree) -> int

    Return the number of nodes in Tree t.

    >>> t = Tree(17)
    >>> count(t)
    1
    >>> t4 = descendents_from_list(Tree(17), [0, 2, 4, 6, 8, 10, 11], 4)
    >>> count(t4)
    8
    '''
    return 1 + sum([count(n) for n in t.children])

def height(t):
    ''' (Tree) -> int

    Return 1 + length of longest path of t.

    >>> t = Tree(13)
    >>> height(t)
    1
    >>> t = descendents_from_list(Tree(13), [0, 1, 3, 5, 7, 9, 11, 13], 3)
    >>> height(t)
    3
    '''
    # 1 more edge than the maximum height of a child, except
    # what do we do if there are no children?
    return 1 + max([height(c) for c in t.children]) if t.children else 1

def leaf_count(t):
    ''' (Tree) -> int

    Return the number of leaves in Tree t.

    >>> t = Tree(7)
    >>> leaf_count(t)
    1
    >>> t = descendents_from_list(Tree(7), [0, 1, 3, 5, 7, 9, 11, 13], 3)
    >>> leaf_count(t)
    6
    '''
    return sum([leaf_count(c) for c in t.children]) + (0 if t.children else 1)

def list_all(t):
    ''' (Tree) -> list

    Return list of values in t.

    >>> t = Tree(0)
    >>> list_all(t)
    [0]
    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4, 5, 6, 7, 8], 3)
    >>> L = list_all(t)
    >>> L.sort()
    >>> L
    [0, 1, 2, 3, 4, 5, 6, 7, 8]
    '''
    # implicit base case when len(t.children) == 0
    return [t.value] + gather_lists([list_all(c) for c in t.children])

def list_leaves(t):
    ''' (Tree) -> list

    Return list of values in leaves of t.

    >>> t = Tree(0)
    >>> list_leaves(t)
    [0]
    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4, 5, 6, 7, 8], 3)
    >>> L = list_leaves(t)
    >>> L.sort() # so list is predictable to compare
    >>> L
    [3, 4, 5, 6, 7, 8]
    '''
    return ([t.value] if len(t.children) == 0 # root is a leaf, base case
            else gather_lists([list_leaves(c) for c in t.children]))

def list_interior(t):
    ''' (Tree) -> list

    Return list of values in interior nodes of t.

    >>> t = Tree(0)
    >>> list_interior(t)
    []
    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4, 5, 6, 7, 8], 3)
    >>> L = list_interior(t)
    >>> L.sort()
    >>> L
    [0, 1, 2]
    '''
    if len(t.children) == 0:
        return []
    else:
        return [t.value] + gather_lists([list_interior(c) for c in t.children])

def list_if(t, p):
    ''' (Tree, function) -> list

    Return a list of values in Tree t that satisfy p(value)

    >>> def p(v): return v > 4
    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4, 5, 6, 7, 8], 3)
    >>> L = list_if(t, p)
    >>> L.sort()
    >>> L
    [5, 6, 7, 8]
    >>> def p(v): return v % 2 == 0
    >>> L = list_if(t, p)
    >>> L.sort()
    >>> L
    [0, 2, 4, 6, 8]
    '''
    return (([t.value] if p(t.value) else []) +
        gather_lists([list_if(c, p) for c in t.children]))
#    Equivalent: return list(filter(p, list_all(t)))

def list_below(t, n):
    ''' (Tree, int) -> list

    Return list of values in t from nodes with paths no longer
    than n from root

    >>> t = Tree(0)
    >>> list_below(t, 0)
    [0]
    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4, 5, 6, 7, 8], 3)
    >>> L = list_below(t, 1)
    >>> L.sort()
    >>> L
    [0, 1, 2, 3]
    '''
    return ([] if n < 0      # don't list this node or its children
            else [t.value] + # list this node and possibly below it
            gather_lists([list_below(c, n-1) for c in t.children]))

def contains_test_passer(t: Tree, test: 'function') -> bool:
    '''Return whether t contains a value that test(value) returns True for

    >>> t = descendents_from_list(Tree(0), [1, 2, 3, 4.5, 5, 6, 7.5, 8, 8.5], 4)
    >>> def greater_than_nine(n): return n > 9
    >>> contains_test_passer(t, greater_than_nine)
    False
    >>> def even(n): return n % 2 == 0
    >>> contains_test_passer(t, even)
    True
    '''
    return test(t.value) or any([test(c.value) for c in t.children])

def gather_lists(L):
    ''' (list-of-lists) -> list

    Concatenate all the sublists of L and return the result.

    >>> gather_lists([[1, 2], [3, 4, 5]])
    [1, 2, 3, 4, 5]
    >>> gather_lists([[6, 7], [8], [9, 10, 11]])
    [6, 7, 8, 9, 10, 11]
    '''
    new_list = []
    for l in L:
        new_list += l
    return new_list


def descendents_from_list(t, L, arity):
    ''' (Tree, list, int) -> Tree

    Populate t's descendents from L, filling them
    in in level order, with up to arity children per node.
    Then return t.

    >>> descendents_from_list(Tree(0), [1, 2, 3, 4], 2)
    Tree(0, [Tree(1, [Tree(3), Tree(4)]), Tree(2)])
    '''
    q = Queue()
    q.enqueue(t)
    L = L.copy()
    while not q.is_empty(): # unlikely to happen
        new_t = q.dequeue()
        for i in range(0,arity):
            if len(L) == 0:
                return t # our work here is done
            else:
                new_t_child = Tree(L.pop(0))
                new_t.children.append(new_t_child)
                q.enqueue(new_t_child)
    return t


if __name__ == '__main__':
    import doctest
    doctest.testmod()
