import pandas as pd
import re
import cartopy.feature as cf
#from multiprocessing import Process
import pickle
from matplotlib.colors import ListedColormap
import multiprocessing as mp
import numpy as np
import sys
import os
import h5py
import re
# import eccodes
# import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import matplotlib.pyplot as plt
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata,UnivariateSpline
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic, gaussian_kde

import cartopy.crs as ccrs
from matplotlib import rcParams

rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.weight'] = 'normal'
rcParams['font.size'] = 22
rcParams['xtick.labelsize'] = 22
rcParams['ytick.labelsize'] = 22
rcParams['axes.labelsize'] = 22
rcParams['axes.linewidth'] = 1
rcParams['axes.unicode_minus'] = False
rcParams['xtick.minor.width'] = 1
rcParams['ytick.minor.width'] = 1
rcParams['xtick.major.width'] = 1
rcParams['ytick.major.width'] = 1
rcParams['xtick.minor.size'] = 4
rcParams['ytick.minor.size'] = 4
rcParams['xtick.major.size'] = 5
rcParams['ytick.major.size'] = 5
rcParams['xtick.direction'] = 'in'
rcParams['ytick.direction'] = 'in'
rcParams['mathtext.fontset'] = 'cm'


import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
from matplotlib import gridspec
from matplotlib import ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import NullFormatter
from collections import OrderedDict, Counter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle, Ellipse
import matplotlib.patheffects as pe


