# Offline tests script
#
# Script version: 1.0 
#
# Ver1.0: created by Mariko, 08.06.2026

##--- Import modules ---##
import xarray as xr
import sys
import os
import copy
import numpy as np


##--- Path to modules ---##
base = os.path.dirname(__file__)
sys.path.append(os.path.join(base, "rook/src/rook/utils"))


##--- Import decadal_fixes tool ---##
import decadal_fixes


##--- Functions ---##
def build_ds_id(ds):
    project = ds.attrs.get("mip_era", "CMIP6") # CMIP6
    activity = ds.attrs.get("activity_id", "UNKNOWN") # DCPP
    institution = ds.attrs.get("institution_id", "UNKNOWN") # NCC
    source_id = ds.attrs.get("source_id") # NorCPM1
    experiment = ds.attrs.get("experiment_id") #dcppB-forecast


    subexp = ds.attrs.get("sub_experiment_id")  # s2025
    variant = ds.attrs.get("variant_label")     # r1i1p1f1


    if subexp:
        member = f"{subexp}-{variant}"
    else:
        member = variant


    table = ds.attrs.get("table_id") # APmon
    grid = ds.attrs.get("grid_label") # gn

    # get main variables
    variables = [
        v for v in ds.data_vars
        if not v.endswith("_bnds")
    ]

    if len(variables) == 0:
        raise ValueError("No valid data variables found (only bounds?)")

    if len(variables) > 1:
        print("Warning: multiple variables found:", variables)

    #variable = variables[0]
    variable = next(iter(variables))

    # check required attributes for CMIP
    required = [source_id, experiment, member, table, grid]
    if any(v is None for v in required):
        raise ValueError("Missing required CMIP attributes")

    dataset_id = (
        f"{project}."
        f"{activity}."
        f"{institution}."
        f"{source_id}."
        f"{experiment}."
        f"{member}."
        f"{table}."
        f"{variable}."
        f"{grid}"
    )

    return dataset_id


