import pandas as pd
import re
import cartopy.feature as cf
#from multiprocessing import Process
import pickle
from matplotlib.colors import ListedColormap
import multiprocessing as mp
import numpy as np
import sys
import os
import h5py
import re
# import eccodes
# import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import matplotlib.pyplot as plt
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic

import cartopy.crs as ccrs
from matplotlib import rcParams

rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.weight'] = 'normal'
rcParams['font.size'] = 20
rcParams['xtick.labelsize'] = 22
rcParams['ytick.labelsize'] = 22
rcParams['axes.labelsize'] = 22
rcParams['axes.linewidth'] = 1
rcParams['axes.unicode_minus'] = False
rcParams['xtick.minor.width'] = 1
rcParams['ytick.minor.width'] = 1
rcParams['xtick.major.width'] = 1
rcParams['ytick.major.width'] = 1
rcParams['xtick.minor.size'] = 4
rcParams['ytick.minor.size'] = 4
rcParams['xtick.major.size'] = 5
rcParams['ytick.major.size'] = 5
rcParams['xtick.direction'] = 'in'
rcParams['ytick.direction'] = 'in'
rcParams['mathtext.fontset'] = 'cm'


import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
from matplotlib import gridspec
from matplotlib import ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import NullFormatter
from collections import OrderedDict, Counter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle, Ellipse
import matplotlib.patheffects as pe