#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  rho = density
  qv = specific humidity
  pv = potential vorticity
  tt = temperature
  uu = 
  vv =
  """ 

  def __init__(self, path_to_directory="/home/lucie/SizeTesting/output_1e6_23", 
    partposit_nc_fname = "partposit_average.nc",interpolatetopo=True):

    self.path = path_to_directory
    #Read in the header
    #self.header = read_header_bin(path_to_directory=self.path)
    #Npart from the header is not accurate when domainfilling is used, so this will be set when data is read instead
    self.npart = None

    self.partposit_nc_fname = partposit_nc_fname
    
    #Open netcdf particledata using xarray and naming it self._pd
    self._read_partposit_nc()
    self.coords = None

    self.kdtree = {}

    #2D and 3D fields interpolation (if required)
    self.topography = None
    self.tropopause = None
    self.hmixing = None
    self.interpolatetopo = interpolatetopo

  def _read_partposit_nc(self):
    tempfile = self.path + '/' + self.partposit_nc_fname
    self._pd = xr.open_dataset(tempfile, engine="netcdf4")
    #print("netCDF4 file opened.")

  @property
  def lon(self):
    if ('topo' in self._pd.variables):
      return self._pd.longitude
    else:
      return self._pd.lon

  @property
  def lon_av(self):
    return self._pd.longitude_av

  @property
  def lat(self):
    if ('topo' in self._pd.variables):
      return self._pd.latitude
    else:
      return self._pd.lat

  @property
  def lat_av(self):
    return self._pd.latitude_av

  @property
  def z(self):
    if ('topo' in self._pd.variables):
      return self._pd.height
    else:
      return self._pd.z

  @property
  def z_av(self):
    return self._pd.height_av

  @property
  def topo(self):
    if ('topo' in self._pd.variables):
      if self.interpolatetopo:
        if self.topography is None:
          topo_inter = self.interpolate2d(self._pd.topo,0)
          #allocate memory for topography per particle
          self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
          #For each timestep, interpolate the position of the particle to the grid
          for itime in range(len(self.topography[0,:])):

            topo_inter = self.interpolate2d(self._pd.topo,0)
            self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
            mini=np.min(self._pd.topo).values
            maxi=np.max(self._pd.topo).values
            self.topography[self.topography[:,itime]<mini,itime]=mini
            self.topography[self.topography[:,itime]>maxi,itime]=maxi
        return self.topography
      else:
        return self._pd.topo
    else:
      if self.interpolatetopo:
        if self.topography is None:
          topo_inter = self.interpolate2d(self._pd.to,0)
          #allocate memory for topography per particle
          self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
          #For each timestep, interpolate the position of the particle to the grid
          for itime in range(len(self.topography[0,:])):

            topo_inter = self.interpolate2d(self._pd.to,0)
            self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
            mini=np.min(self._pd.to).values
            maxi=np.max(self._pd.to).values
            self.topography[self.topography[:,itime]<mini,itime]=mini
            self.topography[self.topography[:,itime]>maxi,itime]=maxi
        return self.topography
      else:
        return self._pd.to

  @property
  def tro(self):
    if ('tr' in self._pd.variables):
      if len(self._pd.tr.dims)==3:
        if self.tropopause is None:
          self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
          for itime in range(len(self.tropopause[0,:])):
            tro_inter = self.interpolate2d(self._pd.tr,itime)
            self.tropopause[:,itime] = tro_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
            mini=np.min(self._pd.tr).values
            maxi=np.max(self._pd.tr).values
            self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
            self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
        return self.tropopause
      else:
        return self._pd.tr
    else:
      if len(self._pd.tro.dims)==3:
        if self.tropopause is None:
          self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
          for itime in range(len(self.tropopause[0,:])):
            tro_inter = self.interpolate2d(self._pd.tro,itime)
            self.tropopause[:,itime] = tro_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
            mini=np.min(self._pd.tro).values
            maxi=np.max(self._pd.tro).values
            self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
            self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
        return self.tropopause
      else:
        return self._pd.tro

  @property
  def hmix(self):
    if len(self._pd.hmix.dims)==3:
      if self.hmixing is None:
        self.hmixing = np.zeros((len(self.xlon),len(self.xlon[0,:])))
        for itime in range(len(self.hmixing[0,:])):
          hmix_inter = self.interpolate2d(self._pd.hmix,itime)
          self.hmixing[:,itime] = hmix_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
          mini=np.min(self._pd.hmix).values
          maxi=np.max(self._pd.hmix).values
          self.hmixing[self.hmixing[:,itime]<mini,itime]=mini
          self.hmixing[self.hmixing[:,itime]>maxi,itime]=maxi
      return self.hmixing
    else:
      return self._pd.hmix

  @property
  def rho(self):
    return self._pd.rho

  @property
  def rho_av(self):
    return self._pd.rho_av

  @property
  def qv(self):
    return self._pd.qv

  @property
  def qv_av(self):
    return self._pd.qv_av

  @property
  def prs(self):
    if ('pr' in self._pd.variables):
      return self._pd.pr
    else:
      return self._pd.prs

  @property
  def prs_av(self):
    return self._pd.pr_av

  @property
  def pv(self):
    return self._pd.pv

  @property
  def pv_av(self):
    return self._pd.pv_av

  @property
  def tt(self):
    if ('topo' in self._pd.variables):
      return self._pd.temperature
    else:
      return self._pd.T

  @property
  def tt_av(self):
    return self._pd.temperature_av

  @property
  def mass(self):
    return self._pd.mass001

  @property
  def time(self):
    return np.array(self._pd.time)

  @property
  def particle(self):
    return np.array(self._pd.particle)

  @property 
  def day(self):
    return self._pd.time.dt.day.values

  @property 
  def month(self):
    return self._pd.time.dt.month.values

  @property 
  def year(self):
    return self._pd.time.dt.year.values

  @property
  def hour(self):
    return self._pd.time.dt.hour.values

  @property
  def minute(self):
    return self._pd.time.dt.minute.values

  @property
  def second(self):
    return self._pd.time.dt.second.values

  @property
  def period(self):
    return self._pd.attrs[list(self._pd.attrs.keys())[0]]

  def interpolate2d(self,field,itime):
    #lon = field.lon.values-0.5
    lon=np.arange(0.,360.5,0.5)
    #lat = field.lat.values-90.-0.5
    lat=np.arange(-90,90.5,0.5)
    # x,y=np.meshgrid(lon,lat)
    # xy=np.stack([x.ravel(),y.ravel()],-1)
    # info=np.zeros_like(x)
    # if len(field.dims)==3:
    #   infotemp = field[:,:,itime].values
    #   info[1:,1:]=infotemp
    #   info[1:,0]=infotemp[:,1]
    #   info[0,1:]=infotemp[1,:]
    #   info[0,0]=infotemp[1,1]
    # else:
    #   infotemp = field[:,:].values
    #   info[1:,1:]=infotemp
    #   info[1:,0]=infotemp[:,1]
    #   info[0,1:]=infotemp[1,:]
    #   info[0,0]=infotemp[1,1]
    # return griddata((x.ravel(),y.ravel()),info.ravel(),(self.xlon[:,itime],self.ylat[:,itime]))
    if len(field.dims)==2:
      info = field[:,:].values
    else:
      info = field[:,:,itime].values
    return rbs(lat,lon,info)


def convert_fields2load(pdata, fields_to_load):
  fields_new = []
  if 'topo' not in pdata._pd.variables:
    for field in fields_to_load:
      if field=='longitude':
        fields_new.append('lon')
      elif field=='latitude':
        fields_new.append('lat')
      elif field=='height':
        fields_new.append('z')
      elif field=='pr':
        fields_new.append('prs')
      elif field=='temperature':
        fields_new.append('T')
      else:
        fields_new.append(field)
    return fields_new
  else:
    return fields_to_load

def save_netcdf(output, file_name, sel_dict):

  ds = nv.Dataset(file_name, 'w', format='NETCDF4')

  time = ds.createDimension('time', sel_dict['period_analyse_hours'])
  part = ds.createDimension('part', len(output['ipart']))

  times = ds.createVariable('time', 'f4', ('time',))
  parts = ds.createVariable('part', 'f4', ('part',))
  tpart = ds.createVariable('start_time_index', 'f4', ('part',))

  lstr = len(output['start_file_name'][0])
  filepart = ds.createVariable('start_file_name', 'S%i' %(lstr+5), ('part',))
  filepart[:] = output['start_file_name']

  times[:] = np.arange(sel_dict['period_analyse_hours'])
  parts[:] = output['ipart']
  tpart[:] = output['start_time_index']

  values = {}
  for field in sel_dict['fields_to_load']:
    values[field] = ds.createVariable(field, 'f4', ('part','time',))
    values[field][:,:] = output[field]

  ds.close()


def select_nc_files(data_path="/home/lucie/LARA/data_energy/",
  years_all=(np.arange(1940,2024,1)).astype(int)):
  d_output={}
  for year in years_all:
    d_output[year] = nv.Dataset(data_path+"Energy_%i.nc" %(year),"r")

  return d_output

def compute_surface_area(grid_xlon,grid_ylat):
  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)

  lo = np.copy(grid_xlon)
  la = np.copy(grid_ylat)

  surface_area = np.zeros_like(lon2d_1)
  for i in range(len(lo)-1):
    a = haversine(lo[i],la[:-1],lo[i],la[1:])*1000.
    b = haversine(lo[i],la[1:],lo[i+1],la[1:])*1000.

    surface_area[:,i] = a*b
  return surface_area

def kg_to_mm_water(water_g, surface_area_m):
  #Assuming 1kg = 1L water
  masspart = 5.09256513E18/6.0e6 #kg
  w_m3 = water_g/1000.*masspart/1000. #gram to kg; L to m3
  w_mm = w_m3/surface_area_m*1000.
  return w_mm


linestylesdict = OrderedDict(
    [('solid',               (0, ())),
     ('loosely dotted',      (0, (1, 10))),
     ('dotted',              (0, (1, 5))),
     ('densely dotted',      (0, (1, 1))),

     ('loosely dashed',      (0, (5, 10))),
     ('dashed',              (0, (5, 5))),
     ('densely dashed',      (0, (5, 1))),

     ('loosely dashdotted',  (0, (3, 10, 1, 10))),
     ('dashdotted',          (0, (3, 5, 1, 5))),
     ('densely dashdotted',  (0, (3, 1, 1, 1))),

     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))])

def fwhm(x,y):

  spline = UnivariateSpline(x, y-np.max(y)/2,s=0)
  r1,r2 = spline.roots()
  return r2,r1,x[np.argmax(y)]

  kde=gaussian_kde(data)
  x = np.linspace(data.min(), data.max(), 500)
  y = kde(x)
  half_max = np.max(y) / 2.
  xmax = np.argmax(y)
  left_idx = (np.abs(y[:xmax] - half_max)).argmin()
  right_idx = (np.abs(y[xmax:] - half_max)).argmin() + xmax

  return x[right_idx] - x[left_idx] #return the difference (full width)

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / float(N)


def plot_distribution(d_output,t_interval,iday,basename='Energy',N=10):

  yrs_elnino = np.array([1941,1942,1958,1966,1973,1978,1980,1983,1987,1992,1995,1998,2003,2007,2010,2016])
  yrs = np.array(list(d_output.keys()))
  percentages=np.array([0.1, 0.25,0.5,0.75,0.9])
  percentages=np.array([0.1, 0.5,0.9])

  yrsall=[yrs[:t_interval],yrs[int(len(yrs)/2-t_interval/2):int(len(yrs)/2+t_interval/2)],yrs[-t_interval:]]
  i_yrsall=([np.arange(t_interval),np.arange(int(len(yrs)/2-t_interval/2),int(len(yrs)/2+t_interval/2)),
    np.arange(len(yrs)-t_interval,len(yrs))])
  i_yrsnoelnino = [i for i in range(len(yrs)) if yrs[i] not in yrs_elnino]
  i_yrselnino = [i for i in range(len(yrs)) if yrs[i] in yrs_elnino]
  ls=['-','--','-.',linestylesdict['densely dashdotdotted'],':']
  alpha=[1, 0.8, 0.7, 0.6, 0.5]
  lw=[2, 2.5, 3, 3.5, 4]
  lonrange=[-60,60]
  latbins=np.arange(lonrange[0],lonrange[-1],1)
  heightbins=np.arange(1,20000,500)
  latnew,heightnew = np.meshgrid( latbins[:-1]+(latbins[1:]-latbins[:-1])/2.,heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2. )
  ls1='--'
  ls2='-'
  cmap=plt.cm.viridis
  
  PDOindex = np.genfromtxt('PDOindex.dat',names=True)
  ww = np.where((PDOindex['Year']>=1940)&(PDOindex['Year']<=yrs[-1]))[0]
  PDOindex = PDOindex[ww]
  PDOmarapr = (np.array(PDOindex['Mar']) + np.array(PDOindex['Apr']))/2

  AMOindex = np.genfromtxt('AMOindex.dat',names=True)
  ww = np.where((AMOindex['Year']>=1940)&(AMOindex['Year']<=yrs[-1]))[0]
  AMOindex = AMOindex[ww]
  ww3 = np.where(AMOindex['month']==3)[0]
  ww4 = np.where(AMOindex['month']==4)[0]
  AMOmarapr = (np.array(AMOindex['SSTA'][ww3])+np.array(AMOindex['SSTA'][ww4]))/2.

  kleuren=['slateblue','darkorange','mediumseagreen']
  kleurenfill=['mediumslateblue','orange','mediumseagreen']

  kleuren2=['cornflowerblue','indigo','firebrick','mediumorchid', 'coral','navy']
  lab=['Median','99% NH', 'FWHM NH', '99% SH', 'FWMH SH']
  lab2=['99%','FWHM']

  zmin=np.array([0])
  zmax=np.array([17000])

  fig = plt.figure(figsize=(20,9))
  ax3=plt.subplot(111)
  ax3.axis('off')
  ax={}
  ax2={}
  inner = gridspec.GridSpecFromSubplotSpec(102,150, subplot_spec=ax3, wspace=0, hspace=0)

  ax[0,0] = fig.add_subplot(inner[0:70,:42])
  ax[1,0] = fig.add_subplot(inner[70:86,:42])
  ax[2,0] = fig.add_subplot(inner[0:70,42:53])

  ax[3,0] = fig.add_subplot(inner[86:,:42])
  ax[4,0] = fig.add_subplot(inner[0:70,53:64])

  ax[0,0].set_xticklabels([])
  ax[2,0].set_yticklabels([])
  ax[4,0].set_yticklabels([])

  ax[0,1] = fig.add_subplot(inner[:,74:74+25])
  secax = fig.add_subplot(inner[:,74+25:74+25+30])
  ax[1,1] = secax.twiny()
  ax[2,1] = fig.add_subplot(inner[:,74+25+30:74+25+30+19])
  for i in range(1,3):
    ax[i,1].set_yticklabels([])

  lstd={}
  lhist={}
  hstd={}
  hhist={}
  lhhist={}
  median_lat=np.zeros(len(yrs))
  max_lat=np.zeros(len(yrs))
  median_99=np.zeros(len(yrs))
  median_fwhm=np.zeros(len(yrs)) 
  median_m99=np.zeros(len(yrs))
  median_mfwhm=np.zeros(len(yrs)) 

  latheighthist=np.zeros((len(yrs),len(latbins)-1,len(heightbins)-1))
  lathist=np.zeros((len(yrs),len(latbins)-1))
  heighthist=np.zeros((len(yrs),len(heightbins)-1))

  j=0
  for year in yrs:

    lat = d_output[year]['lat'][:]
    z = d_output[year]['z'][:]
    ww=np.where((z[:,0]>zmin[0])&(z[:,0]<zmax[0])&(z[:,iday]!=0.))[0]
    print(year,len(ww), len(np.where(z[:,iday]==0.)[0]))
    lat = lat[ww,iday]
    lat[np.isnan(lat)]=-1
      
    latheighthist[j,:,:] = np.histogram2d(lat,z[ww,iday],bins=[latbins,heightbins],density=True)[0]
    lathist[j,:] = np.histogram(lat,bins=latbins,density=True)[0]
    heighthist[j,:] = np.histogram(z[ww,iday],bins=heightbins,density=True)[0]

    # Compute time series of median and 99 percentiles
    median_lat[j] = np.median(lat)
    median_99[j] = np.sort(lat)[-int(len(ww)/100)]
    median_m99[j] = np.sort(lat)[int(len(ww)/100)]
    median_fwhm[j],median_mfwhm[j],max_lat[j] = fwhm(latbins[:-1]+(latbins[1:]-latbins[:-1])/2.,lathist[j,:])
    j=j+1

  for ii in range(len(i_yrsall)):

    lstd[ii]=np.std(lathist[i_yrsall[ii],:],axis=0)
    lhist[ii]=np.mean(lathist[i_yrsall[ii],:],axis=0)
    hstd[ii]=np.std(heighthist[i_yrsall[ii],:],axis=0)
    hhist[ii]=np.mean(heighthist[i_yrsall[ii],:],axis=0)

    counts = np.mean(latheighthist[i_yrsall[ii],:,:],axis=0)
    counts_sorted = np.sort(np.ravel(counts))[::-1]
    lhhist[ii] = {}
    lhhist[ii]['counts'] = counts
    lhhist[ii]['counts_sorted'] = counts_sorted
    lhhist[ii]['fraction'] = np.cumsum(counts_sorted)/np.sum(counts_sorted)

  # Plot time series

  tags=np.array(['(b1)', '(b2)', '(b3)'])
  for itag in range(3):
    ax[itag,1].text(0.99,.99,tags[itag],horizontalalignment='right',verticalalignment='top',
      transform=ax[itag,1].transAxes,color='k',
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))

  ax[0,1].plot(median_lat, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[0])
  #ax[0,1].plot(yrs, max_lat, ls='-', lw=3, alpha=0.6, color=kleuren2[5])
  ax[0,1].plot(median_99, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[1])
  ax[0,1].plot(median_m99, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[3])
  ax[0,1].plot(median_fwhm, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[2])
  ax[0,1].plot(median_mfwhm, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[4])

  yrs_rm = running_mean(yrs,N)
  ax[0,1].plot(running_mean(median_lat,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[0])
  #ax[0,1].plot(yrs_rm,running_mean(max_lat,N), ls='--', color=kleuren2[5])
  ax[0,1].plot(running_mean(median_99,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[1])
  ax[0,1].plot(running_mean(median_m99,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[3])
  ax[0,1].plot(running_mean(median_fwhm,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[2])
  ax[0,1].plot(running_mean(median_mfwhm,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[4])
  for yr in yrs_elnino:
    ax[0,1].axhline(yr,color='lavender',lw=3, zorder=0)
    #ax[1,1].axhline(yr,color='lavender',lw=3, zorder=0)
  ax[2,1].axvline(0, color='lavender', lw=3,zorder=0)
  ax[0,1].axvline(0, color='lavender', lw=3,zorder=0)
  ax[1,0].axvline(0, color='lavender', lw=3,zorder=0)
  ax[0,0].axvline(0, color='lavender', lw=3,zorder=0)
  ax[3,0].axvline(0, color='lavender', lw=3,zorder=0)
  secax.axvline(0, color='lavender', lw=3,zorder=0)


  yrs_rm = running_mean(yrs,N)
  rm_temp = running_mean(median_lat,N)

  amokleur = (230/255,255/255,179/255)#'palegreen'
  amokleuredge = (51/255,102/255,0/255)
  PDOindex_rm = running_mean(PDOmarapr,N)
  AMOindex_rm = running_mean(AMOmarapr,N)
  secax.fill_betweenx(yrs_rm,PDOindex_rm,edgecolor='k',facecolor='lightgrey',zorder=1, ls=ls2,alpha=0.3)
  secax.fill_betweenx(yrs_rm,AMOindex_rm,edgecolor=amokleuredge,facecolor=amokleur,zorder=1, ls=ls2,alpha=0.3)
  secax.set_xlim(xmax=-secax.get_xlim()[0])
  #secax.axis('off')

  # ax[1,1].plot(PDOmarapr[:-1],yrs,color='peachpuff')
  # ax[1,1].plot(AMOmarapr,yrs,color='lightsteelblue')

  N2=5
  yrs_rm2 = running_mean(yrs[i_yrselnino],N2)
  rm_temp2 = running_mean(median_lat[i_yrselnino],N2)
  ax[1,1].plot(rm_temp-rm_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[0])
  ax[1,0].axvline(np.mean(median_lat), ls=ls2, lw=3, color=kleuren2[0])
  #ax[1,1].plot(rm_temp2-rm_temp[0], yrs_rm2, ls=':', color=kleuren2[0],alpha=0.5)
  # rm_temp = running_mean(max_lat,N)
  # ax[1,1].plot(yrs_rm,rm_temp-rm_temp[0], ls='-', color=kleuren2[5])
  m99_temp = running_mean(median_99,N)
  m99en_temp = running_mean(median_99[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp1-m99_temp[0], yrs_rm2, ls=':', color=kleuren2[1])
  if N>10:
    ax[1,1].plot(m99_temp-m99_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[1])
  ax[1,0].axvline(np.mean(median_99), ls=ls2, lw=3, color=kleuren2[1])

  mm99_temp = running_mean(median_m99,N)
  mm99en_temp = running_mean(median_m99[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp2-mm99_temp[0], yrs_rm2, ls=':', color=kleuren2[3])
  if N>10:
    ax[1,1].plot(mm99_temp-mm99_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[3])
  ax[1,0].axvline(np.mean(median_m99), ls=ls2, lw=3, color=kleuren2[3])
  
  # ax[2,1].plot(rm_temp1-rm_temp2- (m99_temp-mm99_temp)[0],yrs_rm2, kleuren2[0],ls=':')

  fwhm_temp = running_mean(median_fwhm,N)
  fwhmen_temp = running_mean(median_fwhm[i_yrselnino],N2)
  #ax[1,1].plot(rm_temp1-fwhm_temp[0], yrs_rm2, ls=':', color=kleuren2[2],alpha=0.5)
  ax[1,1].plot(fwhm_temp-fwhm_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[2])
  ax[1,0].axvline(np.mean(median_fwhm), ls=ls2, lw=3, color=kleuren2[2])
  
  mfwhm_temp = running_mean(median_mfwhm,N)
  mfwhmen_temp = running_mean(median_mfwhm[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp2-mfwhm_temp[0], yrs_rm2, ls=':', color=kleuren2[4])
  ax[1,1].plot(mfwhm_temp-mfwhm_temp[0], yrs_rm, ls=ls2, lw=3, color=kleuren2[4])
  ax[1,0].axvline(np.mean(median_mfwhm), ls=ls2, lw=3, color=kleuren2[4])
  # ax[2,1].plot(rm_temp1-rm_temp2- (fwhm_temp-mfwhm_temp)[0],yrs_rm2, kleuren2[2],ls=':')

  ax[2,1].plot(fwhm_temp-mfwhm_temp- (fwhm_temp-mfwhm_temp)[0],yrs_rm, kleuren2[2],lw=3,ls=ls2)
  ax[2,1].fill_betweenx(yrs_rm,0,fwhm_temp-mfwhm_temp- (fwhm_temp-mfwhm_temp)[0], color=kleuren2[4],alpha=0.4)
  ax[2,1].plot(m99_temp-mm99_temp- (m99_temp-mm99_temp)[0], yrs_rm,kleuren2[1],lw=3,ls=ls2)
  ax[2,1].fill_betweenx(yrs_rm,0,m99_temp-mm99_temp- (m99_temp-mm99_temp)[0], color= kleuren2[3],alpha=0.4)



  # ax[2,1].plot(fwhmen_temp-mfwhmen_temp- (fwhm_temp-mfwhm_temp)[0],yrs_rm2, kleuren2[2],lw=1,ls=ls2)
  # ax[2,1].fill_betweenx(yrs_rm2,0,fwhmen_temp-mfwhmen_temp- (fwhm_temp-mfwhm_temp)[0], color=kleuren2[4],alpha=0.2)
  # ax[2,1].plot(m99en_temp-mm99en_temp- (m99_temp-mm99_temp)[0], yrs_rm2,kleuren2[1],lw=1,ls=ls2)
  # ax[2,1].fill_betweenx(yrs_rm2,0,m99en_temp-mm99en_temp- (m99_temp-mm99_temp)[0], color= kleuren2[3],alpha=0.2)

  # rm_temp = running_mean(max_lat,N)
  # ax[1,1].plot(yrs_rm,rm_temp-rm_temp[0], ls='-', color=kleuren2[5])
  ax[0,1].text(0,yrs[-2],'Median',va='top',ha='right', color=kleuren2[0], rotation='vertical')
  ax[0,1].set_xlim(lonrange)
  print('Delta99:', np.mean(median_99[i_yrsnoelnino]-median_m99[i_yrsnoelnino])-
    np.mean(median_99[i_yrselnino]-median_m99[i_yrselnino]))
  print('DeltaFWHM:', np.mean(median_fwhm[i_yrsnoelnino]-median_mfwhm[i_yrsnoelnino])-
    np.mean(median_fwhm[i_yrselnino]-median_mfwhm[i_yrselnino]))
  i_fwhm=-int(len(yrs)/3)
  ax[0,1].axhline(yrs[i_fwhm],xmin=0.5,
    xmax=(median_fwhm[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]), color=kleuren2[2])
  ax[0,1].axhline(yrs[i_fwhm],xmin=(median_mfwhm[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]),
    xmax=0.5, color=kleuren2[4])
  print(yrs[i_fwhm])
  ax[0,1].text(0,yrs[i_fwhm]-0.5,'FWHM',va='top',ha='center', color=kleuren2[2], rotation='horizontal')
  ax[0,1].text(median_mfwhm[-2],yrs[-2],'FWHM SH',va='top',ha='right', color=kleuren2[4], rotation='vertical')
  ax[0,1].text(median_fwhm[-2],yrs[-2],'FWHM NH',va='top',ha='right', color=kleuren2[2], rotation='vertical')

  i_fwhm=-int(len(yrs)/2)
  ax[0,1].axhline(yrs[i_fwhm],xmin=0.5,
    xmax=(median_99[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]), color=kleuren2[5])
  ax[0,1].axhline(yrs[i_fwhm],xmin=(median_m99[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]),
    xmax=0.5, color=kleuren2[3])
  ax[0,1].text(0,yrs[i_fwhm]-0.5,'99%',va='top',ha='center', color=kleuren2[1], rotation='horizontal')
  ax[0,1].text(median_m99[-2],yrs[-2],'99% SH',va='top',ha='right', color=kleuren2[3], rotation='vertical')
  ax[0,1].text(median_99[-2],yrs[-2],'99% NH',va='top',ha='right', color=kleuren2[1], rotation='vertical')
  ax[0,1].set_ylim((yrs[0],yrs[-1]))
  ax[1,1].set_ylim((yrs[0],yrs[-1]))
  ax[2,1].set_ylim((yrs[0],yrs[-1]))
  ax[0,1].set_ylabel('Year')
  # ax[0,1].set_xticks([-60,-30, 0, 30,60])
  # ax[0,1].set_xticklabels(labels=['-60$^{\\circ}$ N','-30$^{\\circ}$ N','0$^{\\circ}$ N','30$^{\\circ}$ N','60$^{\\circ}$ N'])
  ax[0,1].set_xticks([-30, 0, 30])
  ax[0,1].set_xticklabels(labels=['30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N'])
  ax[0,1].set_xlabel('Latitude')
  ax[1,1].set_xlabel('$\\Delta$ Lat ($\\degree$ N)')
  ax[2,1].set_xlabel('$\\Delta$ Lat ($\\degree$ N)')

  #Plot distribution

  tags=np.array(['(a1)', '(a2)', '(a3)', '(a4)', '(a5)'])
  for itag in range(5):
    ax[itag,0].text(0.01,.01,tags[itag],horizontalalignment='left',verticalalignment='bottom',
      transform=ax[itag,0].transAxes,color='k',
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
  i=0
  for ii in lhhist.keys():
    jj=0
    for pc in percentages:
      waartemp = np.where(lhhist[ii]['fraction'] < pc)[0]
      if len(waartemp) == 0:
        minval = np.min(lhhist[ii]['counts_sorted'])
      else:
        minval = np.min(lhhist[ii]['counts_sorted'][waartemp])
      ax[0,0].contour(latnew,heightnew,lhhist[ii]['counts'].transpose(),levels=[minval],
        colors=[kleuren[i],],linestyles=[ls[jj],],linewidths=lw[jj],alpha=alpha[jj])
      jj=jj+1

    ax[1,0].fill_between(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist[ii]-lstd[ii],lhist[ii]+lstd[ii],
      color=kleurenfill[i],alpha=0.2)
    ax[1,0].plot(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist[ii],
      color=kleuren[i])

    ax[2,0].fill_betweenx(heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2., hhist[ii]-hstd[ii],hhist[ii]+hstd[ii],
      color=kleurenfill[i],alpha=0.2)
    ax[2,0].plot(hhist[ii],heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2.,
      color=kleuren[i])
    if i!=1:
      ax[3,0].fill_between(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist[ii]-lstd[ii]-lhist[1],
        lhist[ii]+lstd[ii]-lhist[1],color=kleurenfill[i],alpha=0.2)
      ax[3,0].plot(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist[ii]-lhist[1],
        color=kleuren[i])
      ax[4,0].fill_betweenx(heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2., hhist[ii]-hstd[ii]-hhist[1],
        hhist[ii]+hstd[ii]-hhist[1],color=kleurenfill[i],alpha=0.2)
      ax[4,0].plot(hhist[ii]-hhist[1],heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2.,
        color=kleuren[i])
    else:
      ax[3,0].fill_between(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., -lstd[ii], lstd[ii],
        color=kleurenfill[i],alpha=0.2)        
      ax[3,0].plot(latbins[:-1]+(latbins[1:]-latbins[:-1])/2.,np.zeros(len(latbins)-1),
        color=kleuren[i])
      ax[4,0].fill_betweenx(heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2., -hstd[ii], hstd[ii],
        color=kleurenfill[i],alpha=0.2)        
      ax[4,0].plot(np.zeros(len(heightbins)-1),heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2.,
        color=kleuren[i])

    i=i+1

  xlim=ax[0,0].get_xlim()
  ylim=ax[0,0].get_ylim()
  ax[0,1].set_xlim(lonrange)
  for j in range(len(zmin)):
    ax[1,j].set_ylim([0,0.03])
    ax[2,j].set_xlim([0,1e-4])
    ax[0,j].set_xlim(xlim)
    ax[0,j].set_ylim(ylim)
    ax[1,j].set_xlim(xlim)
    ax[2,j].set_ylim(ylim)
    #ax[1,j].set_yscale('log')
    ax[3,j].set_xlim(xlim)
    ax[3,j].set_ylim([-0.0018,0.0018])
    ax[4,j].set_ylim(ylim)
    ax[4,j].set_xlim([-0.5e-5,0.5e-5])
    # ax[1,j].set_xlim([-80,80])
    # ax[0,j].set_ylim([0,0.02])
    # ax[1,j].set_ylim([-0.002,0.002])
    ax[3,j].set_xlabel('Latitude')
    ax[2,j].set_xlabel('PDF \n(10$^{-5}$)')
    #ax[2,j].xaxis.set_label_position('top')
    ax[4,j].set_xlabel('$\\Delta$PDF\n(10$^{-5}$)')
    #ax[4,j].xaxis.set_label_position('top')
    ax[0,j].text(.5, .9, '$%i<z_{\\rm init}<%i$m' %(zmin[j],zmax[j]),
      transform=ax[0,j].transAxes, color='k', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
      position=(0.97, 0.97), fontweight='extra bold', ha='right', va='top',zorder=99999)

    # ax[0,j].set_xticks([-60,-30, 0, 30,60])
    # ax[1,j].set_xticks([-60,-30, 0, 30,60])
    ax[3,j].set_xticks([-60,-30, 0, 30,60])
    ax[0,j].set_xticks([-30, 0, 30])
    ax[0,j].set_xlim(lonrange)
    ax[1,j].set_xticks([-30, 0, 30])
    ax[1,j].set_xlim(lonrange)
    #ax[3,j].set_xticks([-30, 0, 30])
    ax[3,j].set_xlim(lonrange)
    ax[3,j].set_xticklabels(labels=['60$^{\\circ}$ S','30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N','60$^{\\circ}$ N'])
    #ax[3,j].set_xticklabels(labels=['-30$^{\\circ}$ N','0$^{\\circ}$ N','30$^{\\circ}$ N'])

    ax[1,j].set_yticks([0,0.02])
    ax[1,j].set_yticklabels(labels=['0','20'])

    ax[3,j].set_yticks([-0.001,0.001])
    ax[3,j].set_yticklabels(labels=['-1', '1'])


    ax[2,j].set_xticks([5e-5])
    ax[2,j].set_xticklabels(labels=['5'])

    ax[4,j].set_xticks([-0.000005,0, 0.000005])
    ax[4,j].set_xticklabels(labels=['-0.5','0', '0.5'])

  print(ax[0,1].get_xlim())
  ax[1,0].set_ylabel('PDF\n (10$^{-3}$)')
  ax[3,0].set_ylabel('$\\Delta$PDF\n(10$^{-3}$)')

  legendboxes1=[plt.Line2D([0],[0],ls=ls[i],lw=lw[i],alpha=alpha[i],color='k') for i in range(len(ls))]
  legendhandles1=['%i %%' %(pc*100) for pc in percentages]

  #ax[0,1].legend(legendboxes,legendhandles,loc='upper left',frameon=False)

  legendboxes=[plt.Line2D([0],[0],color=kleuren[i]) for i in range(len(yrsall))]
  legendhandles=['%i-%i' %(yrsall[i][0], yrsall[i][-1]) for i in range(len(yrsall))]

  ax[0,0].legend(legendboxes+legendboxes1,legendhandles+legendhandles1,
    loc='upper left',frameon=True)

  # legendboxes=[plt.Line2D([0],[0],color=kleuren2[i]) for i in range(len(lab))]
  # legendhandles=[lab[i] for i in range(len(lab))]
  # ax[1,1].legend(legendboxes,legendhandles, loc='lower left',fontsize=18
  maxval = np.max(np.abs(ax[1,1].get_xlim()))
  ax[1,1].set_xlim(xmin=-maxval,xmax=maxval)
  line1=plt.Line2D([0],[0],color='k',ls=ls2,lw=3)
  ax[1,1].legend([line1],['%i yr running mean' %N],loc=(0,0), frameon=False)


  patch1= mpl.patches.Patch(edgecolor='k',facecolor='lightgray',linestyle=ls2,alpha=0.5)
  patch2= mpl.patches.Patch(edgecolor=amokleuredge,facecolor=amokleur,linestyle=ls2,alpha=0.5)
  labPDO = 'PDO'
  labAMO = 'AMO'
  secax.legend([patch1,patch2],[labPDO,labAMO],loc=(0.55,0.08),
    ncol=1,frameon=False,fontsize=18)

  secax.xaxis.set_label_position('top') 
  secax.xaxis.set_ticks_position('top')
  secax.tick_params(axis="x",direction="in", pad=-28,labelsize=22,zorder=1)
  
  ax[1,1].xaxis.set_label_position('bottom')
  ax[1,1].xaxis.set_ticks_position('bottom')

  legendboxes=[mpl.patches.Patch(edgecolor=kleuren2[i],linewidth=3,facecolor=kleuren2[i+2],
    alpha=0.5,linestyle=ls2) for i in [1,2]]
  legendhandles=[lab2[i] for i in range(len(lab2))]

  legend1 = plt.legend([legendboxes[1]],[legendhandles[1]], loc=(0,0), frameon=False)

  ax[2,1].legend([legendboxes[0]],[legendhandles[0]], loc=(0,0.935), frameon=False)
  ax[2,1].add_artist(legend1)

  ax[0,0].set_ylabel('Height (m)')

  plt.subplots_adjust(top=0.98, bottom=0.1, left=0.068, right=1.0, wspace=0.1, hspace=0)
  #plt.yscale('log')
  plt.savefig(basename+'%idays2d.png' %(iday*5))
  plt.savefig(basename+'%idays2d.pdf' %(iday*5))
  plt.close()

def plot_distribution_map(d_output,t_interval,basename='Energy'):

  grid_xlon = np.arange(-180,180,1)
  grid_ylat = np.arange(-90,90,1)

  #qv1=output['qv']
  #ww=np.where((z1[:,0]>9000.)&(z1[:,0]<15000.))[0]

  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)
  grid_area = np.zeros_like(lon2d_1)
  grid_xrad = np.deg2rad(grid_xlon)
  grid_yrad = np.deg2rad(grid_ylat)+0.5*np.pi
  for i in range(len(grid_xrad)-1):
    for j in range(len(grid_yrad)-1):
      grid_area[j, i] = (grid_xrad[i+1]-grid_xrad[i])*(np.cos(grid_yrad[j])-np.cos(grid_yrad[j+1]))

  yrs = np.array(list(d_output.keys()))


  fig = plt.figure(figsize=(18,15))
  ax3=plt.subplot(111,projection=ccrs.PlateCarree(central_longitude=0))
  ax3.axis('off')
  ax = {}
  ax_vert = {}
  for i in range(2):
      inner = gridspec.GridSpecFromSubplotSpec(2,1, subplot_spec=ax3, wspace=0, hspace=0)
      ax[i] = fig.add_subplot(inner[i:i+1],projection=ccrs.PlateCarree(central_longitude=0))
      ax_vert[i] = fig.add_subplot(inner[i:i+1],projection=ccrs.PlateCarree(central_longitude=0))
      ax[i].set_xticklabels([])
      ax[i].coastlines()
      # ax[i].add_feature(cf.BORDERS)

  ii=0
  for i in [3,len(d_output[yrs[0]]['lat'][0])-1]:
    timesteps=np.array([i])
    print(timesteps)
    grid_n=None
    for yr in yrs[0:t_interval]:
      print(yr,len(d_output[yr]['lat'][0]))
      lon1=d_output[yr]['lon'][:]
      lat1=d_output[yr]['lat'][:]
      lone1 = np.where(lon1>180., lon1-360. ,lon1)
      if grid_n is None:
        grid_n = np.histogram2d(lat1[:,timesteps].flatten(order='C'), 
          lone1[:,timesteps].flatten(order='C'), bins=[grid_ylat, grid_xlon])[0]
      else:
        grid_n = grid_n + np.histogram2d(lat1[:,timesteps].flatten(order='C'), 
          lone1[:,timesteps].flatten(order='C'), bins=[grid_ylat, grid_xlon])[0]
    grid_n = grid_n/t_interval

    res=grid_n/grid_area#/grid_a
    res[np.isinf(res)]=0.
    res[np.isnan(res)]=0.
    #print(i,np.sum(res))
    if np.max(res)==0:
      continue
    levels=np.linspace(np.min(res[res>0]),np.max(res),80)
    cs=ax[ii].contourf(lon2d_1,lat2d_1,res,transform=ccrs.PlateCarree(),extend='min',fraction=0.1,aspect=40,levels=levels,cmap=plt.cm.jet)
    cs.cmap.set_under('white')
    #plt.colorbar(cs, orientation='vertical')

    ii+=1
    #ax[i,1].colorbar(cs, orientation='horizontal')

    #ax[i,0].subplots_adjust(top=0.99, bottom=0.02, left=0.02, right=0.98, wspace=0.04, hspace=0.01)
    #ax.hexbin(lone[i_oall],lat[i_oall])
  
  plt.savefig(basename+'distribution.png')

def plot_manual(d_output,t_interval=84,iday=4,basename='Energy',N=10):

  yrs_elnino = np.array([1941,1942,1958,1966,1973,1978,1980,1983,1987,1992,1995,1998,2003,2007,2010,2016])
  yrs = np.array(list(d_output.keys()))
  percentages=np.array([0.1, 0.25,0.5,0.75,0.9])
  percentages=np.array([0.1, 0.5,0.9])

  yrsall=[yrs[:t_interval],yrs[int(len(yrs)/2-t_interval/2):int(len(yrs)/2+t_interval/2)],yrs[-t_interval:]]
  i_yrsall=([np.arange(t_interval),np.arange(int(len(yrs)/2-t_interval/2),int(len(yrs)/2+t_interval/2)),
    np.arange(len(yrs)-t_interval,len(yrs))])
  i_yrsnoelnino = [i for i in range(len(yrs)) if yrs[i] not in yrs_elnino]
  i_yrselnino = [i for i in range(len(yrs)) if yrs[i] in yrs_elnino]
  ls=['-','--','-.',linestylesdict['densely dashdotdotted'],':']
  alpha=[1, 0.8, 0.7, 0.6, 0.5]
  lw=[2, 2.5, 3, 3.5, 4]
  lonrange=[-60,60]
  latbins=np.arange(lonrange[0],lonrange[-1],1)
  heightbins=np.arange(1,20000,500)
  latnew,heightnew = np.meshgrid( latbins[:-1]+(latbins[1:]-latbins[:-1])/2.,heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2. )
  ls1='--'
  ls2='-'
  cmap=plt.cm.viridis
  
  PDOindex = np.genfromtxt('PDOindex.dat',names=True)
  ww = np.where((PDOindex['Year']>=1940)&(PDOindex['Year']<=yrs[-1]))[0]
  PDOindex = PDOindex[ww]
  PDOmarapr = (np.array(PDOindex['Mar']) + np.array(PDOindex['Apr']))/2

  AMOindex = np.genfromtxt('AMOindex.dat',names=True)
  ww = np.where((AMOindex['Year']>=1940)&(AMOindex['Year']<=yrs[-1]))[0]
  AMOindex = AMOindex[ww]
  ww3 = np.where(AMOindex['month']==3)[0]
  ww4 = np.where(AMOindex['month']==4)[0]
  AMOmarapr = (np.array(AMOindex['SSTA'][ww3])+np.array(AMOindex['SSTA'][ww4]))/2.

  kleuren=['k', 'slateblue','darkorange','mediumseagreen']
  kleurenfill=['mediumslateblue','orange','mediumseagreen']

  kleuren2=['cornflowerblue','indigo','firebrick','mediumorchid', 'coral','navy']
  lab=['Median','99% NH', 'FWHM NH', '99% SH', 'FWMH SH']
  lab2=['99%','FWHM']

  zmin=np.array([0])
  zmax=np.array([17000])

  grid_xlon = np.arange(-180,180,1)
  grid_ylat = np.arange(-90,90,1)

  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)
  grid_area = np.zeros_like(lon2d_1)
  grid_xrad = np.deg2rad(grid_xlon)
  grid_yrad = np.deg2rad(grid_ylat)+0.5*np.pi
  for i in range(len(grid_xrad)-1):
    for j in range(len(grid_yrad)-1):
      grid_area[j, i] = (grid_xrad[i+1]-grid_xrad[i])*(np.cos(grid_yrad[j])-np.cos(grid_yrad[j+1]))


  fig = plt.figure(figsize=(20,12))
  ax3=plt.subplot(111)
  ax3.axis('off')
  ax={}
  ax2={}
  inner = gridspec.GridSpecFromSubplotSpec(102,150, subplot_spec=ax3, wspace=0, hspace=0)

  ax[0,0] = fig.add_subplot(inner[5:25,:30],projection=ccrs.PlateCarree(central_longitude=0))
  ax[0,1] = fig.add_subplot(inner[5:25,30:60],projection=ccrs.PlateCarree(central_longitude=0))
  ax[0,2] = fig.add_subplot(inner[30:80,:45])
  ax[0,3] = fig.add_subplot(inner[80:100,:45])
  ax[0,4] = fig.add_subplot(inner[30:80,45:60])

  ax[0,2].set_xticklabels([])
  ax[0,3].set_yticklabels([])
  ax[0,4].set_yticklabels([])

  ax[1,0] = fig.add_subplot(inner[:100,72:95])
  secax = fig.add_subplot(inner[:100,95:125])
  ax[1,1] = secax.twiny()
  ax[1,2] = fig.add_subplot(inner[:100,125:145])
  for i in range(3):
    if i<2:
      ax[0,i].set_yticklabels([])
    if i>0:
      ax[1,i].set_yticklabels([])


  # Map figures

  ax[0,0].coastlines()
  ax[0,0].set_title('                 1) Selecting particles crossing the equator\n              t=0                         t=20 days')
  ax[0,1].coastlines()

  ii=0
  for i in [0,5]:
    grid_n=None
    for yr in yrs[0:t_interval]:
      print(yr,len(d_output[yr]['lat'][0]))
      lon1=d_output[yr]['lon'][:]
      lat1=d_output[yr]['lat'][:]
      lone1 = np.where(lon1>180., lon1-360. ,lon1)
      if grid_n is None:
        grid_n = np.histogram2d(lat1[:,i].flatten(order='C'), 
          lone1[:,i].flatten(order='C'), bins=[grid_ylat, grid_xlon])[0]
      else:
        grid_n = grid_n + np.histogram2d(lat1[:,i].flatten(order='C'), 
          lone1[:,i].flatten(order='C'), bins=[grid_ylat, grid_xlon])[0]
    grid_n = grid_n/t_interval

    res=grid_n/grid_area
    res[np.isinf(res)]=0.
    res[np.isnan(res)]=0.
    #print(i,np.sum(res))
    if np.max(res)==0:
      continue
    levels=np.linspace(np.min(res[res>0]),np.max(res),80)
    cs=ax[0,ii].contourf(lon2d_1,lat2d_1,res,transform=ccrs.PlateCarree(),extend='min',fraction=0.1,aspect=40,levels=levels,cmap=plt.cm.Blues)
    cs.cmap.set_under('white')
    #plt.colorbar(cs, orientation='vertical')

    ii+=1

  tags1=np.array(['t=0', 't=20 days'])
  tags2=np.array(['(a)','(b)'])
  for itag in range(2):
    ax[0,itag].text(0.99,.01,tags2[itag],horizontalalignment='right',verticalalignment='bottom',
      transform=ax[0,itag].transAxes,color='k',
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.8,boxstyle='round',pad=0.01))

  # Plot PDFs

  median_lat=np.zeros(len(yrs))
  max_lat=np.zeros(len(yrs))
  median_99=np.zeros(len(yrs))
  median_fwhm=np.zeros(len(yrs)) 
  median_m99=np.zeros(len(yrs))
  median_mfwhm=np.zeros(len(yrs)) 

  latheighthist=np.zeros((len(yrs),len(latbins)-1,len(heightbins)-1))
  lathist=np.zeros((len(yrs),len(latbins)-1))
  heighthist=np.zeros((len(yrs),len(heightbins)-1))

  j=0
  for year in yrs:

    lat = d_output[year]['lat'][:]
    z = d_output[year]['z'][:]
    ww=np.where((z[:,0]>zmin[0])&(z[:,0]<zmax[0])&(z[:,iday]!=0.))[0]
    print(year,len(ww), len(np.where(z[:,iday]==0.)[0]))
    lat = lat[ww,iday]
    lat[np.isnan(lat)]=-1
      
    latheighthist[j,:,:] = np.histogram2d(lat,z[ww,iday],bins=[latbins,heightbins],density=True)[0]
    lathist[j,:] = np.histogram(lat,bins=latbins,density=True)[0]
    heighthist[j,:] = np.histogram(z[ww,iday],bins=heightbins,density=True)[0]

    # Compute time series of median and 99 percentiles
    median_lat[j] = np.median(lat)
    median_99[j] = np.sort(lat)[-int(len(ww)/100)]
    median_m99[j] = np.sort(lat)[int(len(ww)/100)]
    median_fwhm[j],median_mfwhm[j],max_lat[j] = fwhm(latbins[:-1]+(latbins[1:]-latbins[:-1])/2.,lathist[j,:])
    j=j+1


  lstd=np.std(lathist[:,:],axis=0)
  lhist=np.mean(lathist[:,:],axis=0)
  hstd=np.std(heighthist[:,:],axis=0)
  hhist=np.mean(heighthist[:,:],axis=0)

  counts = np.mean(latheighthist[:,:,:],axis=0)
  counts_sorted = np.sort(np.ravel(counts))[::-1]
  lhhist = {}
  lhhist['counts'] = counts
  lhhist['counts_sorted'] = counts_sorted
  lhhist['fraction'] = np.cumsum(counts_sorted)/np.sum(counts_sorted)

  # Plot time series

  ax[1,0].set_title('                  3) Time series of particle distribution')
  tags=np.array(['(f)', '(g)', '(h)'])
  for itag in range(3):
    ax[1,itag].text(0.99,.99,tags[itag],horizontalalignment='right',verticalalignment='top',
      transform=ax[1,itag].transAxes,color='k',
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))

  ax[1,0].plot(median_lat, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[0])
  #ax[0,1].plot(yrs, max_lat, ls='-', lw=3, alpha=0.6, color=kleuren2[5])
  ax[1,0].plot(median_99, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[1])
  ax[1,0].plot(median_m99, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[3])
  ax[1,0].plot(median_fwhm, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[2])
  ax[1,0].plot(median_mfwhm, yrs, ls=ls1, lw=1, alpha=1, color=kleuren2[4])

  yrs_rm = running_mean(yrs,N)
  ax[1,0].plot(running_mean(median_lat,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[0])
  #ax[0,1].plot(yrs_rm,running_mean(max_lat,N), ls='--', color=kleuren2[5])
  ax[1,0].plot(running_mean(median_99,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[1])
  ax[1,0].plot(running_mean(median_m99,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[3])
  ax[1,0].plot(running_mean(median_fwhm,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[2])
  ax[1,0].plot(running_mean(median_mfwhm,N), yrs_rm, lw=3, ls=ls2, color=kleuren2[4])
  for yr in yrs_elnino:
    ax[1,0].axhline(yr,color='lavender',lw=3, zorder=0)
    #ax[1,1].axhline(yr,color='lavender',lw=3, zorder=0)
  ax[1,2].axvline(0, color='lavender', lw=3,zorder=0)
  ax[1,0].axvline(0, color='lavender', lw=3,zorder=0)
  ax[0,2].axvline(0, color='lavender', lw=3,zorder=0)
  ax[0,3].axvline(0, color='lavender', lw=3,zorder=0)
  secax.axvline(0, color='lavender', lw=3,zorder=0)


  yrs_rm = running_mean(yrs,N)
  rm_temp = running_mean(median_lat,N)

  amokleur = (230/255,255/255,179/255)#'palegreen'
  amokleuredge = (51/255,102/255,0/255)
  PDOindex_rm = running_mean(PDOmarapr,N)
  AMOindex_rm = running_mean(AMOmarapr,N)
  secax.fill_betweenx(yrs_rm,PDOindex_rm,edgecolor='k',facecolor='lightgrey',zorder=1, ls=ls2,alpha=0.3)
  secax.fill_betweenx(yrs_rm,AMOindex_rm,edgecolor=amokleuredge,facecolor=amokleur,zorder=1, ls=ls2,alpha=0.3)
  secax.set_xlim(xmax=-secax.get_xlim()[0])
  #secax.axis('off')

  # ax[1,1].plot(PDOmarapr[:-1],yrs,color='peachpuff')
  # ax[1,1].plot(AMOmarapr,yrs,color='lightsteelblue')

  #N2=5
  N2=N
  yrs_rm2 = running_mean(yrs[i_yrselnino],N2)
  rm_temp2 = running_mean(median_lat[i_yrselnino],N2)
  print(rm_temp)
  ax[1,1].plot(rm_temp-rm_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[0])
  ax[0,3].axvline(np.mean(median_lat), ls=ls2, lw=3, color=kleuren2[0])
  #ax[1,1].plot(rm_temp2-rm_temp[0], yrs_rm2, ls=':', color=kleuren2[0],alpha=0.5)
  # rm_temp = running_mean(max_lat,N)
  # ax[1,1].plot(yrs_rm,rm_temp-rm_temp[0], ls='-', color=kleuren2[5])
  m99_temp = running_mean(median_99,N)
  m99en_temp = running_mean(median_99[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp1-m99_temp[0], yrs_rm2, ls=':', color=kleuren2[1])
  if N>10:
    ax[1,1].plot(m99_temp-m99_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[1])
  ax[0,3].axvline(np.mean(median_99), ls=ls2, lw=3, color=kleuren2[1])

  mm99_temp = running_mean(median_m99,N)
  mm99en_temp = running_mean(median_m99[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp2-mm99_temp[0], yrs_rm2, ls=':', color=kleuren2[3])
  if N>10:
    ax[1,1].plot(mm99_temp-mm99_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[3])
  ax[0,3].axvline(np.mean(median_m99), ls=ls2, lw=3, color=kleuren2[3])

  fwhm_temp = running_mean(median_fwhm,N)
  fwhmen_temp = running_mean(median_fwhm[i_yrselnino],N2)
  #ax[1,1].plot(rm_temp1-fwhm_temp[0], yrs_rm2, ls=':', color=kleuren2[2],alpha=0.5)
  ax[1,1].plot(fwhm_temp-fwhm_temp[0], yrs_rm, ls=ls2,lw=3, color=kleuren2[2])
  ax[0,3].axvline(np.mean(median_fwhm), ls=ls2, lw=3, color=kleuren2[2])
  
  mfwhm_temp = running_mean(median_mfwhm,N)
  mfwhmen_temp = running_mean(median_mfwhm[i_yrselnino],N2)
  # ax[1,1].plot(rm_temp2-mfwhm_temp[0], yrs_rm2, ls=':', color=kleuren2[4])
  ax[1,1].plot(mfwhm_temp-mfwhm_temp[0], yrs_rm, ls=ls2, lw=3, color=kleuren2[4])
  ax[0,3].axvline(np.mean(median_mfwhm), ls=ls2, lw=3, color=kleuren2[4])
  # ax[2,1].plot(rm_temp1-rm_temp2- (fwhm_temp-mfwhm_temp)[0],yrs_rm2, kleuren2[2],ls=':')

  ax[1,2].plot(fwhm_temp-mfwhm_temp- (fwhm_temp-mfwhm_temp)[0],yrs_rm, kleuren2[2],lw=3,ls=ls2)
  ax[1,2].fill_betweenx(yrs_rm,0,fwhm_temp-mfwhm_temp- (fwhm_temp-mfwhm_temp)[0], color=kleuren2[4],alpha=0.4)
  ax[1,2].plot(m99_temp-mm99_temp- (m99_temp-mm99_temp)[0], yrs_rm,kleuren2[1],lw=3,ls=ls2)
  ax[1,2].fill_betweenx(yrs_rm,0,m99_temp-mm99_temp- (m99_temp-mm99_temp)[0], color= kleuren2[3],alpha=0.4)



  # ax[2,1].plot(fwhmen_temp-mfwhmen_temp- (fwhm_temp-mfwhm_temp)[0],yrs_rm2, kleuren2[2],lw=1,ls=ls2)
  # ax[2,1].fill_betweenx(yrs_rm2,0,fwhmen_temp-mfwhmen_temp- (fwhm_temp-mfwhm_temp)[0], color=kleuren2[4],alpha=0.2)
  # ax[2,1].plot(m99en_temp-mm99en_temp- (m99_temp-mm99_temp)[0], yrs_rm2,kleuren2[1],lw=1,ls=ls2)
  # ax[2,1].fill_betweenx(yrs_rm2,0,m99en_temp-mm99en_temp- (m99_temp-mm99_temp)[0], color= kleuren2[3],alpha=0.2)

  # rm_temp = running_mean(max_lat,N)
  # ax[1,1].plot(yrs_rm,rm_temp-rm_temp[0], ls='-', color=kleuren2[5])
  ax[1,0].text(0,yrs[-2],'Median',va='top',ha='right', color=kleuren2[0], rotation='vertical')
  ax[1,0].set_xlim(lonrange)
  print('Delta99:', np.mean(median_99[i_yrsnoelnino]-median_m99[i_yrsnoelnino])-
    np.mean(median_99[i_yrselnino]-median_m99[i_yrselnino]))
  print('DeltaFWHM:', np.mean(median_fwhm[i_yrsnoelnino]-median_mfwhm[i_yrsnoelnino])-
    np.mean(median_fwhm[i_yrselnino]-median_mfwhm[i_yrselnino]))
  i_fwhm=-int(len(yrs)/3)
  ax[1,0].axhline(yrs[i_fwhm],xmin=0.5,
    xmax=(median_fwhm[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]), color=kleuren2[2])
  ax[1,0].axhline(yrs[i_fwhm],xmin=(median_mfwhm[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]),
    xmax=0.5, color=kleuren2[4])
  print(yrs[i_fwhm])
  ax[1,0].text(0,yrs[i_fwhm]-0.5,'FWHM',va='top',ha='center', color=kleuren2[2], rotation='horizontal')
  ax[1,0].text(median_mfwhm[-2],yrs[-2],'FWHM SH',va='top',ha='right', color=kleuren2[4], rotation='vertical')
  ax[1,0].text(median_fwhm[-2],yrs[-2],'FWHM NH',va='top',ha='right', color=kleuren2[2], rotation='vertical')

  i_fwhm=-int(len(yrs)/2)
  ax[1,0].axhline(yrs[i_fwhm],xmin=0.5,
    xmax=(median_99[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]), color=kleuren2[5])
  ax[1,0].axhline(yrs[i_fwhm],xmin=(median_m99[i_fwhm]-latbins[0])/(latbins[-1]-latbins[0]),
    xmax=0.5, color=kleuren2[3])
  ax[1,0].text(0,yrs[i_fwhm]-0.5,'99%',va='top',ha='center', color=kleuren2[1], rotation='horizontal')
  ax[1,0].text(median_m99[-2],yrs[-2],'99% SH',va='top',ha='right', color=kleuren2[3], rotation='vertical')
  ax[1,0].text(median_99[-2],yrs[-2],'99% NH',va='top',ha='right', color=kleuren2[1], rotation='vertical')
  ax[1,0].set_ylim((yrs[0],yrs[-1]))
  ax[1,1].set_ylim((yrs[0],yrs[-1]))
  ax[1,2].set_ylim((yrs[0],yrs[-1]))
  ax[1,0].set_ylabel('Year')
  # ax[0,1].set_xticks([-60,-30, 0, 30,60])
  # ax[0,1].set_xticklabels(labels=['-60$^{\\circ}$ N','-30$^{\\circ}$ N','0$^{\\circ}$ N','30$^{\\circ}$ N','60$^{\\circ}$ N'])
  ax[1,0].set_xticks([-30, 0, 30])
  ax[1,0].set_xticklabels(labels=['30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N'])
  ax[1,0].set_xlabel('Latitude')
  ax[1,1].set_xlabel('$\\Delta$ Lat ($\\degree$ N)')
  ax[1,2].set_xlabel('$\\Delta$ Lat ($\\degree$ N)')

  #Plot distribution

  tags=np.array(['(a1)', '(a2)', '(c)', '(d)', '(e)'])
  for itag in range(2,5):
    ax[0,itag].text(0.01,.01,tags[itag],horizontalalignment='left',verticalalignment='bottom',
      transform=ax[0,itag].transAxes,color='k',
      bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
  i=0

  jj=0
  for pc in percentages:
    waartemp = np.where(lhhist['fraction'] < pc)[0]
    if len(waartemp) == 0:
      minval = np.min(lhhist['counts_sorted'])
    else:
      minval = np.min(lhhist['counts_sorted'][waartemp])
    ax[0,2].contour(latnew,heightnew,lhhist['counts'].transpose(),levels=[minval],
      colors=[kleuren[i],],linestyles=[ls[jj],],linewidths=lw,alpha=alpha[jj])
    jj=jj+1

  ax[0,3].fill_between(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist-lstd,lhist+lstd,
    color=kleurenfill[i],alpha=0.2)
  ax[0,3].plot(latbins[:-1]+(latbins[1:]-latbins[:-1])/2., lhist,
    color=kleuren[i])

  ax[0,4].fill_betweenx(heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2., hhist-hstd,hhist+hstd,
    color=kleurenfill[i],alpha=0.2)
  ax[0,4].plot(hhist,heightbins[:-1]+(heightbins[1:]-heightbins[:-1])/2.,
    color=kleuren[i])


  xlim=ax[0,2].get_xlim()
  ylim=ax[0,2].get_ylim()
  ax[0,3].set_xlim(lonrange)

  ax[0,3].set_ylim([0,0.03])
  ax[0,4].set_xlim([0,1e-4])
  ax[0,3].set_xlim(xlim)
  ax[0,4].set_ylim(ylim)

  ax[0,3].set_xlabel('Latitude')
  ax[0,4].set_xlabel('PDF \n(10$^{-5}$)')

  ax[0,2].text(.5, .9, '$%i<z_{\\rm init}<%i$m' %(zmin[0],zmax[0]),
    transform=ax[0,3].transAxes, color='k', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
    position=(0.97, 0.97), fontweight='extra bold', ha='right', va='top',zorder=99999)

  ax[0,3].set_xticks([-60,-30, 0, 30,60])
  ax[0,2].set_xticks([-30, 0, 30])
  ax[0,2].set_xlim(lonrange)
  ax[0,3].set_xlim(lonrange)
  ax[0,3].set_xticklabels(labels=['60$^{\\circ}$ S','30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N','60$^{\\circ}$ N'])

  ax[0,3].set_yticks([0,0.02])
  ax[0,3].set_yticklabels(labels=['0','20'])

  ax[0,4].set_xticks([5e-5])
  ax[0,4].set_xticklabels(labels=['5'])

  ax[0,3].set_ylabel('PDF\n (10$^{-3}$)')

  legendboxes1=[plt.Line2D([0],[0],ls=ls[i],lw=lw[i],alpha=alpha[i],color='k') for i in range(len(ls))]
  legendhandles1=['%i %%' %(pc*100) for pc in percentages]

  ax[0,2].legend(legendboxes1,legendhandles1,
    loc='upper left',frameon=True)
  ax[0,2].set_title('2) Distribution of particles at t=20 days')
  maxval = np.max(np.abs(ax[0,3].get_xlim()))
  ax[0,3].set_xlim(xmin=-maxval,xmax=maxval)
  line1=plt.Line2D([0],[0],color='k',ls=ls2,lw=3)
  ax[1,1].legend([line1],['%i yr running mean' %N],loc=(0,0), frameon=False)


  patch1= mpl.patches.Patch(edgecolor='k',facecolor='lightgray',linestyle=ls2,alpha=0.5)
  patch2= mpl.patches.Patch(edgecolor=amokleuredge,facecolor=amokleur,linestyle=ls2,alpha=0.5)
  labPDO = 'PDO'
  labAMO = 'AMO'
  secax.legend([patch1,patch2],[labPDO,labAMO],loc=(0.55,0.08),
    ncol=1,frameon=False,fontsize=18)

  secax.xaxis.set_label_position('top') 
  secax.xaxis.set_ticks_position('top')
  secax.tick_params(axis="x",direction="in", pad=-28,labelsize=22,zorder=1)
  
  ax[1,1].xaxis.set_label_position('bottom')
  ax[1,1].xaxis.set_ticks_position('bottom')

  legendboxes=[mpl.patches.Patch(edgecolor=kleuren2[i],linewidth=3,facecolor=kleuren2[i+2],
    alpha=0.5,linestyle=ls2) for i in [1,2]]
  legendhandles=[lab2[i] for i in range(len(lab2))]

  legend1 = plt.legend([legendboxes[1]],[legendhandles[1]], loc=(0,0), frameon=False)

  ax[1,2].legend([legendboxes[0]],[legendhandles[0]], loc=(0,0.935), frameon=False)
  ax[1,2].add_artist(legend1)

  ax[0,2].set_ylabel('Height (m)')

  plt.subplots_adjust(top=0.95, bottom=0.05, left=0.068, right=1.0, wspace=0.1, hspace=0)
  #plt.yscale('log')
  plt.savefig(basename+'manual.png')
  plt.savefig(basename+'manual.pdf')
  plt.close()