def check_diff(ds_before, ds_after, name):
    print(f"\n---Decadal fix: {name} ---")

    # 1. Whole dataset
    print('## 1. Whole dataset ##')
    print("changed (identical):", not ds_before.identical(ds_after))
    print("changed (equals):", not ds_before.equals(ds_after))
    print('')

    # 2. Global attributes
    print('## 2. Global attributes ##')

    ## check number of attributes ##
    before_attrs = ds_before.attrs
    after_attrs = ds_after.attrs
    if len(before_attrs) != len(after_attrs):
        print("changed: number of attributes")
        print("  attrs count:", len(before_attrs), "->", len(after_attrs))
        print('')
    else:
        print('no change: number of attributes')
        print('')

    ## check attributes ##
    keys_before = set(before_attrs)
    keys_after = set(after_attrs)

    # added
    added = keys_after - keys_before
    if added:
        print(" + added attrs:")
        for k in list(added):
            print(f"    {k} = {after_attrs[k]}")

    else:
        print('no attributes added')
    print('')

    # removed
    removed = keys_before - keys_after
    if removed:
        print(" - removed attrs:")
        for k in list(removed):
            print(f"    {k}")
    else:
        print('no attributes removed')
    print('')

    # changed
    changed = [
        k for k in before_attrs
        if k in after_attrs and before_attrs[k] != after_attrs[k]
    ]
    if changed:
        print("global attributes: changed")
        for k in changed:
            print(f"    {k}: {before_attrs[k]} -> {after_attrs[k]}")

    else:
        print('no attributes changed')


    print('')

    # 3. Coordinates
    print('## 3. Coordinates ##')
    before_coords = set(ds_before.coords)
    after_coords = set(ds_after.coords)
    print("coords BEFORE:", list(ds_before.coords))
    print("coords AFTER :", list(ds_after.coords))
    print('')

    for c in after_coords - before_coords:
        print(f"+ added coord: {c}")

    for c in before_coords - after_coords:
        print(f"- removed coord: {c}")



    if before_coords != after_coords:
        print("coords changed:", before_coords, "->", after_coords)
    else:
        print('no change')
    print('')


    # 4. Dimension
    print('## 4. Dimension ##')
    print("dims BEFORE:", list(ds_before.dims))
    print("dims AFTER :", list(ds_after.dims))
    print('')


    # 5. data_vars
    print('## 5. Data var ##')
    before_vars = set(ds_before.data_vars)
    after_vars = set(ds_after.data_vars)

    print("var BEFORE:", list(ds_before.data_vars))
    print("var AFTER :", list(ds_after.data_vars))

    for v in after_vars - before_vars:
        print(f"+ added var: {v}")

    for v in before_vars - after_vars:
        print(f"- removed var: {v}")
    print('')


    # 6. Index
    print('## 6. Indexes ##')
    #print("indexes BEFORE:", ds_before.indexes)
    #print("indexes AFTER:", ds_after.indexes)
    before_indexes = ds_before.indexes
    after_indexes = ds_after.indexes

    for k in before_indexes:
        b = before_indexes[k]
        a = after_indexes[k]

        if str(b) != str(a):
            print(f"Indexes: {k} changed")
            print("BEFORE:", b)
            print("AFTER :", a)
        else:
            print(f"{k}: no change")

    print('')


    # 7. variable: tas
    print('## 7. Variable: tas ##')
    var = "tas"
    if var in ds_before and var in ds_after:
        v_before = ds_before[var].attrs
        v_after = ds_after[var].attrs

        vb_keys = set(v_before)
        va_keys = set(v_after)

        for k in va_keys - vb_keys:
            print(f"+ added var attr: {k} = {v_after[k]}")

        for k in vb_keys - va_keys:
            print(f"- removed var attr: {k}")

        for k in vb_keys & va_keys:
            if v_before[k] != v_after[k]:
                print(f"* changed var attr: {k}")
                print(f"    {v_before[k]} -> {v_after[k]}")


    if var in ds_before and var in ds_after:
        before_var = ds_before[var]
        after_var = ds_after[var]
        print("tas shape:", before_var.shape, "->", after_var.shape)
        if before_var.shape != after_var.shape:
            print('tas shape changed')
            print('')
        else:
            print('tas shape: no change')
            print('')

        try:
            before_mean = float(ds_before[var].mean())
            after_mean = float(ds_after[var].mean())
            print("tas mean:", before_mean, "->", after_mean)
            if before_mean != after_mean:
                print(f"* tas mean changed: {before_mean} -> {after_mean}")
                print('')
            else:
                print('tas mean: no change')
                print('')
        except Exception:
            print("tas mean: cannot compute")
            #pass
    print('')

    print('tas dtype:')
    if ds_before["tas"].dtype != ds_after["tas"].dtype:
        print('BEFORE: ', ds_before["tas"].dtype)
        print('AFTER: ', ds_after["tas"].dtype)
        print('')
    else:
        print('no change')
        print('')


    print('tas attributes:')
    if ds_before["tas"].attrs != ds_after["tas"].attrs:
        print('BEFORE: ', ds_before["tas"].attrs)
        print('AFTER: ', ds_after["tas"].attrs)
        print('')
    else:
        print('no change')
        print('')


    # 8. Calendar
    print('## 8. Calendar ##')
    cal_before = ds_before.time.encoding.get("calendar")
    cal_after = ds_after.time.encoding.get("calendar")
    #print("calendar:", cal_before, "->", cal_after)
    if cal_before != cal_after:
        print(f"* calendar changed: {cal_before} -> {cal_after}")
    else:
        print('no change')
        print('Calendar: ', cal_after)

    print('')


    # 9. encoding
    print('## 9. encoding ##')
    for v in ds_before.data_vars:
        if v in ds_after.data_vars:
            b_enc = ds_before[v].encoding
            a_enc = ds_after[v].encoding

            for k in set(b_enc) | set(a_enc):
                if b_enc.get(k) != a_enc.get(k):
                    print(f"[ENCODING] {v}.{k}: {b_enc.get(k)} -> {a_enc.get(k)}")


    b_time_enc = ds_before.time.encoding
    a_time_enc = ds_after.time.encoding
    for k in set(b_time_enc) | set(a_time_enc):
        if b_time_enc.get(k) != a_time_enc.get(k):
            print(f"[TIME ENCODING] {k}: {b_time_enc.get(k)} -> {a_time_enc.get(k)}")

    # coords encoding
    for c in ds_before.coords:
        if c in ds_after.coords:
            b_enc = ds_before[c].encoding
            a_enc = ds_after[c].encoding

            for k in set(b_enc) | set(a_enc):
                if b_enc.get(k) != a_enc.get(k):
                    print(f"[COORD ENCODING] {c}.{k}: {b_enc.get(k)} -> {a_enc.get(k)}")

    for v in ds_before.data_vars:
        if v in ds_after.data_vars:
            b_attr = ds_before[v].attrs
            a_attr = ds_after[v].attrs

            for k in set(b_attr) | set(a_attr):
                if b_attr.get(k) != a_attr.get(k):
                    print(f"[VAR ATTR] {v}.{k}: {b_attr.get(k)} -> {a_attr.get(k)}")


