import sqlite3
import csv, os
from typing import TextIO, List

def run_query(db: str, db_query: str, args: tuple = None) -> List[tuple]:
    """Return the results of running query db_query on database with name db.

    Optional query argumnts for can be included in args.
    """
    
    con = sqlite3.connect(db)
    cur = con.cursor()
    
    if args is None: 
        cur.execute(db_query)
    else:
        cur.execute(db_query, args)

    data = cur.fetchall()
    cur.close()
    con.close()
    return data

def create_precipitation_table(db: str, data_file: TextIO) -> None:
    """Populate the database db with the contents of data_file
    as follows:
    Create a table called precipitation with four columns:
    city (text), snow (real), total (integer), days (integer).
    
    data_file (csv file): every line contatins one city, 
    snowfall amount, total precipitation amount, 
    and number of days per line
    """
    
    # Connect to database db
    con = sqlite3.connect(db)
    
    # Get the cursor
    cur = con.cursor()
    
    # Create precipitation table
    cur.execute(
        """CREATE TABLE precipitation (
        city TEXT,
        snow REAL,
        total INTEGER,
        days INTEGER)""")
    
    # Populate the precipitation table
    # Loop through each line in the csv file and insert data into the table
    reader = csv.reader(data_file)
    for line in reader:
        city = line[0]
        snow = line[1]
        total = line[2]
        days = line[3]
        
        # Now, we execute a 'INSERT' SQL query
        cur.execute("""INSERT INTO precipitation
        VALUES(?, ?, ?, ?)""", (city, snow, total, days))
  
    # commit (save) changes
    con.commit()
    # Close the cursor
    cur.close()
    # Close the connection
    con.close()  
      
    
def create_temp_table(db: str, temperature_file: TextIO) -> None:
    """Create a temperature table in the database db and populate it with the 
    contents of temperature_file. 
    """
    
    con = sqlite3.connect(db)
    cur = con.cursor()

    # Create the Temperature table
    cur.execute('''CREATE TABLE temperature (
    city TEXT, 
    avgHigh REAL, 
    avgLow REAL)''')
    reader = csv.reader(temperature_file)
    # Populate the Temperature Table
    for data in reader:
        cur.execute('''INSERT INTO temperature 
        VALUES(?, ?, ? )''', \
            (data[0], data[1], data[2]))   

    # close the cursor
    cur.close()
    
    #commit the changes
    con.commit() 
    
    # close the connection
    con.close() 

def create_geography_table(db: str, geography_file: TextIO) -> None:
    """Create a geography table in the database db and populate it with the 
    contents of geography_file. 
    """
    
    con = sqlite3.connect(db)
    cur = con.cursor()

    # Create the Geography table
    cur.execute('''CREATE TABLE geography (
    city TEXT, province TEXT)''')
    reader = csv.reader(geography_file)
    # Populate the Geography Table
    for data in reader:
        cur.execute('''INSERT INTO geography 
        VALUES(?, ?)''',
            (data[0], data[1]))   

    # close the cursor
    cur.close()
    
    #commit the changes
    con.commit() 
    
    # close the connection
    con.close() 

## Query functions below
def warm_and_dry(db:str) -> List[tuple]:
    """Return the cities with less than 150cm of snow
    and average high more than 10 from database db.
    """
    
    query = """SELECT precipitation.city FROM 
               precipitation JOIN temperature
               ON precipitation.city = temperature.city
               WHERE temperature.avgHigh > ? AND precipitation.snow < ?
               """

    return run_query(db, query, (10, 150))

def average_snow(db: str) -> List[tuple]:
    '''Return the average snowfall amount in the 
    precipitation table of database db'''
    
    query = """SELECT AVG(snow)
            FROM precipitation
            """
    return run_query(db, query)

def precip_totals(db: str) -> List[tuple]:
    '''Return the total precipitation and total days of
    precipitation in table precipitation in database db'''
    
    query = """SELECT SUM(total), SUM(days)
            FROM precipitation
            """
    
    return run_query(db, query)
    
def avg_days_by_province(db: str) -> List[tuple]:
    """Return the provinces and the average days of 
    precipitation for each province"""
    
    query = """SELECT geography.province, 
            AVG(precipitation.days)
            FROM precipitation JOIN geography
            ON precipitation.city = geography.city
            GROUP BY geography.province
            """

    return run_query(db, query)

if __name__ == '__main__':
    db = 'weather.db'
    # create the database weather.db
    make_tables = True
    if make_tables:
        if os.path.exists(db):
            os.remove(db) # deletes database file
    
        # set up the precipitation table
        precipitation_file = open('precipitation.csv')
        create_precipitation_table(db, precipitation_file)
        precipitation_file.close()
            
        # We create a temperature table in the same database as precipitation
        temperature_file = open('temperature.csv')
        create_temp_table(db, temperature_file)
        temperature_file.close()  
        
        # We create a temperature table in the same database as precipitation
        geography_file = open('geography.csv')
        create_geography_table(db, geography_file)
        geography_file.close()         

    print(warm_and_dry(db))
    print(average_snow(db))
    print(precip_totals(db))
    print(avg_days_by_province(db))
   