import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd
import os
import seaborn as sns
import json
from matplotlib.colors import LinearSegmentedColormap

def read_data(file_name: str):
    with open(file_name, 'r') as f:
        data = json.load(f)
    for key, value in data.items():
        for subkey, subvalue in value.items():
            data[key][subkey] = float(subvalue)
    df = pd.DataFrame.from_dict(data, orient='index')
    df = df.T
    df.index = pd.to_numeric(df.index)
    df.columns = pd.to_numeric(df.columns)
    return df

def draw_results(results, filename, title_benefit, annot_benefit=False, figure_size=(10, 10)):
    df=results
    df = df.astype(float)
    df.index = df.index / 1000
    df.index = df.index.map(int)
    df.columns = df.columns / 1000
    df.columns = df.columns.map(int)
    min_value = df.min().min()
    max_value = df.max().max()
    max_scale = max(abs(min_value/1000), abs(max_value/1000))

    df[df.columns[-1] + 1] = df.iloc[:, -1] 
    new_Data = pd.DataFrame(index=[df.index[-1] + 1], columns=df.columns)
    for i in df.columns:
        new_Data[i] = df[i].iloc[-1]
    df = pd.concat([df, new_Data])

    X, Y = np.meshgrid(np.arange(df.shape[1]), np.arange(df.shape[0]))

    def fmt(x,pos):
        return '{:.0f}'.format(x/1000)

    cmap = sns.color_palette("coolwarm", as_cmap=True)
    plt.figure(figsize=figure_size)
    ax = sns.heatmap(df/1000, fmt=".1f", cmap=cmap, vmin=-max_scale, vmax=max_scale, annot=annot_benefit)
    CS = ax.contour(X, Y, df,  colors='black', alpha=0.5)
    ax.clabel(CS, inline=True, fontsize=10, fmt=FuncFormatter(fmt))
    plt.title(title_benefit)
    plt.gca().invert_yaxis()
    plt.xlim(0, df.shape[1] - 1)
    plt.ylim(0, df.shape[0] - 1)
    plt.xlabel('ESS Capacity (MWh)')
    plt.ylabel('PV Capacity (MW)')
    plt.savefig(filename)

def draw_cost(costs, filename, title_cost, annot_cost=False, figure_size=(10, 10)):
    df = costs
    df = df.astype(int)
    df.index = df.index / 1000
    df.index = df.index.map(int)
    df.columns = df.columns / 1000
    df.columns = df.columns.map(int)

    df[df.columns[-1] + 1] = df.iloc[:, -1] 
    new_Data = pd.DataFrame(index=[df.index[-1] + 1], columns=df.columns)
    for i in df.columns:
        new_Data[i] = df[i].iloc[-1]
    df = pd.concat([df, new_Data])
    X, Y = np.meshgrid(np.arange(df.shape[1]), np.arange(df.shape[0]))

    def fmt(x, pos):
        return '{:.0f}'.format(x / 1000000)

    plt.figure(figsize=figure_size)
    ax = sns.heatmap(df/1000000,  fmt=".1f", cmap='viridis', annot=annot_cost)
    CS = ax.contour(X, Y, df,  colors='black', alpha=0.5)
    ax.clabel(CS, inline=True, fontsize=10, fmt=FuncFormatter(fmt))
    plt.title(title_cost)
    plt.gca().invert_yaxis()
    plt.xlim(0, df.shape[1] - 1)
    plt.ylim(0, df.shape[0] - 1)
    plt.xlabel('ESS Capacity (MWh)')
    plt.ylabel('PV Capacity (MW)')
    plt.savefig(filename)


