import os
import json
import shutil
import subprocess
os.environ['HDF5_USE_FILE_LOCKING']='FALSE'
from tensorflow import keras
from keras.models import load_model
from keras.preprocessing import image
import requests
from bs4 import BeautifulSoup
import datetime as dt
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
import cv2
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
plt.switch_backend('agg')

def normalize_y_pred(y_pred):
    return K.one_hot(K.argmax(y_pred),y_pred.shape[-1])

def class_true_positive(class_label,y_true,y_pred):
    y_pred=normalize_y_pred(y_pred)
    return K.cast(K.equal(y_true[:,class_label]+y_pred[:,class_label],2),K.floatx())

def class_precision(class_label,y_true,y_pred):
    y_pred=normalize_y_pred(y_pred)
    return K.sum(class_true_positive(class_label,y_true,y_pred))/(K.sum(y_pred[:,class_label])+K.epsilon())

def macro_precision(y_true,y_pred):
    class_count=y_pred.shape[-1]
    return K.sum([class_precision(i,y_true,y_pred) for i in range(class_count)])\
        / K.cast(class_count, K.floatx())

def make_keogram(szn,daystr):
    df=pd.read_csv('/yourdirectory/'+daystr+'_pct.csv',header=0)
    path=sorted(glob('/yourdirectory/'+daystr+'/*_crop.jpg'))
    dt_arr=[dt.datetime.strptime(str(df.UT[i]),'%Y%m%d%H%M%S') for i in range(0,len(df))]
    x_lims=mdates.date2num(dt_arr)
    
    keo_arr=np.zeros((350,len(path),3)).astype(np.uint8)
    pct_arr=np.zeros((len(path),6))

    pct_arr[:,0]=df.arc
    pct_arr[:,1]=df.discrete
    pct_arr[:,2]=df.diffuse
    pct_arr[:,3]=np.array(df.apm)+np.array(df.apc)
    pct_arr[:,4]=np.array(df.clear)+np.array(df.cloud)
    pct_arr[:,5]=df.dd

    for i in range(0,len(path),1):
        path_i=path[i]
        im=np.array(Image.open(path_i))
        keo_arr[:,i,:]=im[30:380,200,:]

    fig=plt.figure(figsize=(9,5))
    plt.rcParams['font.size']=10
    plt.subplots_adjust(hspace=0.1)

#-------- upper panel -------------
    plt.subplot(2,1,1)
    X,Y=np.meshgrid(x_lims,np.arange(0,350,1))
    color_touple=keo_arr.reshape((keo_arr.shape[0]*keo_arr.shape[1],keo_arr.shape[2]))/255
    pc=plt.pcolormesh(X,Y,keo_arr[:,:,0],color=color_touple,shading='auto')
    pc.set_rasterized(True)
    plt.ylim([350,0])
    plt.ylabel('S-N [pix]')

    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
    plt.gca().xaxis.set_minor_locator(mdates.HourLocator(interval=1))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=2))
#plt.gcf().autofmt_xdate()
    plt.tick_params(labelbottom=False)

#-------- bottom panel -------------
    plt.subplot(2,1,2)

    plt.fill_between(dt_arr,np.zeros((len(path))),pct_arr[:,0],facecolor='rosybrown',label='Arc')    
    plt.fill_between(dt_arr,pct_arr[:,0],pct_arr[:,0]+pct_arr[:,1],facecolor='limegreen',label='Discrete')    
    plt.fill_between(dt_arr,pct_arr[:,0]+pct_arr[:,1],pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2],facecolor='green',label='Diffuse')    
    plt.fill_between(dt_arr,pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2],pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2]+pct_arr[:,3],facecolor='olive',label='Noisy aurora')    
    plt.fill_between(dt_arr,pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2]+pct_arr[:,3],pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2]+pct_arr[:,3]+pct_arr[:,4],facecolor='black',label='No aurora')    
    plt.fill_between(dt_arr,pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2]+pct_arr[:,3]+pct_arr[:,4],pct_arr[:,0]+pct_arr[:,1]+pct_arr[:,2]+pct_arr[:,3]+pct_arr[:,4]+pct_arr[:,5],facecolor='midnightblue',label='Dusk & Dawn')    

    plt.legend(loc='upper right',fontsize=8)

    plt.ylim([0,100])
    plt.xlim([dt_arr[0],dt_arr[-1]])
    plt.xlabel('UT on '+dt_arr[0].strftime('%d')+' & '+dt_arr[-1].strftime('%d %b %Y')+' [HH:MM]')
    plt.ylabel('Probability [%]')

    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
    plt.gca().xaxis.set_minor_locator(mdates.HourLocator(interval=1))
    plt.gca().xaxis.set_major_locator(mdates.HourLocator(interval=2))
    plt.gcf().autofmt_xdate()

    plt.savefig('/yourdirectory'+daystr+'.png',dpi=150,bbox_inches='tight')    

    plt.close()