def show_diff(ds_before, ds_after):
    # 1. attrs
    for k in set(ds_before.attrs) | set(ds_after.attrs):
        b = ds_before.attrs.get(k)
        a = ds_after.attrs.get(k)
        if b != a:
            print(f"[ATTR] {k}: {b} -> {a}")

    # 2. coords
    b_coords = set(ds_before.coords)
    a_coords = set(ds_after.coords)
    if b_coords != a_coords:
        print(f"[COORDS] keys: {b_coords} -> {a_coords}")

    # 3. dims
    if ds_before.sizes != ds_after.sizes:
        print(f"[DIMS] {ds_before.dims} -> {ds_after.dims}")

    # 4. data_vars
    b_vars = set(ds_before.data_vars)
    a_vars = set(ds_after.data_vars)
    if b_vars != a_vars:
        print(f"[VARS] {b_vars} -> {a_vars}")


    # 5. indexes
    b_idx = ds_before.indexes
    a_idx = ds_after.indexes

    # diff
    if set(b_idx) != set(a_idx):
        print(f"[INDEX KEYS] {list(b_idx)} -> {list(a_idx)}")

    # 6. index
    for k in set(b_idx) & set(a_idx):
        if str(b_idx[k]) != str(a_idx[k]):
            print(f"[INDEX] {k} changed")
            print("  BEFORE:", b_idx[k])
            print("  AFTER :", a_idx[k])

    # 7. tas
    if "tas" in ds_before and "tas" in ds_after:
        try:
            b = float(ds_before["tas"].mean())
            a = float(ds_after["tas"].mean())
            if b != a:
                print(f"[tas mean] {b} -> {a}")
        except:
            pass

    # 8. calendar
    b_cal = ds_before.time.encoding.get("calendar")
    a_cal = ds_after.time.encoding.get("calendar")
    if b_cal != a_cal:
        print(f"[CALENDAR] {b_cal} -> {a_cal}")


    # 9. encoding
    for v in ds_before.data_vars:
        if v in ds_after.data_vars:
            b_enc = ds_before[v].encoding
            a_enc = ds_after[v].encoding

            for k in set(b_enc) | set(a_enc):
                if b_enc.get(k) != a_enc.get(k):
                    print(f"[ENCODING] {v}.{k}: {b_enc.get(k)} -> {a_enc.get(k)}")


    b_time_enc = ds_before.time.encoding
    a_time_enc = ds_after.time.encoding
    for k in set(b_time_enc) | set(a_time_enc):
        if b_time_enc.get(k) != a_time_enc.get(k):
            print(f"[TIME ENCODING] {k}: {b_time_enc.get(k)} -> {a_time_enc.get(k)}")

    # coords encoding
    for c in ds_before.coords:
        if c in ds_after.coords:
            b_enc = ds_before[c].encoding
            a_enc = ds_after[c].encoding

            for k in set(b_enc) | set(a_enc):
                if b_enc.get(k) != a_enc.get(k):
                    print(f"[COORD ENCODING] {c}.{k}: {b_enc.get(k)} -> {a_enc.get(k)}")

    for v in ds_before.data_vars:
        if v in ds_after.data_vars:
            b_attr = ds_before[v].attrs
            a_attr = ds_after[v].attrs

            for k in set(b_attr) | set(a_attr):
                if b_attr.get(k) != a_attr.get(k):
                    print(f"[VAR ATTR] {v}.{k}: {b_attr.get(k)} -> {a_attr.get(k)}")


