import numpy as np
import sys
import os
import re
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import zarr
import gc
import glob
import copy
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic
"""
------------------------------------------

------------------------------------------

-----------Notes--------------------------------------------------------
"""

class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)

class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  """ 

  def __init__(self, basepath="/jetfs/scratch/lbakels/LARA_zarr/",
    period="194001-194703",year="1940",month="01",
    interpolatetopo=True):

    self.path_topo = basepath
    self.path = f"{basepath}/{period}/{year}/{month}"

    #2D and 3D fields interpolation (if required)
    self.topography = None
    self.tropopause = None
    self.hmixing = None
    self.interpolatetopo = interpolatetopo

    self._lon = None
    self._lat = None
    self._z = None
    self._lon_av = None
    self._lat_av = None
    self._z_av = None
    self._topo = None
    self._tro = None
    self._hmix = None

  @property
  def lon(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return self._lon.lon

  @property
  def lon_av(self):
    if self._lon_av is None:
      self._lon_av = xr.open_dataset(self.path+'/lon_av',engine="zarr")
    return self._lon_av.lon_av

  @property
  def lat(self):
    if self._lat is None:
      self._lat = xr.open_dataset(self.path+'/lat',engine="zarr")
    return self._lat.lat

  @property
  def lat_av(self):
    if self._lat_av is None:
      self._lat_av = xr.open_dataset(self.path+'/lat_av',engine="zarr")
    return self._lat_av.lat_av

  @property
  def z(self):
    if self._z is None:
      self._z = xr.open_dataset(self.path+'/z',engine="zarr")
    return self._z.z

  @property
  def z_av(self):
    if self._z_av is None:
      self._z_av = xr.open_dataset(self.path+'/z_av',engine="zarr")
    return self._z_av.z_av

  @property
  def _topo1(self):
    if self._topo is None:
      self._topo = xr.open_dataset(self.path_topo+'/to',engine="zarr").to
    return self._topo

  @property
  def topo(self):
    if self._topo is None:
      self._topo = self._topo1.load()
    if self.interpolatetopo:
      if self.topography is None:
        topo_inter = self.interpolate2d(self._topo,0)
        #allocate memory for topography per particle
        self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
        #For each timestep, interpolate the position of the particle to the grid
        topo_inter = self.interpolate2d(self._topo,0)
        for itime in range(len(self.topography[0,:])):          
          self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
          mini=np.min(self._topo).values
          maxi=np.max(self._topo).values
          self.topography[self.topography[:,itime]<mini,itime]=mini
          self.topography[self.topography[:,itime]>maxi,itime]=maxi
      return self.topography
    else:
      return self._topo

  @property
  def tro(self):
    if self._tro is None:
      self._tro = xr.open_dataset(self.path+'/tro',engine="zarr").tro
    # if len(self._tro.dims)==3:
    #   if self.tropopause is None:
    #     self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
    #     for itime in range(len(self.tropopause[0,:])):
    #       tro_inter = self.interpolate2d(self._tro,itime)
    #       self.tropopause[:,itime] = tro_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
    #       mini=np.min(self._tro).values
    #       maxi=np.max(self._tro).values
    #       self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
    #       self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
    #   return self.tropopause
    # else:
    return self._tro

  @property
  def hmix(self):
    if self._hmix is None:
      self._hmix = xr.open_dataset(self.path+'/hmix',engine="zarr").hmix
    # if len(self._hmix.dims)==3:
    #   if self.hmixing is None:
    #     self.hmixing = np.zeros((len(self.xlon),len(self.xlon[0,:])))
    #     for itime in range(len(self.hmixing[0,:])):
    #       hmix_inter = self.interpolate2d(self._hmix,itime)
    #       self.hmixing[:,itime] = hmix_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
    #       mini=np.min(self._hmix).values
    #       maxi=np.max(self._hmix).values
    #       self.hmixing[self.hmixing[:,itime]<mini,itime]=mini
    #       self.hmixing[self.hmixing[:,itime]>maxi,itime]=maxi
    #   return self.hmixing
    # else:
    return self._hmix

  def interpolate2d(self,field,itime):
    lon=np.arange(0.,360.5,0.5)
    lat=np.arange(-90,90.5,0.5)
    if len(field.dims)==2:
      info = field[:,:].values
    else:
      info = field[:,:,itime].values
    return rbs(lat,lon,info)


def volume_calc(grid_xlon, grid_ylat, grid_z,topo=None):
  grid_vol = np.zeros((len(grid_ylat)-1, len(grid_xlon)-1, len(grid_z)-1))
  grid_xrad = np.deg2rad(grid_xlon)
  grid_yrad = np.deg2rad(grid_ylat)+0.5*np.pi
  zfactor = np.ones_like(grid_vol)*1./3*((6371229.+grid_z[1:])**3-(6371229.+grid_z[:-1])**3)

  for i in range(len(grid_xrad)-1):
    for j in range(len(grid_yrad)-1):
      if (topo is not None):
        ll=np.abs(grid_z[:]-topo[j,i]).argmin()
        if grid_z[ll]>topo[j,i]:
          if ll<=0:
            ll=0
          else:
            zfactor[j,i,:ll-1]=0.
            zfactor[j,i,ll-1]=1./3*((6371229.+grid_z[ll+1])**3-(6371229.+topo[j,i])**3)
          # else:
          #   ll=ll-1
        elif grid_z[ll]<=topo[j,i]:
          zfactor[j,i,:ll]=0.
          zfactor[j,i,ll]=1./3*((6371229.+grid_z[ll+1])**3-(6371229.+topo[j,i])**3)

      grid_vol[j, i, :] = (grid_xrad[i+1]-grid_xrad[i])*(np.cos(grid_yrad[j])-np.cos(grid_yrad[j+1]))*zfactor[j,i,:]
  return grid_vol

def ERA5_density(ERA5basepath,EAfile, akz,bkz,topo,output):

  data={}
  data['qv'] = xr.open_dataset(ERA5basepath+EAfile,
    engine="cfgrib",backend_kwargs={'filter_by_keys':{'shortName':'q'},'indexpath':''})['q']#.values
  data['tth'] = xr.open_dataset(ERA5basepath+EAfile,
    engine="cfgrib",backend_kwargs={'filter_by_keys':{'shortName':'t'},'indexpath':''})['t']#.values
  data['sp'] = xr.open_dataset(ERA5basepath+EAfile,
    engine="cfgrib",backend_kwargs={'filter_by_keys':{'shortName':'sp'},'indexpath':''})['sp']#.values
  data['tt2'] = xr.open_dataset(ERA5basepath+EAfile,
    engine="cfgrib",backend_kwargs={'filter_by_keys':{'shortName':'2t'},'indexpath':''})['t2m']#.values
  data['td2'] = xr.open_dataset(ERA5basepath+EAfile,
    engine="cfgrib",backend_kwargs={'filter_by_keys':{'shortName':'2d'},'indexpath':''})['d2m']#.values

  sp=data['sp'].values
  td2=data['td2'].values
  tt2=data['tt2'].values
  tth=data['tth'].values
  qv=data['qv'].values


  prs = np.zeros((len(akz),len(sp),len(sp[0])))
  for iz in range(len(akz)):
    prs[iz,:,:] = akz[iz]+bkz[iz]*sp

  #Flip z axis
  prs=np.flip(prs,axis=0)
  tth=np.flip(tth,axis=0)
  qv=np.flip(qv,axis=0)

  r_air=287.05
  ga=9.81
  const=r_air/ga
  tv=tth*(1.+0.608*qv)
  tv[0,:,:]=tt2*(1.+0.378*ew(td2)/sp)

  #Density
  rhotmp = prs/(r_air*tv)

  heightstmp=np.zeros_like(tv)
  heightstmp[0,:,:] = np.roll(np.flip(topo[:,:-1],axis=0),360,axis=1)
  for iz in range(1,len(heightstmp)):
    heightstmp[iz,:,:] = np.where(abs(tv[iz,:,:]-tv[iz-1])>0.2,
      heightstmp[iz-1,:,:] + const*np.log(prs[iz-1,:,:]/prs[iz,:,:]) * 
      (tv[iz,:,:]-tv[iz-1,:,:])/np.log(tv[iz,:,:]/tv[iz-1,:,:]),
      heightstmp[iz-1,:,:] + const*np.log(prs[iz-1,:,:]/prs[iz,:,:])*tv[iz,:,:])

  output['heights'] = output['heights']+np.roll(np.flip(np.transpose(heightstmp,(1,2,0)),axis=0),360,axis=1)
  output['rho'] = output['rho']+np.roll(np.flip(np.transpose(rhotmp,(1,2,0)),axis=0),360,axis=1)

  #if i==0:
  # output[EAfile]['heights'] = np.roll(np.flip(np.transpose(heightstmp,(1,2,0)),axis=0),360,axis=1)
  # output[EAfile]['rho'] = np.roll(np.flip(np.transpose(rhotmp,(1,2,0)),axis=0),360,axis=1)
  #else:
  #  heights = heights + np.roll(np.flip(np.transpose(heightstmp,(1,2,0)),axis=0),360,axis=1)
  #  rho = rho + np.roll(np.flip(np.transpose(rhotmp,(1,2,0)),axis=0),360,axis=1)
  data['sp'].close()
  data['td2'].close()
  data['tt2'].close()
  data['tth'].close()
  data['qv'].close()

  data.clear()
  del data
  del sp
  del td2
  del tt2 
  del tth 
  del qv
  del prs
  del tv
  gc.collect()
  print('done')

def ew(td):
  y=373.16/td
  a=-7.90298*(y-1.)
  a=a+(5.02808*0.43429*np.log(y))
  c=(1.-(1./y))*11.344
  c=-1.+(10.**c)
  c=-1.3816*c/(10.**7)
  d=(1.-y)*3.49149
  d=-1.+(10.**d)
  d=8.1328*d/(10.**3)
  y=a+c+d
  return 101324.6*(10.**y)


def compute_masked_arrays(partdata,ecdata,lat1,lat2,ndata,gridvol,tropo,hmix,topo,topogrid):
  '''
  A:
  1) Only consider gridcells above topography
  2) Take the difference between the particle computed densities and ECMWF densities per grid cell
  3) Take the average of all gridcells per height level
  B:
  1) Only consider gridcells above topography
  2) Take the volume weighted average of the particle computed densities of all gridcells per height level
  C:
  1) Only consider gridcells above topography
  2) Only consider gridcells that have particles inside for taking the average density
  3) Take the volume weighted average of the mean ECMWF density of all gridcells per height level
  '''
  # masked1 = np.ma.masked_where(np.isnan(partdata[lat1:lat2,:,:]) | (np.isinf(partdata[lat1:lat2,:,:])) |
  #   (ndata[lat1:lat2,:,:]<3) ,np.abs(partdata[lat1:lat2,:,:]-ecdata[lat1:lat2,:,:]))
  # darray=masked1.mean(axis=(0,1))
  mintop=1e10
  topotemp = np.transpose(np.transpose(np.ones_like(partdata),(2,0,1))*topogrid,(1,2,0))

  masked1 = np.ma.masked_where(np.isnan(partdata[lat1:lat2,:,:]) | (np.isinf(partdata[lat1:lat2,:,:])) | (topotemp[lat1:lat2,:,:]>mintop)
    ,partdata[lat1:lat2,:,:]*gridvol[lat1:lat2,:,:])
  masked2 = np.ma.masked_where(np.isnan(partdata[lat1:lat2,:,:]) | (np.isinf(partdata[lat1:lat2,:,:])) | (topotemp[lat1:lat2,:,:]>mintop)
    ,gridvol[lat1:lat2,:,:])
  part_array=masked1.sum(axis=(0,1))/masked2.sum(axis=(0,1))

  #mintop=100
  masked1 = np.ma.masked_where(#np.isnan(partdata[lat1:lat2,:,:]) | (np.isinf(partdata[lat1:lat2,:,:])) |
    (ecdata[lat1:lat2,:,:]==0) | (topotemp[lat1:lat2,:,:]>mintop)
    ,ecdata[lat1:lat2,:,:]*gridvol[lat1:lat2,:,:])
  masked2 = np.ma.masked_where(#np.isnan(partdata[lat1:lat2,:,:]) | (np.isinf(partdata[lat1:lat2,:,:])) |
    (ecdata[lat1:lat2,:,:]==0) | (topotemp[lat1:lat2,:,:]>mintop)
    ,gridvol[lat1:lat2,:,:])

  ec_array=masked1.sum(axis=(0,1))/masked2.sum(axis=(0,1))

  #Compute hmix and tropopause
  gridvolcol=np.sum(gridvol,axis=(1,2))
  masked1 = np.ma.masked_where((tropo[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),tropo[lat1:lat2]*gridvolcol[lat1:lat2])
  masked2 = np.ma.masked_where((tropo[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),gridvolcol[lat1:lat2])

  avtr=masked1.sum()/masked2.sum()

  masked1 = np.ma.masked_where((hmix[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),hmix[lat1:lat2]*gridvolcol[lat1:lat2])
  masked2 = np.ma.masked_where((hmix[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),gridvolcol[lat1:lat2])

  avpbl=masked1.sum()/masked2.sum()

  masked1 = np.ma.masked_where((topo[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),topo[lat1:lat2]*gridvolcol[lat1:lat2])
  masked2 = np.ma.masked_where((topo[lat1:lat2]==0) | (topo[lat1:lat2]>mintop),gridvolcol[lat1:lat2])

  avtopo=masked1.sum()/masked2.sum()

  return avtr,avpbl,part_array,ec_array,avtopo

def compute_density( path_to_LARA='/jetfs/scratch/lbakels/LARA/', 
  z=[0,1e5],period="194001-194703",year="1940",month="01",):

  pd = PartInfo(basepath=path_to_LARA,period=period,year=year,month=month)

  #grid from sea level
  # To get the correct value on the edges of the bins, we need to
  # substract 0.25 of the bins! The edges of the grid_vol are zero and are excluded from computations.
  grid_ylat=np.arange(-90.25,90.26,0.5)
  grid_xlon=np.arange(-0.25,360,0.5)
  grid_z = np.concatenate([np.array([0.,25.,50.,75.]),np.logspace(2,np.log10(z[1]),100)])

  #zpos = np.logspace(np.log10(grid_z[0]+(grid_z[1]-grid_z[0])/2.),np.log10(grid_z[-2]+(grid_z[-1]-grid_z[-2])/2.),99)
  zpos=grid_z[:-1]+(grid_z[1:]-grid_z[:-1])/2.
 
  hmix=None
  itot=0

  hmix = np.mean(pd.hmix.values,axis=(1,2)) # double value at first and final lon (not a problem)
  pd.hmix.close()
  print('close hmix')
  tropo = np.mean(pd.tro.values,axis=(1,2))
  pd.tro.close()
  print('close tro')
  #Mass per particle
  numpart = len(pd.lon.particle)
  massparteta = 5.09256513E18/numpart

  print('load topo')
  topo = pd._topo1.values
  print('finished loading topo')
  
  topo2=np.mean(topo,axis=(1))

  grid_vol=volume_calc(grid_xlon,grid_ylat,grid_z,topo)

  #Sum up particles within range
  print('loading lon and lat')
  lon1eta = pd.lon.load()
  lat1eta = pd.lat.load()
  lon1eta = pd.lon.values.flatten(order='C')
  lat1eta = pd.lat.values.flatten(order='C')
  print('loaded lon and lat')
  print('compute topo')
  topoeta = pd.topo[:,:].flatten(order='C')
  print('computed topo')
  z1eta = pd.z.values.flatten(order='C')+topoeta
  print('loaded z')
  #Bin the results according to specified grid
  print('Bin results')
  rho_n1eta,edges = np.histogramdd([lat1eta,lon1eta,z1eta], 
    bins=[grid_ylat,grid_xlon,grid_z])

  pd.lon.close()
  pd.lat.close()
  pd.z.close()
  del pd, lon1eta,lat1eta,topoeta,z1eta
  gc.collect()

  #From particle distribution
  #NaN below topography
  print('Compute density')
  rho_n1eta = rho_n1eta/grid_vol*massparteta
  print("Finished zarr file")

  #Average difference between era5 and particle distribution

  # Loading ERA-5 data
  ERA5basepath="/jetfs/shared-data/ECMWF/ERA5_glob_0.5deg_1h/%i/%02d/" %(start_time[0],start_time[1])
  EAfiles_tmp = os.listdir(ERA5basepath)
  EAfiles = []
  for ff in EAfiles_tmp:
    if ff.startswith('EA'):
      EAfiles.append(ff)
  # Computing model pressure levels
  df = pa.read_csv('table.csv')
  akm = df['a [Pa]'].values
  bkm = df['b'].values

  akz=0.5*(akm[1:]+akm[:-1])
  bkz=0.5*(bkm[1:]+bkm[:-1])

  output={}
  output['heights']=np.zeros((len(topo),len(topo[0])-1,137))
  output['rho'] =np.zeros((len(topo),len(topo[0])-1,137))
  i=0
  for EAfile in EAfiles:
    ERA5_density(ERA5basepath,EAfile,akz,bkz,topo,output)
    i=i+1
  heights=output['heights']
  rho=output['rho']
  heights=heights/i
  rho=rho/i

  #Interpolating densities to fit the zpos levels
  rho2 = np.zeros_like(grid_vol)
  for jy in range(len(grid_ylat)-1):
    for ix in range(len(grid_xlon)-1):
      inter_tmp = interp1d(heights[jy,ix,:],rho[jy,ix,:])
      lmin = np.abs(zpos-heights[jy,ix,0]).argmin()
      lmax = np.abs(zpos-heights[jy,ix,-1]).argmin()-1
      if heights[jy,ix,lmin]>zpos[lmin]:
        lmin=lmin+1
      if heights[jy,ix,lmax]<zpos[lmax]:
        lmax=lmax-1
      rho2[jy,ix,lmin:lmax] = inter_tmp(zpos[lmin:lmax])

  # #Poles
  polestr,polespbl,poleseta,ec_poleseta,polestopo=compute_masked_arrays(
    np.concatenate((rho_n1eta[0:48,:,:],rho_n1eta[312:361,:,:]),axis=0),
    np.concatenate((rho2[0:48,:,:],rho2[312:361,:,:]),axis=0),
    0,-1,
    np.concatenate((grid_n1eta[0:48,:,:],grid_n1eta[312:361,:,:]),axis=0),
    np.concatenate((grid_vol[0:48,:,:],grid_vol[312:361,:,:]),axis=0),
    np.concatenate((tropo[0:48],tropo[312:361]),axis=0),
    np.concatenate((hmix[0:48],hmix[312:361]),axis=0),
    np.concatenate((topo2[0:48],topo2[312:361]),axis=0),
    np.concatenate((topo[0:48,:-1],topo[312:361,:-1]),axis=0)
    )
  dpoleseta=np.abs(poleseta-ec_poleseta)

  datafile = open('density_%i_%i.pickle' %(start_time[0],start_time[1]),'wb')
  pickled_data = {}
  pickled_data['polestr']=polestr
  pickled_data['polespbl']=polespbl
  pickled_data['poleseta']=poleseta
  pickled_data['ec_poleseta']=ec_poleseta
  pickled_data['polestopo']=polestopo

  # Midlatitudes
  midlattr,midlatpbl,tmp1,ec_tmp1,midlattopo=compute_masked_arrays(rho_n1eta,rho2,48,134,grid_n1eta,grid_vol,tropo,hmix,topo2,topo[:,:-1])
  midlattr,midlatpbl,tmp2,ec_tmp2,midlattopo=compute_masked_arrays(rho_n1eta,rho2,226,312,grid_n1eta,grid_vol,tropo,hmix,topo2,topo[:,:-1])
  midlateta=(tmp1+tmp2)/2.
  ec_midlateta=(ec_tmp1+ec_tmp2)/2.
  dmidlateta=np.abs(midlateta-ec_midlateta)

  pickled_data['midlattr']=midlattr
  pickled_data['midlatpbl']=midlatpbl
  pickled_data['midlateta']=midlateta
  pickled_data['ec_midlateta']=ec_midlateta
  pickled_data['midlattopo']=midlattopo

  #Tropics
  troptr,troppbl,tropeta,ec_tropeta,troptopo=compute_masked_arrays(rho_n1eta,rho2,134,226,grid_n1eta,grid_vol,tropo,hmix,topo2,topo[:,:-1])
  dtropeta=np.abs(tropeta-ec_tropeta)

  pickled_data['troptr']=troptr
  pickled_data['troppbl']=troppbl
  pickled_data['tropeta']=tropeta
  pickled_data['ec_tropeta']=ec_tropeta
  pickled_data['troptopo']=troptopo
  pickle.dump(pickled_data,datafile)
  datafile.close()