def get_pct_arr(path,model):
    img=image.load_img(path,target_size=(128,128))
    im=image.img_to_array(img)
    im=np.expand_dims(im,axis=0)
    im=im/255.
    pred=np.array(model.predict(im)[0]*100)

    arc=pred[2]
    disc=pred[6]
    diff=pred[5]
    apc=pred[0]
    apm=pred[1]
    clear=pred[3]
    cloud=pred[4]
    dd=pred[7]

    arr=np.array([arc,disc,diff,apc,apm,clear,cloud,dd])
    return arr    

def make_json(szn, daystr):
    # jpg filename should be YYYYMMDDHHMMSS.jpg
    file_list = sorted(glob('./yourdirectory/*_crop.jpg'))

    total_images = len(file_list)
    observable_images = 0
    auroral_images = 0

    if total_images != 0:
        ut_arr = [path_i[-23:-9] for path_i in file_list]

        one_day_probs = np.zeros((len(file_list), 8), dtype=float)

        model = load_model('./model/resnet50_ski.h5',
                           custom_objects={'macro_precision': macro_precision})

        for i, img_path in enumerate(file_list):
            pct_arr = get_pct_arr(img_path, model)
            one_day_probs[i] = pct_arr

            if np.sum(pct_arr[0:3]) + pct_arr[5] >= 80:
                observable_images += 1
                if np.sum(pct_arr[0:3]) >= 80:
                    auroral_images += 1

        ut_series = pd.Series(ut_arr, name='UT')
        pct_frame = pd.DataFrame(
            one_day_probs,
            columns=['arc', 'discrete', 'diffuse', 'ac', 'am', 'clear', 'cloud', 'dd']
        )

        df = pd.concat([ut_series, pct_frame], axis=1)

        df.to_csv('./yourdirectory/pct.csv', index=False, float_format='%.1f')

        del model
        keras.backend.clear_session()

        summary_dict = {
            'Date': ut_arr[0][0:8],
            'Total': int(total_images),
            'Observable': int(observable_images),
            'Auroral': int(auroral_images)
        }

        with open('./yourdirectory/NumOfImgs.json', 'w') as f:
            json.dump(summary_dict, f, indent=4)

    else:
        summary_dict = {
            'Date': daystr,
            'Total': 0,
            'Observable': 0,
            'Auroral': 0
        }
        with open('./yourdirectory/NumOfImgs.json', 'w') as f:
            json.dump(summary_dict, f, indent=4)


def make_video(szn,daystr):
    path=sorted(glob('/yourdirectory/'+daystr+'/*.jpg'))
    tmpvideo='/yourdirectory/'+daystr+'/tmp.mp4'
    rewritevideo='/yourdirectory/'+daystr+'.mp4'
    fourcc=cv2.VideoWriter_fourcc(*'mp4v')

    video=cv2.VideoWriter(tmpvideo,fourcc,15,(400,400))
    for i in path:
        img=cv2.imread(i)
        video.write(img)
    video.release()    

    # ffmpeg for codec of H264 from mp4v
    command = f'ffmpeg -i "{tmpvideo}" -movflags faststart -vcodec libx264 -acodec libfaac "{rewritevideo}"'
    subprocess.call(command, shell=True)

    os.remove(tmpvideo)

# main --------------
analysis_date=dt.datetime.now()-dt.timedelta(days=1)
yesterday_str=analysis_date.strftime('%Y%m%d')

if analysis_date.month<=4:
    szn=analysis_date.year-1 # szn means auroral season (year); 2024 season is from sep 2024 to apr 2025
else:
    szn=analysis_date.year

# directory name is deleted for security reason
data_dir='/yourdirectory/'
www_dir='/yourdirectory/'

make_keogram(szn,yesterday_str)
shutil.copyfile(data_dir+'keo_'+yesterday_str+'.png',www_dir+'keo/keo_'+yesterday_str+'.png')
make_video(szn,yesterday_str)
shutil.copyfile(data_dir+yesterday_str+'.mp4',www_dir+'mov/'+yesterday_str+'.mp4')
