import pandas as pd
import re
import cartopy.feature as cf
#from multiprocessing import Process
import pickle
from matplotlib.colors import ListedColormap
import multiprocessing as mp
import numpy as np
import sys
import os
import h5py
import re
# import eccodes
# import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import matplotlib.pyplot as plt
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic

import cartopy.crs as ccrs
from matplotlib import rcParams

rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.weight'] = 'normal'
rcParams['font.size'] = 24
rcParams['xtick.labelsize'] = 24
rcParams['ytick.labelsize'] = 24
rcParams['axes.labelsize'] = 24
rcParams['axes.linewidth'] = 1
rcParams['axes.unicode_minus'] = False
rcParams['xtick.minor.width'] = 1
rcParams['ytick.minor.width'] = 1
rcParams['xtick.major.width'] = 1
rcParams['ytick.major.width'] = 1
rcParams['xtick.minor.size'] = 4
rcParams['ytick.minor.size'] = 4
rcParams['xtick.major.size'] = 5
rcParams['ytick.major.size'] = 5
rcParams['xtick.direction'] = 'in'
rcParams['ytick.direction'] = 'in'
rcParams['mathtext.fontset'] = 'cm'


import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
from matplotlib import gridspec
from matplotlib import ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import NullFormatter
from collections import OrderedDict, Counter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle, Ellipse
import matplotlib.patheffects as pe


#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


def convert_fields2load(pdata, fields_to_load):
  fields_new = []
  if 'topo' not in pdata._pd.variables:
    for field in fields_to_load:
      if field=='longitude':
        fields_new.append('lon')
      elif field=='latitude':
        fields_new.append('lat')
      elif field=='height':
        fields_new.append('z')
      elif field=='pr':
        fields_new.append('prs')
      elif field=='temperature':
        fields_new.append('T')
      else:
        fields_new.append(field)
    return fields_new
  else:
    return fields_to_load

def haversine(lon1,lat1,lon2,lat2):
  # convert decimal degrees to radians 
  lo1 = np.radians(lon1)
  lo2 = np.radians(lon2)
  la1 = np.radians(lat1)
  la2 = np.radians(lat2)

  # haversine formula 
  dlon = lo2 - lo1 
  dlat = la2 - la1 
  a = np.sin(dlat/2)**2 + np.cos(la1) * np.cos(la2) * np.sin(dlon/2)**2
  c = 2. * np.arcsin(np.sqrt(a)) 
  r = 6371.229 # Radius of earth in km.
  return c * r

def select_nc_files(years_all=(np.arange(1940,2024,1)).astype(int),data_path="/home/lucie/LARA/data_assimilation/"):
  d_output={}
  for year in years_all:
    d_output[year] = nv.Dataset(data_path+"ATCE_%i.nc" %(year),"r")

  return d_output

def compute_surface_area(grid_xlon,grid_ylat):
  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)

  lo = np.copy(grid_xlon)
  la = np.copy(grid_ylat)

  surface_area = np.zeros_like(lon2d_1)
  for i in range(len(lo)-1):
    a = haversine(lo[i],la[:-1],lo[i],la[1:])*1000.
    b = haversine(lo[i],la[1:],lo[i+1],la[1:])*1000.

    surface_area[:,i] = a*b
  return surface_area

def kg_to_mm_water(water_g, surface_area_m):
  #Assuming 1kg = 1L water
  masspart = 5.09256513E18/6.0e6 #kg
  w_m3 = water_g/1000.*masspart/1000. #gram to kg; L to m3
  w_mm = w_m3/surface_area_m*1000.
  return w_mm

