####################
# import libraries #
####################

import sys 
import numpy as np 
from netCDF4 import Dataset
import matplotlib.pyplot as plt

############
# settings #
############

EXP = 'noresm2-mm-seaclim_hindcast'
IDIR = '/nird/datalake/NS9039K/users/ingo/Projects/SEACLIM/data/noresm/output/raw' 
VNAME_NORESM = 'TREFHT'
CLIM_NORESM = 'noresm2-mm-seaclim_hindcast.cam.h0.clim.1993-2024.startmonth11.leadmonth1-64.mem1-10.nc'
VNAME_ERA5 = 't2m'
CLIM_ERA5 = 't2m_ERA5_f09_1993-2024_clim.nc'
OFFSET_K2C = -273.15 
SDATE = 19931101 
YEAR1 = 1993  
YEAR2 = 1999
MONTH1 = 11 
MONTH2 = 2 
MEMBERS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 
LONIND = 25 # corresponding to 30E 
LATIND = 176 # corresponding to 75N 

#############
# functions #
#############

def read_noresm_lonlat1(idir, exp, vname, sdate, members, year1, year2, month1, month2, lonind, latind):   
    def fpath(member, sdate, year, month):  
        return f'{idir}/{exp}/{exp}_{sdate}/{exp}_{sdate}_mem{member:0>3}/atm/hist/{exp}_{sdate}_mem{member:0>3}.cam.h0.{year}-{month:0>2}.nc' 
    nmember = len(members)
    nmonth = (year2 - year1 + 1) * 12 - (month1 - 1)  - (12 - month2)    
    data = np.zeros((nmember,nmonth))
    for imember, member in enumerate(members):
        imonth = 0 
        for year in range(year1, year2+1):
            m1 = month1 if year == year1 else 1 
            m2 = month2 if year == year2 else 12    
            for month in range(m1, m2+1):
                imonth += 1 
                #print(imonth, f'{year}-{month:0>2}', fpath(member,sdate,year,month))     
                #print(nc['lon'][lonind-1],nc['lat'][latind-1])
                nc = Dataset(fpath(member,sdate,year,month),'r')
                data[imember-1,imonth-1] = nc[vname][0,latind-1,lonind-1] + OFFSET_K2C
                nc.close()
    return data

def read_noresm_lonlat1_clim(exp, vname, lonind, latind):   
    fpath_clim = 'noresm2-mm-seaclim_hindcast.cam.h0.clim.1993-2024.startmonth11.leadmonth1-64.mem1-10.nc'
    nc = Dataset(fpath_clim,'r')
    clim = nc[vname][:,latind-1,lonind-1] + OFFSET_K2C
    nc.close()
    return clim

def read_era5_lonlat1(vname, year1, year2, month1, month2, lonind, latind):
    fpath_era5 = 't2m_ERA5_f09_199301-202412.nc'     
    nc = Dataset(fpath_era5,'r')
    nmonth = (year2 - year1 + 1) * 12 - (month1 - 1)  - (12 - month2)    
    data = np.zeros(nmonth)
    imonth = 0 
    for year in range(year1, year2+1):
        m1 = month1 if year == year1 else 1 
        m2 = month2 if year == year2 else 12    
        for month in range(m1, m2+1):
            imonth += 1 
            data[imonth-1] = nc[vname][imonth-1+(year1-1993)*12+month1-1,latind-1,lonind-1] + OFFSET_K2C
    nc.close()
    return data

def read_era5_lonlat1_clim(vname, year1, year2, month1, month2, lonind, latind):
    fpath_era5_clim = 't2m_ERA5_f09_1993-2024_clim.nc'     
    nc = Dataset(fpath_era5_clim,'r')
    nmonth = (year2 - year1 + 1) * 12 - (month1 - 1)  - (12 - month2)    
    data = np.zeros(nmonth)
    imonth = 0 
    for year in range(year1, year2+1):
        m1 = month1 if year == year1 else 1 
        m2 = month2 if year == year2 else 12    
        for month in range(m1, m2+1):
            imonth += 1 
            data[imonth-1] = nc[vname][month-1,latind-1,lonind-1] + OFFSET_K2C
    nc.close()
    return data

################
# prepare data #
################

# ERA5 and NorESM climological time series for forecast period (lead time dependent for NorESM)
clim_era5 = read_era5_lonlat1_clim(VNAME_ERA5, YEAR1, YEAR2, MONTH1, MONTH2, LONIND, LATIND)   
clim_noresm = read_noresm_lonlat1_clim(EXP, VNAME_NORESM, LONIND, LATIND)    

# ERA5 and NorESM forecast time series for forecast period  
data_era5 = read_era5_lonlat1(VNAME_ERA5, YEAR1, YEAR2, MONTH1, MONTH2, LONIND, LATIND)   
data_noresm = read_noresm_lonlat1(IDIR, EXP, VNAME_NORESM, SDATE, MEMBERS, YEAR1, YEAR2, MONTH1, MONTH2, LONIND, LATIND)   

# calibrate
data_noresm_calibrated = np.zeros(data_noresm.shape)
for imember in range(len(MEMBERS)):
    data_noresm_calibrated[imember,:] = data_noresm[imember,:] - clim_noresm + clim_era5

########
# plot #
########

# set font sizes
SMALL_SIZE = 16 
MEDIUM_SIZE = 18 
BIGGER_SIZE = 20 
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# make axes 
fig,ax = plt.subplots(nrows=1,ncols=2, sharex=False, sharey=False, figsize=[12,5])

# scatter plot of NorESM ensemble mean versus ERA5 
ax[0].plot([-20, 10], [-20, 10],color='gray')
hunc, = ax[0].plot(data_era5,data_noresm.mean(axis=0),marker='.',markersize=10,linestyle=None,linewidth=0,color='black',label='uncalibrated')
hcal, = ax[0].plot(data_era5,data_noresm_calibrated.mean(axis=0),marker='*',markersize=8,linestyle=None,linewidth=0,color='red',label='calibrated')
ax[0].set_xlim(-20,10)
ax[0].set_ylim(-20,10)
ax[0].set_xlabel('SAT ERA5 (K)')
ax[0].set_ylabel('SAT NorESM (K)')
ax[0].set_aspect('equal','box')
ax[0].legend([hcal, hunc], ['calibrated', 'uncalibrated'], loc='upper left',frameon=False)

# box plot for NorESM (individual members) and ERA5
ax[1].boxplot([data_era5, data_noresm_calibrated.flatten(), data_noresm.flatten()],tick_labels=['ERA5','calibrated','uncalibrated']) 
ax[1].set_ylabel('SAT (K)')

# save plot
fig.suptitle(f'SAT calibrated vs uncalibrated ([30E,75N], 1993-11 to 1999-02)')
plt.savefig('seaclim_calibration_example_monthly.png',format='png',dpi=300,bbox_inches='tight')