#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  rho = density
  qv = specific humidity
  pv = potential vorticity
  tt = temperature
  uu = 
  vv =
  """ 

  def __init__(self, path_to_directory="/home/lucie/SizeTesting/output_1e6_23", 
    partposit_nc_fname = "partposit_average.nc",interpolatetopo=True):

    self.path = path_to_directory
    #Read in the header
    #self.header = read_header_bin(path_to_directory=self.path)
    #Npart from the header is not accurate when domainfilling is used, so this will be set when data is read instead
    self.npart = None

    self.partposit_nc_fname = partposit_nc_fname
    
    #Open netcdf particledata using xarray and naming it self._pd
    self._read_partposit_nc()
    self.coords = None

    self.kdtree = {}

    #2D and 3D fields interpolation (if required)
    self.topography = None
    self.tropopause = None
    self.hmixing = None
    self.interpolatetopo = interpolatetopo

  def _read_partposit_nc(self):
    tempfile = self.path + '/' + self.partposit_nc_fname
    self._pd = xr.open_dataset(tempfile, engine="netcdf4")
    #print("netCDF4 file opened.")

  @property
  def lon(self):
    if ('topo' in self._pd.variables):
      return self._pd.longitude
    else:
      return self._pd.lon

  @property
  def lon_av(self):
    return self._pd.longitude_av

  @property
  def lat(self):
    if ('topo' in self._pd.variables):
      return self._pd.latitude
    else:
      return self._pd.lat

  @property
  def lat_av(self):
    return self._pd.latitude_av

  @property
  def z(self):
    if ('topo' in self._pd.variables):
      return self._pd.height
    else:
      return self._pd.z

  @property
  def z_av(self):
    return self._pd.height_av

  @property
  def topo(self):
    if ('topo' in self._pd.variables):
      if self.interpolatetopo:
        if self.topography is None:
          topo_inter = self.interpolate2d(self._pd.topo,0)
          #allocate memory for topography per particle
          self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
          #For each timestep, interpolate the position of the particle to the grid
          for itime in range(len(self.topography[0,:])):

            topo_inter = self.interpolate2d(self._pd.topo,0)
            self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
            mini=np.min(self._pd.topo).values
            maxi=np.max(self._pd.topo).values
            self.topography[self.topography[:,itime]<mini,itime]=mini
            self.topography[self.topography[:,itime]>maxi,itime]=maxi
        return self.topography
      else:
        return self._pd.topo
    else:
      if self.interpolatetopo:
        if self.topography is None:
          topo_inter = self.interpolate2d(self._pd.to,0)
          #allocate memory for topography per particle
          self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
          #For each timestep, interpolate the position of the particle to the grid
          for itime in range(len(self.topography[0,:])):

            topo_inter = self.interpolate2d(self._pd.to,0)
            self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
            mini=np.min(self._pd.to).values
            maxi=np.max(self._pd.to).values
            self.topography[self.topography[:,itime]<mini,itime]=mini
            self.topography[self.topography[:,itime]>maxi,itime]=maxi
        return self.topography
      else:
        return self._pd.to

  @property
  def tro(self):
    if ('tr' in self._pd.variables):
      if len(self._pd.tr.dims)==3:
        if self.tropopause is None:
          self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
          for itime in range(len(self.tropopause[0,:])):
            tro_inter = self.interpolate2d(self._pd.tr,itime)
            self.tropopause[:,itime] = tro_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
            mini=np.min(self._pd.tr).values
            maxi=np.max(self._pd.tr).values
            self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
            self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
        return self.tropopause
      else:
        return self._pd.tr
    else:
      if len(self._pd.tro.dims)==3:
        if self.tropopause is None:
          self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
          for itime in range(len(self.tropopause[0,:])):
            tro_inter = self.interpolate2d(self._pd.tro,itime)
            self.tropopause[:,itime] = tro_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
            mini=np.min(self._pd.tro).values
            maxi=np.max(self._pd.tro).values
            self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
            self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
        return self.tropopause
      else:
        return self._pd.tro

  @property
  def hmix(self):
    if len(self._pd.hmix.dims)==3:
      if self.hmixing is None:
        self.hmixing = np.zeros((len(self.xlon),len(self.xlon[0,:])))
        for itime in range(len(self.hmixing[0,:])):
          hmix_inter = self.interpolate2d(self._pd.hmix,itime)
          self.hmixing[:,itime] = hmix_inter(self.ylat[:,itime],self.xlon[:,itime],grid=False)
          mini=np.min(self._pd.hmix).values
          maxi=np.max(self._pd.hmix).values
          self.hmixing[self.hmixing[:,itime]<mini,itime]=mini
          self.hmixing[self.hmixing[:,itime]>maxi,itime]=maxi
      return self.hmixing
    else:
      return self._pd.hmix

  @property
  def rho(self):
    return self._pd.rho

  @property
  def rho_av(self):
    return self._pd.rho_av

  @property
  def qv(self):
    return self._pd.qv

  @property
  def qv_av(self):
    return self._pd.qv_av

  @property
  def prs(self):
    if ('pr' in self._pd.variables):
      return self._pd.pr
    else:
      return self._pd.prs

  @property
  def prs_av(self):
    return self._pd.pr_av

  @property
  def pv(self):
    return self._pd.pv

  @property
  def pv_av(self):
    return self._pd.pv_av

  @property
  def tt(self):
    if ('topo' in self._pd.variables):
      return self._pd.temperature
    else:
      return self._pd.T

  @property
  def tt_av(self):
    return self._pd.temperature_av

  @property
  def mass(self):
    return self._pd.mass001

  @property
  def time(self):
    return np.array(self._pd.time)

  @property
  def particle(self):
    return np.array(self._pd.particle)

  @property 
  def day(self):
    return self._pd.time.dt.day.values

  @property 
  def month(self):
    return self._pd.time.dt.month.values

  @property 
  def year(self):
    return self._pd.time.dt.year.values

  @property
  def hour(self):
    return self._pd.time.dt.hour.values

  @property
  def minute(self):
    return self._pd.time.dt.minute.values

  @property
  def second(self):
    return self._pd.time.dt.second.values

  @property
  def period(self):
    return self._pd.attrs[list(self._pd.attrs.keys())[0]]

  def interpolate2d(self,field,itime):
    #lon = field.lon.values-0.5
    lon=np.arange(0.,360.5,0.5)
    #lat = field.lat.values-90.-0.5
    lat=np.arange(-90,90.5,0.5)
    # x,y=np.meshgrid(lon,lat)
    # xy=np.stack([x.ravel(),y.ravel()],-1)
    # info=np.zeros_like(x)
    # if len(field.dims)==3:
    #   infotemp = field[:,:,itime].values
    #   info[1:,1:]=infotemp
    #   info[1:,0]=infotemp[:,1]
    #   info[0,1:]=infotemp[1,:]
    #   info[0,0]=infotemp[1,1]
    # else:
    #   infotemp = field[:,:].values
    #   info[1:,1:]=infotemp
    #   info[1:,0]=infotemp[:,1]
    #   info[0,1:]=infotemp[1,:]
    #   info[0,0]=infotemp[1,1]
    # return griddata((x.ravel(),y.ravel()),info.ravel(),(self.xlon[:,itime],self.ylat[:,itime]))
    if len(field.dims)==2:
      info = field[:,:].values
    else:
      info = field[:,:,itime].values
    return rbs(lat,lon,info)


def convert_fields2load(pdata, fields_to_load):
  fields_new = []
  if 'topo' not in pdata._pd.variables:
    for field in fields_to_load:
      if field=='longitude':
        fields_new.append('lon')
      elif field=='latitude':
        fields_new.append('lat')
      elif field=='height':
        fields_new.append('z')
      elif field=='pr':
        fields_new.append('prs')
      elif field=='temperature':
        fields_new.append('T')
      else:
        fields_new.append(field)
    return fields_new
  else:
    return fields_to_load

def haversine(lon1,lat1,lon2,lat2):
  # convert decimal degrees to radians 
  lo1 = np.radians(lon1)
  lo2 = np.radians(lon2)
  la1 = np.radians(lat1)
  la2 = np.radians(lat2)

  # haversine formula 
  dlon = lo2 - lo1 
  dlat = la2 - la1 
  a = np.sin(dlat/2)**2 + np.cos(la1) * np.cos(la2) * np.sin(dlon/2)**2
  c = 2. * np.arcsin(np.sqrt(a)) 
  r = 6371.229 # Radius of earth in km.
  return c * r

def select_nc_files(years_all=(np.arange(1940,2024,1)).astype(int),data_path="/home/lucie/LARA/data_WCB/"):
  d_output={}
  for year in years_all:
    d_output[year] = {}
    d_output[year]['JJA'] = nv.Dataset(data_path+"WCB_JJA_%i.nc" %(year),"r")
    d_output[year]['DJF'] = nv.Dataset(data_path+"WCB_DJF_%i.nc" %(year),"r")

  return d_output

def compute_surface_area(grid_xlon,grid_ylat):
  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)

  lo = np.copy(grid_xlon)
  la = np.copy(grid_ylat)

  surface_area = np.zeros_like(lon2d_1)
  for i in range(len(lo)-1):
    a = haversine(lo[i],la[:-1],lo[i],la[1:])*1000.
    b = haversine(lo[i],la[1:],lo[i+1],la[1:])*1000.

    surface_area[:,i] = a*b
  return surface_area

def kg_to_mm_water(water_g, surface_area_m):
  #Assuming 1kg = 1L water
  masspart = 5.09256513E18/6.0e6 #kg
  w_m3 = water_g/1000.*masspart/1000. #gram to kg; L to m3
  w_mm = w_m3/surface_area_m*1000.
  return w_mm


def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / float(N)

def plot_temporal_evolution_distribution(d_output,res=1,t_interval=42,
  basename='WCBtime_dist',N=20):

  kleur1 = 'mediumseagreen'
  kleur2 = 'mediumorchid'

  grid_xlona = np.arange(-180.25,180.50,0.5)
  grid_ylata = np.arange(-90.25,90.50,0.5)
  grid_xlon = grid_xlona[np.arange(res/2,len(grid_xlona)-1,res).astype(int)]
  grid_ylat = grid_ylata[np.arange(res/2,len(grid_ylata)-1,res).astype(int)]

  surface_area = compute_surface_area(grid_xlon, grid_ylat)

  lon2d, lat2d = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)

  yrs = np.array(list(d_output.keys()))

  masspart = 5.09256513E18/5913999.
  minmass=1

  factorjja = masspart/surface_area/(24*(30+31+30)) #mass per surface area kg/m2
  factordjf = masspart/surface_area/(24*(28+31+31))

  fjja = masspart/(24*(30+31+30)) #mass per day
  fdjf = masspart/(24*(28+31+31))

  fig = plt.figure(figsize=(18,15))
  ax3=plt.subplot(111,projection=ccrs.PlateCarree(central_longitude=0))
  ax3.axis('off')
  ax = {}
  inner = gridspec.GridSpecFromSubplotSpec(175,120, subplot_spec=ax3, wspace=0, hspace=0)
  ax_vert = fig.add_subplot(inner[100:102,15:105])
  ax_vert2 = fig.add_subplot(inner[173:175,15:105])

  tags=np.array(['(a)', '(b)', '(c)', '(d)', '(e)'])
  ax[0,0] = fig.add_subplot(inner[0:36,3:])

  ax[0,0].text(0.01,.99,tags[0],horizontalalignment='left',verticalalignment='top',
    transform=ax[0,0].transAxes,color='k',
    bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))

  ax[1,0] = fig.add_subplot(inner[45:95,3:60],
    projection=ccrs.PlateCarree(central_longitude=0))
  ax[1,1] = fig.add_subplot(inner[45:95,-57:],
    projection=ccrs.PlateCarree(central_longitude=0))

  ax[2,0] = fig.add_subplot(inner[114:164,3:60],
    projection=ccrs.PlateCarree(central_longitude=0))
  ax[2,1] = fig.add_subplot(inner[114:164,-57:],
    projection=ccrs.PlateCarree(central_longitude=0))

  ii=1
  for i in range(1,3):
    for j in range(2):
      ax[i,j].set_xticklabels([])
      ax[i,j].coastlines(zorder=100)
      ax[i,j].set_global()
      ax[i,j].gridlines()

      ax[i,j].text(0.01,.99,tags[ii],horizontalalignment='left',verticalalignment='top',
        transform=ax[i,j].transAxes,color='k',
        bbox=dict(facecolor='white',edgecolor='none',alpha=0.7,boxstyle='round',pad=0.01))
      ii+=1

  #Plot time series
  years = list(d_output.keys())
  wcb_frac_jja = np.zeros(len(years))
  wcb_std_jja = np.zeros(len(years))
  wcb_frac_djf = np.zeros(len(years))
  wcb_std_djf = np.zeros(len(years))

  for i_year in range(len(years)):

    outtmpjja = np.zeros((len(grid_ylat)-1,len(grid_xlon)-1))
    outtmpdjf = np.zeros((len(grid_ylat)-1,len(grid_xlon)-1))
    outtmp1 = d_output[years[i_year]]['JJA']['WCB_grid'][:,:]
    outtmp2 = d_output[years[i_year]]['DJF']['WCB_grid'][:,:]
    if res==1:
      outtmpjja=outtmp1[:-1,:]
      outtmpdjf=outtmp2[:-1,:]
    else:
      for i in range(len(grid_ylat)-1):
        for j in range(len(grid_xlon)-1):
          outtmpjja[i,j] = np.sum(outtmp1[i*res:(i+1)*res,j*res:(j+1)*res])
          outtmpdjf[i,j] = np.sum(outtmp2[i*res:(i+1)*res,j*res:(j+1)*res])

    wcb_frac_jja[i_year] = (np.sum(outtmpjja[outtmpjja * factorjja>minmass])*fjja)
    wcb_frac_djf[i_year] = (np.sum(outtmpdjf[outtmpdjf * factordjf>minmass])*fdjf)
    #wcb_frac_jja[i_year] = np.mean(d_output[years[i_year]]['JJA']['WCB_frac'][:])*5.09256513E18
    # wcb_std_jja[i_year] = np.std(d_output[years[i_year]]['JJA']['WCB_frac'][:])
    # wcb_frac_djf[i_year] = np.mean(d_output[years[i_year]]['DJF']['WCB_frac'][:])
    # wcb_std_djf[i_year] = np.std(d_output[years[i_year]]['DJF']['WCB_frac'][:])

  ax[0,0].plot(years,wcb_frac_jja,ls='--',color=kleur1,lw=3)
  ax[0,0].plot(running_mean(years,N),running_mean(wcb_frac_jja,N),
    color=kleur1, lw=3)
  # ax.fill_between(years,wcb_frac_jja-wcb_std_jja,wcb_frac_jja+wcb_std_jja,
  #   color=kleur1, alpha=0.2)

  ax[0,0].axhline(np.mean(wcb_frac_jja),color=kleur1,ls=':', lw=5,alpha=0.4)

  ax[0,0].plot(years,wcb_frac_djf,ls='--',color=kleur2,lw=3)
  ax[0,0].plot(running_mean(years,N),running_mean(wcb_frac_djf,N),
    color=kleur2, lw=3)
  # ax.fill_between(years,wcb_frac_djf-wcb_std_djf,wcb_frac_djf+wcb_std_djf,
  #   color=kleur2, alpha=0.2)

  ax[0,0].axhline(np.mean(wcb_frac_djf),color=kleur2,ls=':', lw=5,alpha=0.4)

  yrs_elnino = np.array([1941,1942,1958,1966,1973,1978,1980,1983,1987,1992,1995,1998,2003,2007,2010,2016])
  for yr in yrs_elnino:
    ax[0,0].axvline(yr,color='lavender',lw=3,zorder=0)

  line1=plt.Line2D([0],[0],color=kleur1,ls='--',lw=3)
  line2=plt.Line2D([0],[0],color=kleur2,ls='--',lw=3)
  line3=plt.Line2D([0],[0],color='k',lw=3)
  line4=plt.Line2D([0],[0],color='k',ls=':',lw=5,alpha=0.4)
  legends = ['JJA','DJF','%i year running mean' %N,'Mean total' ]
  ax[0,0].legend([line1,line2,line3,line4],legends,loc='lower right',ncol=2)

  ax[0,0].set_ylabel('Mass within WCB \n[kg h$^{-1}$]')
  ax[0,0].set_xlabel('Year') 

  #Plot maps
  gridjja={}
  griddjf={}
  fi=0
  for ii in range(2):
    grida=None
    i=0
    for i_yr in range(ii*t_interval,ii*t_interval+t_interval):
      i=i+1
      if i_yr>=len(yrs):
        break
      if grida is None:
        grida=np.roll(d_output[yrs[i_yr]]['JJA']['WCB_grid'][:,:],360,axis=1)
        gridb=np.roll(d_output[yrs[i_yr]]['DJF']['WCB_grid'][:,:],360,axis=1)
      else:
        grida=grida+np.roll(d_output[yrs[i_yr]]['JJA']['WCB_grid'][:,:],360,axis=1)
        gridb=gridb+np.roll(d_output[yrs[i_yr]]['DJF']['WCB_grid'][:,:],360,axis=1)

    grida=grida/i
    gridb=gridb/i

    gridjja[ii] = np.zeros((len(grid_ylat)-1,len(grid_xlon)-1))
    griddjf[ii] = np.zeros((len(grid_ylat)-1,len(grid_xlon)-1))

    if res==1:
      gridjja[ii] = grida[:-1,:]
      griddjf[ii] = gridb[:-1,:]
    else:
      for i in range(len(gridjja[ii])):
        for j in range(len(gridjja[ii][0])):
          gridjja[ii][i,j] = np.sum(grida[i*res:(i+1)*res,j*res:(j+1)*res])
          griddjf[ii][i,j] = np.sum(gridb[i*res:(i+1)*res,j*res:(j+1)*res])


  for i in range(1,3):
    ax[i,0].set_yticks([-60,-30, 0, 30,60], crs=ccrs.PlateCarree())
    ax[i,0].set_yticklabels(labels=['60$^{\\circ}$ S','30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N','60$^{\\circ}$ N'])
    ax[i,0].set_ylabel('Latitude')

    ax[i,0].text(.5, .9, 'JJA', 
      transform=ax[i,0].transAxes, color='k', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
      position=(0.97, 0.97), fontweight='extra bold', ha='right', va='top',zorder=99999)
    ax[i,1].text(.5, .9, 'DJF', 
      transform=ax[i,1].transAxes, color='k', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
      position=(0.97, 0.97), fontweight='extra bold', ha='right', va='top',zorder=99999)

    for j in range(2):
      ax[i,j].set_xticks([-120,-60,0,60,120], crs=ccrs.PlateCarree())
      ax[i,j].set_xticklabels(labels=['120$^{\\circ}$ W','60$^{\\circ}$ W','0$^{\\circ}$','60$^{\\circ}$ E','120$^{\\circ}$ E'])
      if i>1:
        ax[i,j].set_xlabel('Longitude')

  ax[0,0].text(.5, .9, 'Period %i: %i-%i' %(1,yrs[0],yrs[t_interval-1]), 
    transform=ax[0,0].transAxes, color='k',
    position=(0.052, 0.981), fontweight='extra bold', ha='left', va='top',zorder=99999)

  ax[0,0].set_ylim([ax[0,0].get_ylim()[0]*0.97, ax[0,0].get_ylim()[1]])
  maxval = ax[0,0].get_ylim()[1]
  minval = ax[0,0].get_ylim()[1] - (ax[0,0].get_ylim()[1]-ax[0,0].get_ylim()[0])/10.
  ax[0,0].fill_betweenx(np.array([minval,maxval]),np.array([yrs[0],yrs[0]]),np.array([yrs[t_interval],yrs[t_interval]]),
    color='peachpuff',zorder=0)

  ax[0,0].fill_betweenx(np.array([minval,maxval]),np.array([yrs[t_interval],yrs[t_interval]]),np.array([yrs[2*t_interval-1],yrs[2*t_interval-1]]),
    color='lightsteelblue',zorder=0)

  ax[0,0].text(.5, .9, 'Period %i: %i-%i' %(2,yrs[t_interval],yrs[2*t_interval-1]), 
    transform=ax[0,0].transAxes, color='k',
    position=(0.515, 0.981), fontweight='extra bold', ha='left', va='top',zorder=99999)


  levels=np.linspace(minmass,8,15)
  cmap=plt.cm.Greens#Reds

  cmap = truncate_colormap(plt.cm.Greens, 0.1, 1)

  cs1=ax[1,0].contourf(lon2d,lat2d,gridjja[1]*factorjja,transform=ccrs.PlateCarree(),levels=levels,
    extend='both',cmap=cmap)
  ax[1,0].contour(lon2d,lat2d,gridjja[1]*factorjja,transform=ccrs.PlateCarree(),levels=levels,cmap=cmap)
  cs=ax[1,1].contourf(lon2d,lat2d,griddjf[1]*factordjf,transform=ccrs.PlateCarree(),extend='both',
    fraction=0.1,aspect=10,levels=levels,cmap=cmap)
  ax[1,1].contour(lon2d,lat2d,griddjf[1]*factordjf,transform=ccrs.PlateCarree(),
    fraction=0.1,aspect=10,levels=levels,cmap=cmap)

  cs1.cmap.set_under('white')
  cs1.cmap.set_over(cmap(0.99))
  cs.cmap.set_under('white')
  cs.cmap.set_over(cmap(0.99))

  cmap2=plt.cm.seismic
  levels2=np.linspace(-1.5,1.5,8)
  ww=np.where((gridjja[1]*factorjja<=minmass)&(gridjja[0]*factorjja<=minmass))
  gridjja[1][ww]=0
  gridjja[0][ww]=0
  ax[2,0].contourf(lon2d,lat2d,(gridjja[1]-gridjja[0])*factorjja,transform=ccrs.PlateCarree(),extend='both',
    levels=levels2,cmap=cmap2,alpha=0.8)
  ax[2,0].contour(lon2d,lat2d,(gridjja[1]-gridjja[0])*factorjja,transform=ccrs.PlateCarree(),extend='both',
    levels=levels2,cmap=cmap2,alpha=0.8)
  ww=np.where((griddjf[1]*factordjf<=minmass)&(griddjf[0]*factordjf<=minmass))
  griddjf[1][ww]=0
  griddjf[0][ww]=0
  cs2=ax[2,1].contourf(lon2d,lat2d,((griddjf[1]-griddjf[0])*factordjf),transform=ccrs.PlateCarree(),extend='both',
    fraction=0.1,aspect=10,levels=levels2,cmap=cmap2,alpha=0.8)
  ax[2,1].contour(lon2d,lat2d,((griddjf[1]-griddjf[0])*factordjf),transform=ccrs.PlateCarree(),extend='both',
    fraction=0.1,aspect=10,levels=levels2,cmap=cmap2,alpha=0.8)

  ax[1,0].text(.5, .9, 'Period 2', bbox=(dict(facecolor='lightsteelblue',alpha=0.9,edgecolor='white')),
    transform=ax[1,0].transAxes, color='k',
    position=(0.50, 0.98), fontweight='extra bold', ha='center', va='top',zorder=99999)
  ax[1,1].text(.5, .9, 'Period 2', bbox=(dict(facecolor='lightsteelblue',alpha=0.7,edgecolor='white')),
    transform=ax[1,1].transAxes, color='k',
    position=(0.50, 0.98), fontweight='extra bold', ha='center', va='top',zorder=99999)
  ax[2,0].text(.5, .9, 'Period 2 - Period 1', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
    transform=ax[2,0].transAxes, color='k',
    position=(0.50, 0.98), fontweight='extra bold', ha='center', va='top',zorder=99999)
  ax[2,1].text(.5, .9, 'Period 2 - Period 1', bbox=(dict(facecolor='white',alpha=0.7,edgecolor='white')),
    transform=ax[2,1].transAxes, color='k',
    position=(0.50, 0.98), fontweight='extra bold', ha='center', va='top',zorder=99999)

  cb1=fig.colorbar(cs,cax=ax_vert,orientation='horizontal',extend='both',label='Air mass within WCB [kg m$^{-2}$ h$^{-1}$]')
  cb1.set_ticks(np.linspace(levels[0],levels[-1],6))
  cb1.update_ticks()
  ax_vert.xaxis.set_ticks_position('bottom')

  cb1.cmap.set_under('white')
  cb1.cmap.set_over(cmap(0.99))

  cb2=fig.colorbar(cs2,cax=ax_vert2,orientation='horizontal',extend='both',label='$\\Delta$WCB [kg m$^{-2}$ h$^{-1}$]')
  tick_locator=ticker.MaxNLocator(nbins=len(levels2)-1)
  # cb2.locator=tick_locator
  # cb2.update_ticks()
  cb2.set_ticks(np.linspace(levels2[0],levels2[-1],6))
  cb2.update_ticks()
  ax_vert2.xaxis.set_ticks_position('bottom')
  cb2.cmap.set_under(cmap2(0.01))
  cb2.cmap.set_over(cmap2(0.99))

  plt.subplots_adjust(top=0.98, bottom=0.06, left=0.063, right=0.99, wspace=0.01, hspace=0)
  plt.savefig(basename+'.png' %(yrs[ii]))
  plt.savefig(basename+'.pdf' %(yrs[ii]))
  plt.close()