##--- Path to data ---##
#-----------------------------------------
# This var needs to be updated every test!
#-----------------------------------------
model_name = 'NorCPM1'
#model_name = 'CMCC-CM2-SR5'
#model_name = 'EC-Earth3'
#model_name = 'HadGEM3-GC31-MM'
#model_name = 'MPI-ESM1-2-LR'


input_dir = '/nird/datalake/NS11071K/www/shared/c3s2/wp1/C3S2_375_ESGF_output_sample/'


if model_name == 'NorCPM1':
    dir_name = 'NorCPM1_dcppB-forecast_s2025-r1i1p1f1_v20260201/'
    file_name = 'tas_APmon_NorCPM1_dcppB-forecast_s2025-r1i1p1f1_gn_202511-203512.nc'
elif model_name == 'CMCC-CM2-SR5':
    dir_name = 'CMCC-CM2-SR5_dcppB-forecast_s2021-r40i1p1f1_v20260223/'
    file_name = 'tas_Amon_CMCC-CM2-SR5_dcppB-forecast_s2021-r40i1p1f1_gn_202111-203112.nc'
elif model_name == 'EC-Earth3':
    dir_name = 'EC-Earth3_dcppB-forecast_s2020-r1i1p1f1_v20260320/'
    file_name = 'tas_Amon_EC-Earth3_dcppB-forecast_s2020-r1i1p1f1_gr_202011-202110.nc'
elif model_name == 'HadGEM3-GC31-MM':
    dir_name = 'HadGEM3-GC31-MM_dcppB-forecast_s2024-r1i1p1f3_v20260318/'
    file_name = 'tas_Amon_HadGEM3-GC31-MM_dcppB-forecast_s2024-r1i1p1f3_gn_202411-202912.nc'
elif model_name == 'MPI-ESM1-2-LR':
    dir_name = 'MPI-ESM1-2-LR_dcppA-hindcast_s1999-r3i1p1f1_v20260223/'
    file_name = 'tas_APmon_MPI-ESM1-2-LR_dcppA-hindcast_s1999-r3i1p1f1_gn_199911-200912.nc'
else:
    print('no data')
    print('')
    sys.exit()


input_data = input_dir + dir_name + file_name

output_dir = 'decadal_fixes_out/' + dir_name
if not os.path.exists(output_dir):
    os.makedirs(output_dir)


##--- Read data ---##
print('=========================')
print('   Apply decadal fixes   ')
print('=========================')
print('')
ds = xr.open_dataset(input_data)



##--- Call functions  ---##
### Create dataset ID ###
ds_id = build_ds_id(ds)
print('Dataset ID:', ds_id)
print('')


### Apply decadal fixes ###
ds_copy = copy.deepcopy(ds)

ds_fixed = decadal_fixes.apply_decadal_fixes(ds_id,ds_copy)


### Test each decadal fix ###
ds_copy1 = copy.deepcopy(ds)
ds1 = decadal_fixes.decadal_fix_calendar(ds_id, ds_copy1)
check_diff(ds, ds1, "calendar")

ds_copy2 = copy.deepcopy(ds1)
ds2 = decadal_fixes.decadal_fix_1(ds_id, ds_copy2)
check_diff(ds1, ds2, "fix_1")

ds_copy3 = copy.deepcopy(ds2)
ds3 = decadal_fixes.decadal_fix_2(ds_id, ds_copy3)
check_diff(ds2, ds3, "fix_2")

ds_copy4 = copy.deepcopy(ds3)
ds4 = decadal_fixes.decadal_fix_3(ds_id, ds_copy4)
check_diff(ds3, ds4, "fix_3")

ds_copy5 = copy.deepcopy(ds4)
ds5 = decadal_fixes.decadal_fix_4(ds_id, ds_copy5)
check_diff(ds4, ds5, "fix_4")

ds_copy6 = copy.deepcopy(ds5)
ds6 = decadal_fixes.decadal_fix_5(ds_id, ds_copy6)
check_diff(ds5, ds6, "fix_5")



print('')
print('=============')
print('   Summary   ')
print('=============')
show_diff(ds, ds_fixed)
print('')


##--- Save fixed data as NetCDF ---##
print('===============================')
print('   Save fixed data as NetCDF   ')
print('===============================')

output_file = output_dir + file_name
ds_fixed.to_netcdf(output_file)

print('Save new data:')
print(output_file)
print('')
print('')


ds.close()

