import numpy as np
import sys
import os
import h5py
import re
#import eccodes
#import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import matplotlib.pyplot as plt
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic
from sklearn.neighbors import BallTree
import pandas as pa
import multiprocessing as mp
from mpl_toolkits.axes_grid1 import make_axes_locatable

import cartopy.crs as ccrs
from matplotlib import rcParams

rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.weight'] = 'normal'
rcParams['font.size'] = 20
rcParams['xtick.labelsize'] = 20
rcParams['ytick.labelsize'] = 20
rcParams['axes.labelsize'] = 22
rcParams['axes.linewidth'] = 1
rcParams['axes.unicode_minus'] = False
rcParams['xtick.minor.width'] = 1
rcParams['ytick.minor.width'] = 1
rcParams['xtick.major.width'] = 1
rcParams['ytick.major.width'] = 1
rcParams['xtick.minor.size'] = 4
rcParams['ytick.minor.size'] = 4
rcParams['xtick.major.size'] = 5
rcParams['ytick.major.size'] = 5
rcParams['xtick.direction'] = 'in'
rcParams['ytick.direction'] = 'in'
rcParams['mathtext.fontset'] = 'cm'


import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
from matplotlib import gridspec
from matplotlib import ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import NullFormatter
from collections import OrderedDict, Counter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle, Ellipse
import matplotlib.patheffects as pe
import pickle 


