import os
os.environ['HDF5_USE_FILE_LOCKING']='FALSE'
from keras.layers import Input,Flatten,Dense
from keras import optimizers
from keras.models import Sequential,Model
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50
import keras.backend as K
import numpy as np
from glob import glob
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.switch_backend('agg')

# evaluation functions
def normalize_y_pred(y_pred):
    return K.one_hot(K.argmax(y_pred),y_pred.shape[-1])

def class_true_positive(class_label,y_true,y_pred):
    y_pred=normalize_y_pred(y_pred)
    return K.cast(K.equal(y_true[:,class_label]+y_pred[:,class_label],2),K.floatx())

def class_precision(class_label,y_true,y_pred):
    y_pred=normalize_y_pred(y_pred)
    return K.sum(class_true_positive(class_label,y_true,y_pred))/(K.sum(y_pred[:,class_label])+K.epsilon())

def macro_precision(y_true,y_pred):
    class_count=y_pred.shape[-1]
    return K.sum([class_precision(i,y_true,y_pred) for i in range(class_count)])\
        / K.cast(class_count, K.floatx())

# path to your dataset directory
base_dir='./dataset_kiruna'

# edit here too
train_dir=os.path.join(base_dir,'train_kiruna')
validation_dir=os.path.join(base_dir,'validation_kiruna')

num_of_train=len(glob(base_dir+'/train_kiruna/*/*.jpg'))
num_of_test=len(glob(base_dir+'/validation_kiruna/*/*.jpg'))

# hyper parameter
batch_size=128
# size of image (pix); original size is bigger but downsample it as preprocessing
img_size=128

# normalization of 8 bit color values to 0-1
train_datagen=ImageDataGenerator(rescale=1./255)
test_datagen=ImageDataGenerator(rescale=1./255)

train_generator=train_datagen.flow_from_directory(train_dir,target_size=(img_size,img_size),color_mode='rgb',batch_size=batch_size,class_mode='categorical')
validation_generator=test_datagen.flow_from_directory(validation_dir,target_size=(img_size,img_size),color_mode='rgb',batch_size=batch_size,class_mode='categorical')

# 3 means the number of color channels: RGB
input_tensor=Input(shape=(img_size,img_size,3))

RN50=ResNet50(include_top=False,weights='imagenet',input_tensor=input_tensor)

# 8 means of the number of classes
top_model=Sequential()
top_model.add(Flatten(input_shape=RN50.output_shape[1:]))
top_model.add(Dense(8,activation='softmax'))

model=Model(RN50.input,top_model(RN50.output))

model.compile(loss='categorical_crossentropy',optimizer=optimizers.SGD(lr=1e-3,momentum=0.9),metrics=['acc',macro_precision])
model.summary()


print(num_of_train)
print(num_of_test)
history=model.fit_generator(train_generator,steps_per_epoch=num_of_train//batch_size,epochs=10,validation_data=validation_generator,validation_steps=num_of_test//batch_size)
# file name of your model
model.save('./model/resnet50.h5')


# visualization of training progress
plt.plot(history.history['acc'],label='acc')
plt.plot(history.history['val_acc'],label='val_acc')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.grid()
plt.legend()
plt.savefig('acc_rn50.png',bbox_inches='tight')
plt.close()

plt.plot(history.history['loss'],label='loss')
plt.plot(history.history['val_loss'],label='val_loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.grid()
plt.legend()
plt.savefig('loss_rn50.png',bbox_inches='tight')
plt.close()

plt.plot(history.history['macro_precision'],label='macro_precision')
plt.plot(history.history['val_macro_precision'],label='val_macro_precision')
plt.ylabel('macro precision')
plt.xlabel('epoch')
plt.grid()
plt.legend()
plt.savefig('mp_rn50.png',bbox_inches='tight')
plt.close()