def draw_overload(overload_cnt, filename, title_unmet, annot_unmet=False, figure_size=(10, 10)):
    df = overload_cnt
    df = (4 * 24 * 365 - df) / (4 * 24 * 365)
    df = df.astype(float)
    df.index = df.index / 1000
    df.index = df.index.map(int)
    df.columns = df.columns / 1000
    df.columns = df.columns.map(int)
    min_value = df.min().min()
    max_value = df.max().max()


    df[df.columns[-1] + 1] = df.iloc[:, -1] 
    new_Data = pd.DataFrame(index=[df.index[-1] + 1], columns=df.columns)
    for i in df.columns:
        new_Data[i] = df[i].iloc[-1]
    # print(new_Data)
    df = pd.concat([df, new_Data])


    plt.figure(figsize=figure_size)
    cmap = LinearSegmentedColormap.from_list("", ["white", "blue"])
    ax = sns.heatmap(df, fmt=".00%", cmap=cmap, vmin=0, vmax=1, annot=annot_unmet)

    cbar = ax.collections[0].colorbar
    cbar.set_ticks([0, 0.25, 0.5, 0.75, 1])
    cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'])
    cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{x:.0%}'))
    X, Y = np.meshgrid(np.arange(df.shape[1]), np.arange(df.shape[0]))

    def fmt(x, pos):
        return '{:.0f}%'.format(x * 100)
    CS = ax.contour(X, Y, df,  colors='black', alpha=0.5)

    ax.clabel(CS, inline=True, fontsize=10, fmt=FuncFormatter(fmt))

    plt.xlim(0, df.shape[1] - 1)
    plt.ylim(0, df.shape[0] - 1)
    plt.title(title_unmet)
    plt.xlabel('ESS Capacity (MWh)')
    plt.ylabel('PV Capacity (MW)')
    plt.savefig(filename)

with open('config.json', 'r') as f:
    js_data = json.load(f)

data = pd.read_csv('combined_data.csv')
time_interval = js_data["time_interval"]["numerator"] / js_data["time_interval"]["denominator"]

pv_loss = js_data["pv"]["loss"]
pv_cost_per_kW = js_data["pv"]["cost_per_kW"]
pv_lifetime = js_data["pv"]["lifetime"]

ess_loss = js_data["ess"]["loss"]
ess_cost_per_kW = js_data["ess"]["cost_per_kW"]
ess_lifetime = js_data["ess"]["lifetime"]

grid_loss = js_data["grid"]["loss"]
sell_price = js_data["grid"]["sell_price"]
grid_capacity = js_data["grid"]["capacity"]

pv_begin = js_data["pv_capacities"]["begin"]
pv_end = js_data["pv_capacities"]["end"]
pv_groups = js_data["pv_capacities"]["groups"]

ess_begin = js_data["ess_capacities"]["begin"]
ess_end = js_data["ess_capacities"]["end"]
ess_groups = js_data["ess_capacities"]["groups"]

annot_unmet = js_data["annotated"]["unmet_prob"]
annot_benefit = js_data["annotated"]["benefit"]
annot_cost = js_data["annotated"]["cost"]

title_unmet = js_data["plot_title"]["unmet_prob"]
title_cost = js_data["plot_title"]["cost"]
title_benefit = js_data["plot_title"]["benefit"]

figure_size = (js_data["figure_size"]["length"], js_data["figure_size"]["height"])

directory = 'data/'

file_list = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


split_files = [f.split('-') for f in file_list]

costs_files = [f for f in split_files if f[-1].endswith('costs.json')]
print(f'find costs files: {costs_files}')
overload_files = [f for f in split_files if f[-1].endswith('overload_cnt.json')]
print(f'find coverage/unmet files: {overload_files}')
results_files = [f for f in split_files if f[-1].endswith('results.json')]
print(f'find profit/benefit files: {results_files}')

costs_dfs = [read_data(directory + '-'.join(f)) for f in costs_files]
overload_dfs = [read_data(directory + '-'.join(f)) for f in overload_files]
results_dfs = [read_data(directory + '-'.join(f)) for f in results_files]

for costs_df, overload_df, results_df in zip(costs_dfs, overload_dfs, results_dfs):

    draw_cost(costs_df, 
              f'plots/costs-ess-{int(costs_df.columns[0])}-{int(costs_df.columns[-1])}-pv-{int(costs_df.index[0])}-{int(costs_df.index[-1])}.png', 
              title_cost=title_cost, 
              annot_cost=annot_cost)

    draw_overload(overload_df, 
                  f'plots/overload-ess-{overload_df.columns[0]}-{overload_df.columns[-1]}-pv-{overload_df.index[0]}-{overload_df.index[-1]}.png', 
                  title_unmet=title_unmet, 
                  annot_unmet=False)

    draw_results(results_df, 
                 f'plots/results-ess-{results_df.columns[0]}-{results_df.columns[-1]}-pv-{results_df.index[0]}-{results_df.index[-1]}.png', 
                 title_benefit=title_benefit,
                 annot_benefit=False)