from pyspark.context import SparkContext

def apply_discount(t):
    if t[-3] == '25' and int(t[-2]) >= 2:
        t[-1] = float(t[-1]) * 0.95
    else:
        t[-1] = float(t[-1])
    return t

if __name__ == '__main__':
    sc = SparkContext('local')
    lines = sc.textFile('ch04_data_transactions.txt')
    
    gifts = []
    
    # 1. Klijentu koji je napravio najviše transakcija kao poklon dodjeljuje se proizvod ID=4
    
    transactions = lines.map(lambda line: line.split('#'))
    
    trans_by_cust = transactions.map(lambda t: (int(t[2]), t))
    count_by_cust = trans_by_cust.countByKey()
    
    best_cid = None
    best_count = None
    for cid in count_by_cust:
        if best_cid is None or best_count < count_by_cust[cid]:
            best_cid = cid
            best_count = count_by_cust[cid]
    
    gifts.append(['2015-03-30', '6:18 AM', best_cid, '4', '1', '0.0'])
    
    
    # 2. Klijentima koji su kupili >=2 proizvoda ID=25 obračunati 5% popusta
    
    trans_by_cust.mapValues(apply_discount)
    
    print(trans_by_cust.collect())
    
    
    # 3. Klijentima koji su kupili >=5 proizvoda ID=81 pokloniti proizvod ID=70
    # ne nuzno u jednoj transakciji
    
    tmp = trans_by_cust.mapValues(lambda t: int(t[4]) if t[3] == '81' else 0)
    # print(tmp.collect())
    num81_by_cust = tmp.reduceByKey(lambda x, y: x + y)
    # print(type(num81_by_cust))
    # print(num81_by_cust.collect())
    
    for cid, _ in num81_by_cust.filter(lambda p: p[1] >= 5).collect():
        gifts.append(['2015-03-30', '6:18 AM', cid, '70', '1', '0.0'])
    
    # na slajdu je varijanta ako je kupljeno bar 5 u jednoj istoj transakciji
    
    
    # 4. Klijentu koji je potrošio najviše novca poklon je proizvod ID=63
    
    tmp = trans_by_cust.mapValues(lambda t: float(t[-1]))
    price_by_cust = tmp.reduceByKey(lambda x, y: x + y)
    cid, price = sorted(price_by_cust.collect(), key=lambda p: p[1])[-1]
    print(cid, price)
    gifts.append(['2015-03-30', '6:18 AM', cid, '63', '1', '0.0'])
    
    print('Nagrade:')
    print(gifts)
    
    print('Transakcije nagradjenih klijenata:')
    for gift in gifts:
        cid = gift[2]
        print(cid, ':', trans_by_cust.lookup(cid))
        
    # upisivanje nagradnih transakcija
    
    gifts_rdd = sc.parallelize(gifts)
    transactions = transactions.union(gifts_rdd)
    print(transactions.count())