def plot_density_pickle(z=[0,1e5], fname='density_profile'):

  both=True
  #grid from sea level
  grid_xlon = np.arange(0,360,0.5)
  grid_ylat = np.arange(-90.,90.5,0.5)
  grid_z = np.concatenate([np.array([0.,25.,50.,75.]),np.logspace(2,np.log10(z[1]),100)])

  #zpos = np.logspace(np.log10(grid_z[0]+(grid_z[1]-grid_z[0])/2.),np.log10(grid_z[-2]+(grid_z[-1]-grid_z[-2])/2.),99)
  zpos=grid_z[:-1]+(grid_z[1:]-grid_z[:-1])/2.

  periods=([np.arange(1940,1947).astype(int),
    np.arange(1947,1954).astype(int),
    np.arange(1954,1961).astype(int),
    np.arange(1961,1968).astype(int),
    np.arange(1968,1975).astype(int),
    np.arange(1975,1982).astype(int),
    np.arange(1982,1989).astype(int),
    np.arange(1989,1996).astype(int),
    np.arange(1996,2003).astype(int),
    np.arange(2010,2017).astype(int),
    np.arange(2017,2024).astype(int)])

  density= {}
  for i in range(len(periods)):
    density[i]={}
    for yr in range(len(periods[i])):
      year=yr+periods[i][0]
      print(year)
      dens=open('/home/lucie/LARA/data_density/density_%i_%i.pickle' %(year,7), "rb")
      density[i][yr] = pickle.load(dens)
      dens.close()

  cmap=plt.cm.viridis_r#summer#ocean
  density_py={}
  for yr in range(7):
    density_py[yr] = {}
    aantal = 0
    for i in range(len(periods)):
      if yr not in density[i].keys():
        continue
      aantal = aantal+1
      for key in density[i][yr]:
        if key not in density_py[yr].keys():
          density_py[yr][key] = np.zeros_like(density[i][yr][key])
        density_py[yr][key] = density_py[yr][key] + density[i][yr][key]
    for key in density_py[yr].keys():
      density_py[yr][key] = density_py[yr][key]/aantal


  fig,ax=plt.subplots(1,figsize=(16,9),sharex=True,sharey=True)
  ax.axis('off')

  if (both==True):
    istep_p1=8
    istep_p2=6
  else:
    istep_p1=9
    istep_p2=5
  inner = gridspec.GridSpecFromSubplotSpec(20, 45, subplot_spec=ax, wspace=0, hspace=0)
  ax1a = fig.add_subplot(inner[0:18, 0:istep_p1])
  ax1b = fig.add_subplot(inner[0:18, istep_p1:istep_p1+istep_p2])
  ax2a = fig.add_subplot(inner[0:18, istep_p1+istep_p2+1:istep_p1*2+istep_p2+1])
  ax2b = fig.add_subplot(inner[0:18, istep_p1*2+istep_p2+1:istep_p1*2+istep_p2*2+1])
  ax3a = fig.add_subplot(inner[0:18, istep_p1*2+istep_p2*2+2:istep_p1*3+istep_p2*2+2])
  ax3b = fig.add_subplot(inner[0:18, istep_p1*3+istep_p2*2+2:istep_p1*3+istep_p2*3+2])

  ax1ac = fig.add_subplot(inner[18:, 0:istep_p1])
  ax1bc = fig.add_subplot(inner[18:, istep_p1:istep_p1+istep_p2])
  ax2ac = fig.add_subplot(inner[18:, istep_p1+istep_p2+1:istep_p1*2+istep_p2+1])
  ax2bc = fig.add_subplot(inner[18:, istep_p1*2+istep_p2+1:istep_p1*2+istep_p2*2+1])
  ax3ac = fig.add_subplot(inner[18:, istep_p1*2+istep_p2*2+2:istep_p1*3+istep_p2*2+2])
  ax3bc = fig.add_subplot(inner[18:, istep_p1*3+istep_p2*2+2:istep_p1*3+istep_p2*3+2])

  kleur3=plt.cm.tab20b(3/20.)
  kleur3a=plt.cm.tab20b(14/20.)
  kleur2=plt.cm.tab20b(3/20.)
  kleur2a=plt.cm.tab20b(14/20.)
  kleur1=plt.cm.tab20b(3/20.)
  kleur1a=plt.cm.tab20b(14/20.)
  polestr=0
  polespbl=0
  midlattr=0
  midlatpbl=0
  troptr=0
  troppbl=0
  i_first = list(density_py.keys())[0]
  i_end = list(density_py.keys())[-1]
  for yr in density_py.keys():

    kleur1 = cmap((yr+1)/8)
    kleur2 = cmap((yr+1)/8)
    kleur3 = cmap((yr+1)/8)
    density = density_py[yr]
    polestr+=density['polestr']
    polespbl+=density['polespbl']
    poleseta=density['poleseta']
    ec_poleseta=density['ec_poleseta']
    polestopo=density['polestopo']

    midlattr+=density['midlattr']
    midlatpbl+=density['midlatpbl']
    midlateta=density['midlateta']
    ec_midlateta=density['ec_midlateta']
    midlattopo=density['midlattopo']

    troptr+=density['troptr']
    troppbl+=density['troppbl']
    tropeta=density['tropeta']
    ec_tropeta=density['ec_tropeta']
    troptopo=density['troptopo']

    dpoleseta=(poleseta-ec_poleseta)
    dmidlateta=(midlateta-ec_midlateta)
    dtropeta=(tropeta-ec_tropeta)

    if yr==i_first or yr==i_end:
      ax1a.plot(poleseta,zpos,ls='-',lw=3,color=kleur1,zorder=1)
      ax1ac.plot(poleseta,zpos,ls='-',lw=3,color=kleur1,zorder=1)

    et=(dpoleseta)
    #ax1b.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax1b.plot((et),zpos,color=kleur1,zorder=3)
    #ax1bc.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax1bc.plot((et),zpos,color=kleur1,zorder=3)
    
    if yr==i_first or yr==i_end:
      ax2a.plot(midlateta,zpos,ls='-',lw=3,color=kleur2,zorder=1)
      ax2ac.plot(midlateta,zpos,ls='-',lw=3,color=kleur2,zorder=1)

    et=(dmidlateta)
    #ax2b.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax2b.plot((et),zpos,color=kleur1,zorder=3)
    #ax2bc.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax2bc.plot((et),zpos,color=kleur1,zorder=3)

    if yr==i_first or yr==i_end:
      ax3a.plot(tropeta,zpos,ls='-',lw=3,color=kleur3,zorder=1)
      ax3ac.plot(tropeta,zpos,ls='-',lw=3,color=kleur3,zorder=1)

    et=(dtropeta)

    #ax3b.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax3b.plot((et),zpos,color=kleur1,zorder=3)
    #ax3bc.fill_betweenx(zpos,np.abs(et),0,color=kleur1,zorder=1)
    ax3bc.plot((et),zpos,color=kleur1,zorder=3)


  # ipbl=np.abs(zpos-midlatpbl/len(density_py.keys())).argmin()
  # print('midlat PBL:', np.sum(dmidlateta[1:ipbl]/ec_midlateta[1:ipbl]*zpos[1:ipbl])/np.sum(zpos[1:ipbl])*100)
  # print('midlat PBL max:', np.max(dmidlateta[1:ipbl]/ec_midlateta[1:ipbl])*100)
  ipbl=1
  itr=np.abs(zpos-midlattr/len(density_py.keys())).argmin()
  print('midlat tropo:', np.sum(dmidlateta[ipbl:itr]/ec_midlateta[ipbl:itr]*zpos[ipbl:itr])/np.sum(zpos[ipbl:itr])*100)
  print('midlat tropo max:', np.max(dmidlateta[ipbl:itr]/ec_midlateta[ipbl:itr])*100)

  istr=np.abs(zpos-5e4).argmin()
  print('midlat strat:', np.sum(dmidlateta[itr:istr]/ec_midlateta[itr:istr]*zpos[itr:istr])/np.sum(zpos[itr:istr])*100)
  print('midlat strat max:', np.max(dmidlateta[itr:istr]/ec_midlateta[itr:istr])*100)

  print('midlat total:', np.min(dmidlateta/ec_midlateta)*100)
  print('PBL','Tropo','Strato', zpos[ipbl],zpos[itr],zpos[istr])

  # ipbl=np.abs(zpos-polespbl/len(density_py.keys())).argmin()
  # print('poles PBL:',np.sum(dpoleseta[1:ipbl]/ec_poleseta[1:ipbl]*zpos[1:ipbl])/np.sum(zpos[1:ipbl])*100)
  # print('poles PBL max:',np.max(dpoleseta[1:ipbl]/ec_poleseta[1:ipbl])*100)
  # i1km=np.abs(zpos-1000).argmin()
  # print('poles 1km:',np.sum(dpoleseta[1:i1km]/ec_poleseta[1:i1km]*zpos[1:i1km])/np.sum(zpos[1:i1km])*100)
  # print('poles 1km max:',np.max(dpoleseta[1:i1km]/ec_poleseta[1:i1km])*100)
  ipbl=1
  itr=np.abs(zpos-polestr/len(density_py.keys())).argmin()
  print('poles tropo:', np.sum(dpoleseta[ipbl:itr]/ec_poleseta[ipbl:itr]*zpos[ipbl:itr])/np.sum(zpos[ipbl:itr])*100)
  print('poles tropo max:', np.max(dpoleseta[ipbl:itr]/ec_poleseta[ipbl:itr])*100)

  istr=np.abs(zpos-5e4).argmin()
  print('poles strat:', np.sum(dpoleseta[itr:istr]/ec_poleseta[itr:istr]*zpos[itr:istr])/np.sum(zpos[itr:istr])*100)
  print('poles strat max:', np.max(dpoleseta[itr:istr]/ec_poleseta[itr:istr])*100)
  ipeak = np.argmax(dpoleseta[itr:istr])
  print('poles strat peak meter:', (dpoleseta[itr:istr]/ec_poleseta[itr:istr])[ipeak]*100)
  print('poles total:', np.min(dpoleseta/ec_poleseta)*100)
  print('PBL','Tropo','Strato', zpos[ipbl],zpos[itr],zpos[istr])

  # ipbl=np.abs(zpos-troppbl/len(density_py.keys())).argmin()
  # print('tropics PBL:',np.sum(dtropeta[1:ipbl]/ec_tropeta[1:ipbl]*zpos[1:ipbl])/np.sum(zpos[1:ipbl])*100)
  # print('tropics PBL:',np.max(dtropeta[1:ipbl]/ec_tropeta[1:ipbl])*100)
  ipbl=1
  itr=np.abs(zpos-troptr/len(density_py.keys())).argmin()
  print('tropics tropo:', np.sum(dtropeta[ipbl:itr]/ec_tropeta[ipbl:itr]*zpos[ipbl:itr])/np.sum(zpos[ipbl:itr])*100)
  print('tropics tropo max:', np.max(dtropeta[ipbl:itr]/ec_tropeta[ipbl:itr])*100)
  istr=np.abs(zpos-5e4).argmin()
  print('tropics strat:', np.sum(dtropeta[itr:istr]/ec_tropeta[itr:istr]*zpos[itr:istr])/np.sum(zpos[itr:istr])*100)
  print('tropics strat max:', np.max(dtropeta[itr:istr]/ec_tropeta[itr:istr])*100)
  print('tropics total:', np.min(dtropeta/ec_tropeta)*100)
  print('PBL','Tropo','Strato', zpos[ipbl],zpos[itr],zpos[istr])

  ax1a.plot(ec_poleseta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax1ac.plot(ec_poleseta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax1a.axhline(polespbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax1b.axhline(polespbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax1a.axhline(polestr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)
  ax1b.axhline(polestr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)

  ax2a.plot(ec_midlateta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax2ac.plot(ec_midlateta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax2a.axhline(midlatpbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax2b.axhline(midlatpbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax2a.axhline(midlattr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)
  ax2b.axhline(midlattr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)

  ax3a.plot(ec_tropeta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax3ac.plot(ec_tropeta,zpos,ls='--',lw=1,color='k',zorder=3)
  ax3a.axhline(troppbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax3a.text(0,troppbl/len(density_py.keys())*0.9, 'ABL',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top')
  ax3b.axhline(troppbl/len(density_py.keys()),ls='-',lw=5,color='lavender',zorder=0)
  ax3a.axhline(troptr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)
  ax3a.text(0,troptr/len(density_py.keys())*0.9, 'Tropopause',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top',zorder=0)
  ax3b.axhline(troptr/len(density_py.keys()),ls='--',lw=5,color='lavender',zorder=0)

  ax1a.text(0,polespbl/len(density_py.keys())*0.9, 'ABL',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top')
  ax1a.text(0,polestr/len(density_py.keys())*0.9, 'Tropopause',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top',zorder=0)
  #ax1a.fill_between([0,np.max(poleseta)],polestopo,color='lavender',zorder=0,interpolate=True)

  #ax1a.plot(ec_poleseta,zpos,ls='--',lw=1,color='k',zorder=2)
  drange=0.11
  ax1b.set_xlim([-drange,drange])
  ax1b.axvline(0,ls='--',color='k',lw=1)
  ax1bc.axvline(0,ls='--',color='k',lw=1)
  ax1bc.set_xlim(ax1b.get_xlim())
  ax1ac.set_xlim(ax1a.get_xlim())

  ax1a.set_yscale('log')
  ax1b.set_yscale('log')

  ax2a.text(0,midlatpbl/len(density_py.keys())*0.9, 'ABL',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top')
  ax2a.text(0,midlattr/len(density_py.keys())*0.9, 'Tropopause',
      fontsize=23,
      color='lavender',
      rotation=0,
      fontweight='extra bold', ha='left', va='top',zorder=0)

  ax2b.set_xlim([-drange,drange])
  ax2b.axvline(0,ls='--',color='k',lw=1)
  ax2bc.axvline(0,ls='--',color='k',lw=1)
  ax2bc.set_xlim(ax2b.get_xlim())
  ax2ac.set_xlim(ax2a.get_xlim())

  ax2a.set_yscale('log')
  ax2b.set_yscale('log')


  ax3b.set_xlim([-drange,drange])
  ax3b.axvline(0,ls='--',color='k',lw=1)
  ax3bc.axvline(0,ls='--',color='k',lw=1)
  ax3bc.set_xlim(ax3b.get_xlim())
  ax3ac.set_xlim(ax3a.get_xlim())

  tags=np.array(['(a)', '(b)', '(c)'])
  ax1b.text(0.99,.99,tags[0],horizontalalignment='right',verticalalignment='top',transform=ax1b.transAxes,color='k',fontsize=22,
    bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
  ax2b.text(0.99,.99,tags[1],horizontalalignment='right',verticalalignment='top',transform=ax2b.transAxes,color='k',fontsize=22,
    bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
  ax3b.text(0.99,.99,tags[2],horizontalalignment='right',verticalalignment='top',transform=ax3b.transAxes,color='k',fontsize=22,
    bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))

  miny=100.
  ax3a.set_ylim([miny,5e4])
  ax3b.set_ylim([miny,5e4])
  ax1a.set_ylim([miny,5e4])
  ax1b.set_ylim([miny,5e4])
  ax2a.set_ylim([miny,5e4])
  ax2b.set_ylim([miny,5e4])

  ax3ac.set_ylim([0,miny])
  ax3bc.set_ylim([0,miny])
  ax1ac.set_ylim([0,miny])
  ax1bc.set_ylim([0,miny])
  ax2ac.set_ylim([0,miny])
  ax2bc.set_ylim([0,miny])

  ax3a.set_yscale('log')
  ax3b.set_yscale('log')
  #ax3b.set_xscale('log')

  ax1a.text(0.9,.98,'$|\\rm{lat}|>66^{\\circ}$',horizontalalignment='right',verticalalignment='top',
    transform=ax1a.transAxes,color='k')#kleur1)
  ax2a.text(0.9,.98,'$66^{\\circ}>|\\rm{lat}|>23^{\\circ}$',horizontalalignment='right',verticalalignment='top',
    transform=ax2a.transAxes,color='k')#kleur2)
  ax3a.text(0.9,.98,'$|\\rm{lat}|<23^{\\circ}$',horizontalalignment='right',verticalalignment='top',
    transform=ax3a.transAxes,color='k')#kleur3)

  line = [Line2D([0],[0],ls='-',lw=3,color=cmap((yr+1)/8),label='year %i' %(yr+1)) for yr in range(7)]

  line3 = [Line2D([0],[0],ls='--',lw=1,color='k',label='$\\rho_{\\rm ERA5}$')] + line

  if both==True:
    ax3a.legend(handles=line3,handlelength=2,frameon=False,loc='lower left',fontsize=20)
  else:
    ax2a.legend(handles=line3,handlelength=2,frameon=False,loc='lower left',fontsize=18)
    patch1 = Patch(facecolor=kleur3,edgecolor=kleur3,label='$\\Delta_{\\eta}>\\Delta_{z}$')
    patch2 = Patch(facecolor=kleur3a,edgecolor=kleur3a,label='$\\Delta_{\\eta}<\\Delta_{z}$')
    ax3a.legend(handles=[patch1,patch2],frameon=False,loc='lower left',fontsize=18)
  
  ax1b.set_yticklabels([])
  ax2b.set_yticklabels([])
  ax3b.set_yticklabels([])
  ax2bc.set_yticks([0,50])
  ax3bc.set_yticks([0,50])
  ax1bc.set_yticks([0,50])
  ax1bc.set_yticklabels([])
  ax2bc.set_yticklabels([])
  ax3bc.set_yticklabels([])
  ax2ac.set_yticks([0,50])
  ax2ac.set_yticklabels([])
  ax3ac.set_yticks([0,50])
  ax3ac.set_yticklabels([])
  ax1ac.set_yticklabels([0,50])
  ax1ac.set_yticks([0,50])
  ax2a.set_yticklabels([])
  ax3a.set_yticklabels([])
  ax1a.set_ylabel('Height (m)')
  ax1a.set_xticklabels([])
  ax2a.set_xticklabels([])
  ax3a.set_xticklabels([])
  ax1b.set_xticklabels([])
  ax2b.set_xticklabels([])
  ax3b.set_xticklabels([])
  ax1ac.set_xlabel('Density (kg m$^{-3}$)')
  ax2ac.set_xlabel('Density (kg m$^{-3}$)')
  ax3ac.set_xlabel('Density (kg m$^{-3}$)')
  if both==True:
    ax1bc.set_xlabel('$\\rho_{\\rm part}-\\rho_{\\rm ERA5}$ \n(kg m$^{-3}$)',fontsize=18)
    ax2bc.set_xlabel('$\\rho_{\\rm part}-\\rho_{\\rm ERA5}$ \n(kg m$^{-3}$)',fontsize=18)
    ax3bc.set_xlabel('$\\rho_{\\rm part}-\\rho_{\\rm ERA5}$ \n(kg m$^{-3}$)',fontsize=18)
  else:
    ax1bc.set_xlabel('$\\Delta_{\\eta}-\\Delta_{z}$',fontsize=18)
    ax2bc.set_xlabel('$\\Delta_{\\eta}-\\Delta_{z}$',fontsize=18)
    ax3bc.set_xlabel('$\\Delta_{\\eta}-\\Delta_{z}$',fontsize=18)
  fig.subplots_adjust(top=0.96, bottom=0.14, left=0.08, right=0.98, wspace=0.15, hspace=0.02)

  plt.savefig('/home/lucie/LARA/Density/'+fname+'.png')
  plt.savefig('/home/lucie/LARA/'+fname+'.pdf')
  plt.close()


def plot_yearly_change(z=[0,1e5],fname='density_yr'):

  both=True
  #grid from sea level
  grid_xlon = np.arange(0,360,0.5)
  grid_ylat = np.arange(-90.,90.5,0.5)
  grid_z = np.concatenate([np.array([0.,25.,50.,75.]),np.logspace(2,np.log10(z[1]),100)])

  #zpos = np.logspace(np.log10(grid_z[0]+(grid_z[1]-grid_z[0])/2.),np.log10(grid_z[-2]+(grid_z[-1]-grid_z[-2])/2.),99)
  zpos=grid_z[:-1]+(grid_z[1:]-grid_z[:-1])/2.

  periods=([np.arange(1940,1944).astype(int),
    np.arange(1945,1952).astype(int),
    np.arange(1952,1959).astype(int),
    np.arange(1959,1966).astype(int),
    np.arange(1966,1974).astype(int),
    np.arange(1974,1982).astype(int),
    np.arange(1982,1990).astype(int),
    np.arange(1990,1998).astype(int),
    np.arange(1998,2006).astype(int),
    np.arange(2006,2013).astype(int)])#,
    #np.arange(2014,2019).astype(int)])
    #np.arange(2021,2024).astype(int)])

  #periods=([[1982,1983,1984,1985,1986,1987,1988,1989]])

  density= {}
  for i in range(len(periods)):
    density[i]={}
    for yr in range(len(periods[i])):
      year=yr+periods[i][0]
      print(year)
      dens=open('/home/lucie/LARA/data_density/density_%i_%i.pickle' %(year,7), "rb")
      density[i][yr] = pickle.load(dens)
      dens.close()

  cmap=plt.cm.viridis
  density_py={}
  for yr in range(8):
    density_py[yr] = {}
    aantal = 0
    for i in range(len(periods)):
      if yr not in density[i].keys():
        continue
      aantal = aantal+1
      for key in density[i][yr]:
        if key not in density_py[yr].keys():
          density_py[yr][key] = np.zeros_like(density[i][yr][key])
        density_py[yr][key] = density_py[yr][key] + density[i][yr][key]
    for key in density_py[yr].keys():
      density_py[yr][key] = density_py[yr][key]/aantal

  fig, ax = plt.subplots(1,3, figsize=(16,9), sharex=True, sharey=True)

  for yr in density_py.keys():
    polestr=density_py[yr]['polestr']
    polespbl=density_py[yr]['polespbl']
    poleseta=density_py[yr]['poleseta']
    ec_poleseta=density_py[yr]['ec_poleseta']
    polestopo=density_py[yr]['polestopo']

    midlattr=density_py[yr]['midlattr']
    midlatpbl=density_py[yr]['midlatpbl']
    midlateta=density_py[yr]['midlateta']
    ec_midlateta=density_py[yr]['ec_midlateta']
    midlattopo=density_py[yr]['midlattopo']

    troptr=density_py[yr]['troptr']
    troppbl=density_py[yr]['troppbl']
    tropeta=density_py[yr]['tropeta']
    ec_tropeta=density_py[yr]['ec_tropeta']
    troptopo=density_py[yr]['troptopo']

    dpoleseta=(poleseta-ec_poleseta)
    dmidlateta=(midlateta-ec_midlateta)
    dtropeta=(tropeta-ec_tropeta)


    ax[0].plot(dpoleseta,zpos,color=cmap(yr/7),zorder=3)
    ax[1].plot(dmidlateta,zpos,color=cmap(yr/7),zorder=3)
    ax[2].plot(dtropeta,zpos,color=cmap(yr/7),zorder=3)


  ax[0].text(0.02,.98,'$|\\rm{lat}|>66^{\\circ}$',horizontalalignment='left',verticalalignment='top',
    transform=ax[0].transAxes,color='k')#kleur1)
  ax[1].text(0.02,.98,'$66^{\\circ}>|\\rm{lat}|>23^{\\circ}$',horizontalalignment='left',verticalalignment='top',
    transform=ax[1].transAxes,color='k')#kleur2)
  ax[2].text(0.02,.98,'$|\\rm{lat}|<23^{\\circ}$',horizontalalignment='left',verticalalignment='top',
    transform=ax[2].transAxes,color='k')#kleur3)

  for i in range(3):
    ax[i].set_ylim([100,50000])
    ax[i].set_xlim([-0.1,0.1])
    ax[i].set_xlabel('$\\rho-\\rho_{\\rm ERA5}$ \n(kg m$^{-3}$)',fontsize=18)
    ax[i].axvline(x=0,color='lavender',ls='-',lw=3,zorder=0)
    #ax[i].set_yscale('log')
  ax[0].set_ylabel('Height (m)')

  line={}
  for yr in range(8):
    line[yr] = Line2D([0],[0],ls='-',lw=3,color=cmap(yr/7),label='year %i' %(yr+1))
  ax[2].legend(handles=[line[i] for i in line.keys()],handlelength=2,frameon=False,loc='upper right',fontsize=20)


  plt.tight_layout()
  plt.savefig('/home/lucie/LARA/Density/'+fname+'.png')
  #plt.savefig('/home/lucie/LARA/'+fname+'.pdf')
  plt.close()