def plot_timeseries(d_output,basename='ATCE_time',res=4):

  yrs = list(d_output.keys())
  resultjja = np.zeros(len(yrs))
  resultjjab = np.zeros(len(yrs))

  resultjja_n = np.zeros(len(yrs))
  resultjjab_n = np.zeros(len(yrs))

  resultjja_s = np.zeros(len(yrs))
  resultjjab_s = np.zeros(len(yrs))

  resultjja_sp = np.zeros(len(yrs))
  resultjjab_sp = np.zeros(len(yrs))


  resultjja_np = np.zeros(len(yrs))
  resultjjab_np = np.zeros(len(yrs))

  resultjja_sm = np.zeros(len(yrs))
  resultjjab_sm = np.zeros(len(yrs))

  resultjja_nm = np.zeros(len(yrs))
  resultjjab_nm = np.zeros(len(yrs))

  resultjja_t = np.zeros(len(yrs))
  resultjjab_t = np.zeros(len(yrs))

  kleuren=['slateblue','darkorange','mediumseagreen']
  kleurenfill=['mediumslateblue','orange','mediumseagreen']
  kleuren=['indigo','mediumslateblue','darkorange','mediumseagreen','firebrick','mediumorchid', 'coral','navy']

  heightlevels = np.array(['0', '1', '5','10','20'])

  fig = plt.figure(figsize=(18,16))
  ax3=plt.subplot(111)
  ax3.set_ylim([0,1])
  ax3.set_xlim([0,1])
  ax3.axis('off')
  ax = {}
  inner = gridspec.GridSpecFromSubplotSpec(190,120, subplot_spec=ax3, wspace=0.1, hspace=0.1)
  # ax_vert = fig.add_subplot(inner[108:110,10:100])
  # for i in range(2):
  #   for j in range(2):
  #     ax[i,j] = fig.add_subplot(inner[i*50:(i+1)*50-2,j*60:(j+1)*60-2],projection=ccrs.PlateCarree(central_longitude=0))
  #     ax[i,j].set_xticklabels([])
  #     ax[i,j].coastlines(zorder=100)
  #     ax[i,j].set_global()
  #     ax[i,j].gridlines()


  tags=np.array(['(a)', '(b)', '(c)'])

  ax = {}
  for i in range(3):
    ax[i]=fig.add_subplot(inner[i*40:(i+1)*40-2,:])
    ax[i].text(0.99,.97,tags[i],horizontalalignment='right',verticalalignment='top',
      transform=ax[i].transAxes,color='k',fontsize=22,
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
    if i < 2:
      ax[i].set_xticklabels([])

  for field in ['qv','pv','th']:
    #continue
    jj=-1
    for layers in [[0,1,2,3]]:
      jj+=1
      i=-1
      for yr in yrs:
        i+=1
        NA = np.sum(d_output[yr]['NA'][:-1,:,layers],axis=2)
        NA1 = np.sum(d_output[yr]['NA-1'][:-1,:,layers],axis=2)
        NAm1 = np.sum(d_output[yr]['NA+1'][:-1,:,layers],axis=2)
        A=np.sum(d_output[yr][field+'A'][:-1,:,layers],axis=2)
        A1=np.sum(d_output[yr][field+'A-1'][:-1,:,layers],axis=2)
        Am1=np.sum(d_output[yr][field+'A+1'][:-1,:,layers],axis=2)      
        grida=np.abs(np.roll(A,360,axis=1))
        grida1=np.abs(np.roll(A1,360,axis=1))
        gridam1=np.abs(np.roll(Am1,360,axis=1))

        gridn=np.abs(np.roll(NA,360,axis=1))
        gridn1=np.abs(np.roll(NA1,360,axis=1))
        gridnm1=np.abs(np.roll(NAm1,360,axis=1))

        resultjja[i] = np.sum(grida)/np.sum(gridn)

        resultjjab[i] = np.sum((grida1+gridam1)/2.)/np.sum((gridn1+gridnm1)/2.)

        tmp = (grida1+gridam1)/2.
        tmpn = (gridn1+gridnm1)/2.
        ieq = np.abs(d_output[yr]['lat'][:]).argmin()
        resultjja_n[i] = np.sum(grida[ieq:,:])/np.sum(gridn[ieq:,:])
        resultjjab_n[i] = np.sum(tmp[ieq:,:])/np.sum(tmpn[ieq:,:])
        resultjja_s[i] = np.sum(grida[:ieq,:])/np.sum(gridn[:ieq,:])
        resultjjab_s[i] = np.sum(tmp[:ieq,:])/np.sum(tmpn[:ieq,:])

        ispole=np.abs(d_output[yr]['lat'][:]+66.).argmin()
        inpole=np.abs(d_output[yr]['lat'][:]-66.).argmin()
        istrop=np.abs(d_output[yr]['lat'][:]+23.).argmin()
        introp=np.abs(d_output[yr]['lat'][:]-23.).argmin()

        resultjja_sp[i] = np.sum(grida[:ispole,:])/np.sum(gridn[:ispole,:])
        resultjja_np[i] = np.sum(grida[inpole:,:])/np.sum(gridn[inpole:,:])
        resultjjab_sp[i] = np.sum(tmp[:ispole,:])/np.sum(tmpn[:ispole,:])
        resultjjab_np[i] = np.sum(tmp[inpole:,:])/np.sum(tmpn[inpole:,:])

        resultjja_sm[i] = np.sum(grida[ispole:istrop,:])/np.sum(gridn[ispole:istrop,:])
        resultjja_nm[i] = np.sum(grida[introp:inpole,:])/np.sum(gridn[introp:inpole,:])
        resultjjab_sm[i] = np.sum(tmp[ispole:istrop,:])/np.sum(tmpn[ispole:istrop,:])
        resultjjab_nm[i] = np.sum(tmp[introp:inpole,:])/np.sum(tmpn[introp:inpole,:])

        resultjja_t[i] = np.sum(grida[istrop:introp,:])/np.sum(gridn[istrop:introp,:])
        resultjjab_t[i] = np.sum(tmp[istrop:introp,:])/np.sum(tmpn[istrop:introp,:])


      if field=='qv':
        fjja = 1e3/(8760)*1e5
        iax=0
        ax[0].set_ylabel('ATCE$_{\\rm c}$(q) \n[$10^{-5}$ g kg$^{-1}$ h$^{-1}$]')

      elif field=='pv':
        fjja = 1./(8760)*1e5
        iax=1
        ax[1].set_ylabel('ATCE$_{\\rm c}$(PV) \n[$10^{-5}$ pvu h$^{-1}$]')

      elif field=='th':
        fjja = 1./(8760)*1e5
        iax=2
        ax[2].set_ylabel('ATCE$_{\\rm c}$($\\theta$) \n[$10^{-5}$ K h$^{-1}$]')

      rass = resultjja*fjja
      rass_b = resultjjab*fjja
      rass_np = resultjja_np*fjja
      rass_np_b = resultjjab_np*fjja
      rass_sp = resultjja_sp*fjja
      rass_sp_b = resultjjab_sp*fjja
      rass_nm = resultjja_nm*fjja
      rass_nm_b = resultjjab_nm*fjja
      rass_sm = resultjja_sm*fjja
      rass_sm_b = resultjjab_sm*fjja
      rass_t = resultjja_t*fjja
      rass_t_b = resultjjab_t*fjja

      rass_nh = rass_np+rass_nm
      rass_nh_b = rass_np_b+rass_nm_b
      rass_sh = rass_sp+rass_sm
      rass_sh_b = rass_sp_b+rass_sm_b

      ax[iax].plot(yrs,rass-rass_b,color=kleuren[jj],lw=5,label='Global')#'Height '+heightlevels[layers[0]]+'-'+heightlevels[layers[-1]+1]+'km')
      if True:#jj==1:
        lab1='$\\rm{lat} > 23^{\\circ}$'
        lab2='$\\rm{lat} < -23^{\\circ}$'
        lab3='$|\\rm{lat}| < 23^{\\circ}$'
      else:
        lab1=None
        lab2=None
        lab3=None
      ax[iax].plot(yrs,rass_nh-rass_nh_b,color=kleuren[jj+1],lw=3.5,ls=':',label=lab1)
      ax[iax].plot(yrs,rass_sh-rass_sh_b,color=kleuren[jj+2],lw=3.5,ls='--',label=lab2)

      ax[iax].plot(yrs,rass_t-rass_t_b,color=kleuren[jj+3],lw=3.5,ls='-.',label=lab3)

      ax[iax].set_ylim(ymin=0)
      ax[iax].set_xlim((yrs[0],yrs[-1]))

  ax[0].set_ylim(ymax=2)
  ax[1].set_ylim(ymax=3.3)
  ax[2].set_ylim(ymax=17)
  ax[0].legend(frameon=False,loc='upper left')
  ax[2].set_xlabel('Year')
  

  # Plot maps

  field = 'th'
  layers=np.arange(0,4).astype(int)
  periods = ([np.arange(1959,1979,1).astype(int),
    np.arange(2000,2020,1).astype(int)])

  grid_xlona = np.arange(-180.25,180.25,0.5)
  grid_ylata = np.arange(-90.25,90.25,0.5)
  grid_xlon = grid_xlona[np.arange(res/2,len(grid_xlona)-1,res).astype(int)]
  grid_ylat = grid_ylata[np.arange(res/2,len(grid_ylata)-1,res).astype(int)]

  lon2d, lat2d = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)
  surface_area = 1#compute_surface_area(grid_xlon, grid_ylat)

  if field=='qv':
    factorjja = 1/surface_area/(8760)*1e15 # per surface area 1/m2
    #factorboth = 1/surface_area/(30+31+30+28+31+31)*1e15
  elif field=='pv':
    factorjja = 1/surface_area/(8760)*1e11
    #factorboth = 1/surface_area/(30+31+30+28+31+31)*1e11
  elif field=='th':
    factorjja = 1/surface_area/(8760)*1e5


    #factorboth = 1/surface_area/(30+31+30+28+31+31)*1e12
  ax3.fill_betweenx(np.array([0.048,0.311,0.38]), np.array([0.015,0.015,0.238]), np.array([0.477,0.477,0.458]), color='lavender',zorder=0)

  ax3.fill_betweenx(np.array([0.048,0.311,0.38]), np.array([0.5228,0.5228,0.73]), np.array([0.985,0.985,0.951]), color='lavender',zorder=0)

  if field=='qv':
    levels=np.linspace(0.3,1,6)
  elif field=='pv':
    levels=np.linspace(-2,1,6)
  elif field=='th':
    levels=np.linspace(0.3,16,40)

  if field=='qv':
    label='ATCE$_{\\rm c}$(q) [$10^{-12}$ g kg$^{-1}$ h$^{-1}$]'
  if field=='pv':
    label='ATCE$_{\\rm c}$(PV) [$10^{-11}$ pvu h$^{-1}$]'
  if field=='th':
    label='ATCE$_{\\rm c}$($\\theta$) [$10^{-5}$ K h$^{-1}$]'

  cmap=plt.cm.magma_r
  axm = {}
  gridjja={}
  ax_vert = fig.add_subplot(inner[188:190, 40:80])
  factorboth = factorjja
  factor_time = 1./(365*24)

  tags=np.array(['(d)', '(e)'])

  axm[0] = fig.add_subplot(inner[131:181,0:59],projection=ccrs.PlateCarree(central_longitude=0))
  axm[1] = fig.add_subplot(inner[131:181,61:],projection=ccrs.PlateCarree(central_longitude=0))
  layersall=[[0,1],[2,3]]
  for ii in range(2):
    #layers = layersall[ii]
    period = periods[ii]

    # Mark period in plot:
    ax[2].fill_betweenx(np.arange(0,2000,100),periods[ii][0],periods[ii][-1],color='lavender',zorder=0)
    axm[ii].set_yticklabels([])
    axm[ii].coastlines(zorder=100)
    axm[ii].set_global()
    axm[ii].gridlines()
    axm[ii].text(0.01,.97,tags[ii],horizontalalignment='left',verticalalignment='top',
      transform=axm[ii].transAxes,color='k',fontsize=22,
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))

    grida=None
    for yr in period:
      NA = np.sum(d_output[yr]['NA'][:-1,:,layers],axis=2)
      NA1 = np.sum(d_output[yr]['NA+1'][:-1,:,layers],axis=2)
      NAm1 = np.sum(d_output[yr]['NA-1'][:-1,:,layers],axis=2)
      if grida is None:
        A=np.sum(d_output[yr][field+'A'][:-1,:,layers],axis=2)
        A1=np.sum(d_output[yr][field+'A+1'][:-1,:,layers],axis=2)
        Am1=np.sum(d_output[yr][field+'A-1'][:-1,:,layers],axis=2)
        grida=np.roll(A/NA,360,axis=1)
        grida1=np.roll(A1/NA1,360,axis=1)
        gridam1=np.roll(Am1/NAm1,360,axis=1)
      else:
        A=np.sum(d_output[yr][field+'A'][:-1,:,layers],axis=2)
        A1=np.sum(d_output[yr][field+'A+1'][:-1,:,layers],axis=2)
        Am1=np.sum(d_output[yr][field+'A-1'][:-1,:,layers],axis=2)
        grida=grida+np.roll(A/NA,360,axis=1)
        grida1=grida1+np.roll(A1/NA1,360,axis=1)
        gridam1=gridam1+np.roll(Am1/NAm1,360,axis=1)

    grida=np.abs((grida-(grida1+gridam1)/2.)/len(period))
    gridjja[ii] = np.zeros((len(grid_ylat)-1,len(grid_xlon)-1))
    
    for i in range(len(gridjja[ii])):
      for j in range(len(gridjja[ii][0])):
        gridjja[ii][i,j] = np.mean(grida[i*res:(i+1)*res,j*res:(j+1)*res])

    cs= axm[ii].contourf(lon2d,lat2d,(gridjja[ii]*factorjja),transform=ccrs.PlateCarree(),extend='both',
      levels=levels,cmap=cmap)
    axm[ii].contour(lon2d,lat2d,(gridjja[ii]*factorjja),transform=ccrs.PlateCarree(),levels=levels,cmap=cmap)

    if field == 'qv':
      axm[ii].text(0.03,.03,'Global: %2.2f g kg$^{-1}$ h$^{-1}$' %np.mean(gridjja[ii]*1e3*factor_time),
        horizontalalignment='left',verticalalignment='bottom',transform=axm[ii].transAxes,
        color='k',fontsize=22,bbox=dict(facecolor='white',edgecolor='none',alpha=0.8,boxstyle='round',pad=0.01),zorder=99999)
    elif field == 'pv':
      axm[ii].text(0.03,.03,'Global: %2.2f pvu h$^{-1}$' %np.mean(gridjja[ii]*factor_time),
        horizontalalignment='left',verticalalignment='bottom',transform=axm[ii].transAxes,
        color='k',fontsize=22,bbox=dict(facecolor='white',edgecolor='none',alpha=0.8,boxstyle='round',pad=0.01),zorder=99999)
    elif field == 'th':
      axm[ii].text(0.03,.03,'Global: %2.2e K h$^{-1}$' %(np.mean(grida)*factor_time),
        horizontalalignment='left',verticalalignment='bottom',transform=axm[ii].transAxes,
        color='k',fontsize=22,bbox=dict(facecolor='white',edgecolor='none',alpha=0.8,boxstyle='round',pad=0.01),zorder=99999)

    cs.cmap.set_under('white')
    cs.cmap.set_over(cmap(0.99))
    axm[ii].set_xticks([-120,-60,0,60,120], crs=ccrs.PlateCarree())
    axm[ii].set_xticklabels(labels=['120$^{\\circ}$ W','60$^{\\circ}$ W','0$^{\\circ}$','60$^{\\circ}$ E','120$^{\\circ}$ E'],fontsize=24)
    axm[ii].set_xlabel('Longitude',fontsize=24)


    axm[ii].set_title('Period: %i-%i, Height: '%(period[0],period[-1])+heightlevels[layers[0]]+'-'+heightlevels[layers[-1]+1]+'km',fontsize=24)

  axm[0].set_yticks([-60,-30, 0, 30,60], crs=ccrs.PlateCarree())
  axm[0].set_yticklabels(labels=['60$^{\\circ}$ S','30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N','60$^{\\circ}$ N'],fontsize=24)
  axm[0].set_ylabel('Latitude',fontsize=24)


  cb1=fig.colorbar(cs,cax=ax_vert,orientation='horizontal',extend='both',label=label)
  cb1.set_ticks(levels[np.arange(0,len(levels),int(len(levels)/5)).astype(int)])
  cb1.update_ticks()
  ax_vert.xaxis.set_ticks_position('bottom')

  cb1.cmap.set_under('white')
  cb1.cmap.set_over(cmap(0.99))

  plt.subplots_adjust(top=0.98, bottom=0.08, left=0.08, right=0.99, wspace=0.05, hspace=0.1)
  plt.savefig(basename+'.png')
  plt.savefig(basename+'.pdf')
  plt.close()
