import pandas as pd
import re
#from multiprocessing import Process
import pickle
import numpy as np
import sys
import os
import re
# import eccodes
# import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic
from sklearn.neighbors import BallTree

#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  rho = density
  qv = specific humidity
  pv = potential vorticity
  tt = temperature
  uu = 
  vv =
  """ 

  def __init__(self, basepath="/jetfs/scratch/lbakels/LARA_zarr/194001-194703/",
    year="1940",month="01",interpolatetopo=True):

    self.path_topo = basepath
    self.path = f"{basepath}/{year}/{month}"

    #2D and 3D fields interpolation (if required)
    self._lon = None
    self._lat = None
    self._z = None
    self._prs = None

    self.ds = {}
    self.ds['lon'] = self.lon
    self.ds['lat'] = self.lat
    self.ds['z'] = self.z
    self.ds['prs'] = self.prs

  @property
  def lon(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return self._lon.lon

  @property
  def lat(self):
    if self._lat is None:
      self._lat = xr.open_dataset(self.path+'/lat',engine="zarr")
    return self._lat.lat

  @property
  def z(self):
    if self._z is None:
      self._z = xr.open_dataset(self.path+'/z',engine="zarr")
    return self._z.z

  @property
  def prs(self):
    if self._prs is None:
      self._prs = xr.open_dataset(self.path+'/prs',engine="zarr")
    return self._prs.prs

  @property
  def time(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return np.array(self._lon.time)


def select_files_LARA(path_to_directory='/jetfs/scratch/lbakels/LARA/197901-198503/',
  times=[[1980,1],[1980,2]]):
  
  pd = {}

  year_dir = os.listdir(path_to_directory)

  for s_time in times:
    pd[f'{path_to_directory}/{s_time[0]}/{s_time[1]}/'] = PartInfo(basepath=path_to_directory,
      year="%i" %s_time[0], month="%02i" %s_time[1])

  return pd

def find_hours(start_time=[1970,1,1,0],end_time=[1971,2,2,2]):
  dd = np.array([31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ddly = np.array([31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ly = np.arange(1940, 2024, 4).astype(int)
  years=np.arange(start_time[0],end_time[0]+1,1)

  if start_time[0]==end_time[0]:
    if start_time[0] not in ly:
      return ((np.sum(dd[start_time[1]-1:end_time[1]]) - #month
        start_time[2]+1-(dd[end_time[1]-1]-end_time[2]))*24 - #days
        start_time[3]-(24-end_time[3])) #hours
    else:
      return ((np.sum(ddly[start_time[1]-1:end_time[1]]) - #month
        start_time[2]+1-(ddly[end_time[1]-1]-end_time[2]))*24 - #days
        start_time[3]-(24-end_time[3])) #hours

  total_hours=0

  for y in years:
    if y == start_time[0]:
      if y not in ly:
        total_hours += ((np.sum(dd[start_time[1]-1:]) - #month
          start_time[2]+1)*24 - #days
          start_time[3]) #hours
      else:
        total_hours += ((np.sum(ddly[start_time[1]-1:]) - #month
          start_time[2]+1)*24 - #days
          start_time[3]) #hours      
    elif y == end_time[0]:
      if y not in ly:
        total_hours += ((np.sum(dd[:end_time[1]]) - #month
          (dd[end_time[1]-1]-end_time[2]))*24 - #days
          (24-end_time[3])) #hours
      else:
        total_hours += ((np.sum(ddly[:end_time[1]]) - #month
          (dd[end_time[1]-1]-end_time[2]))*24 - #days
          (24-end_time[3])) #hours
      return total_hours
    else:
      if y not in ly:
        total_hours += np.sum(dd)*24
      else:
        total_hours += np.sum(ddly)*24

  return total_hours

def find_WCB_LARA(pd_n, n_timesteps_n,t_interval):

  griddata=None
  list_of_files=list(pd_n.keys())

  grid_ylat=np.arange(-90.25,90.26,0.5)
  grid_xlon=np.arange(-0.25,360,0.5)

  timesteps=np.arange(0,n_timesteps_n,t_interval).astype(int)
  wcb_frac=np.zeros_like(timesteps).astype(float)

  i_file=0
  i_t_file = 0
  for i in range(len(timesteps)):
    nfile=list_of_files[i_file]

    if i==0:
      pd_n[nfile].prs.load()
      pr1 = pd_n[nfile].prs[:,i_t_file].values
      numpart = len(np.where(pd_n[nfile].lon[:,0]>=0)[0])
    else:
      i_t_file=i_t_file2
      pr1 = np.copy(pr2)

    lon1=pd_n[nfile].lon[:,i_t_file].values
    lat1=pd_n[nfile].lat[:,i_t_file].values
    nfilemid=None
    #If the next step is in another file, reset file timer and i_file
    if (i_t_file + t_interval >= len(pd_n[nfile].lon[0,:])):
      i_t_file2 = i_t_file + t_interval - len(pd_n[nfile].lon[0,:])

      i_file = i_file+1
      if (i_file>=len(list_of_files)):
        return griddata,wcb_frac
      nfile2=list_of_files[i_file]
      print(i_t_file2,len(pd_n[nfile2].lon[0,:]))

      if (i_t_file2 >= len(pd_n[nfile2].lon[0,:])):
        i_t_file2 = i_t_file2 - len(pd_n[nfile2].lon[0,:])  
        i_file = i_file+1
        nfilemid=nfile2
        pd_n[nfilemid].prs.load()
        if (i_file>=len(list_of_files)):
          return griddata,wcb_frac
        nfile2=list_of_files[i_file]
      numpart2 = len(np.where(pd_n[nfile2].lon[:,0]>=0)[0])
      print(i, nfile, len(pd_n[nfile].lon[0,:]),len(pd_n[nfile2].lon[0,:]))
    else:
      i_t_file2 = i_t_file + t_interval
      nfile2 = nfile
      numpart2=numpart

    #pd_n[nfile2].prs.load()
    pr2 = pd_n[nfile2].prs[:,i_t_file2].values
    lon2=pd_n[nfile2].lon[:,i_t_file2].values
    lat2=pd_n[nfile2].lat[:,i_t_file2].values

    maxlen=np.min([numpart,numpart2])
    ww=np.where((pr1[:maxlen]>79000)&((pr1[:maxlen]-pr2[:maxlen])>50000))[0]

    wwn = ww[np.where((lat1[ww]>0)&((lon2[ww]-lon1[ww]>10)|(lon2[ww]-lon1[ww]<-180.))&(lat2[ww]-lat1[ww]>5))[0]]
    wws = ww[np.where((lat1[ww]<0)&((lon2[ww]-lon1[ww]>10.)|(lon2[ww]-lon1[ww]<-180.))&(lat2[ww]-lat1[ww]<-5))[0]]
    sel = np.concatenate([wwn,wws])

    wcb_frac[i] = len(sel)/numpart
    # Add to grid
    if (len(sel)>0):
      if griddata is None:
        griddata = np.histogram2d(lat1[sel].flatten(order='C'), lon1[sel].flatten(order='C'), 
          bins=[grid_ylat, grid_xlon])[0]
      else:
        griddata = griddata + np.histogram2d(lat1[sel].flatten(order='C'), lon1[sel].flatten(order='C'), 
          bins=[grid_ylat, grid_xlon])[0]

    if nfile2!=nfile:
      del pd_n[nfile]
      gc.collect()
    if nfilemid is not None:
      del pd_n[nfilemid]
      gc.collect()


  return griddata, wcb_frac


def select_particles_grid(sel_dict):
  '''
  Function that selects and organises data from the LARA dataset using the HILDA+ dataset.
  
  Parameters
  ----------
  sel_dict : dictionary
      'paths_to_directories' : list of 2 str
      'period' : 2D array of ints
        Selection period [[Year_i, Month_i, Day_i, Hour_i],[Year_f, Month_f, Day_f, Hour_f]]
      
  Returns
  -------
  griddata : 2D array
  wcb_frac : 1D array
  '''
  print('Select files')
  pd_1 = select_files_LARA(path_to_directory=sel_dict['paths_to_directories'],
    times=sel_dict['times'])

  hours=find_hours(sel_dict['period'][0],sel_dict['period'][1])
  print('Hours:', hours)
  print('Select data')
  griddata, wcb_frac = find_WCB_LARA(pd_1,hours,sel_dict['time_interval'])

  return griddata, wcb_frac


def save_netcdf(griddata,wcb_frac, file_name, sel_dict):

  ds = nv.Dataset(file_name, 'w', format='NETCDF4')

  hours=find_hours(sel_dict['period'][0],sel_dict['period'][1])
  timesteps=np.arange(0,hours,sel_dict['time_interval']).astype(int)
  grid_ylat=np.arange(-90,90.26,0.5)
  grid_xlon=np.arange(-0,360,0.5)


  time = ds.createDimension('time', len(timesteps))
  lon = ds.createDimension('lon', len(grid_xlon))
  lat = ds.createDimension('lat', len(grid_ylat))

  times = ds.createVariable('time', 'f4', ('time',))
  lons = ds.createVariable('lon', 'f4', ('lon',))
  lats = ds.createVariable('lat', 'f4', ('lat',))

  times[:] = timesteps
  lons[:] = grid_xlon
  lats[:] = grid_ylat


  values = ds.createVariable('WCB_grid', 'f4', ('lat','lon',))
  values[:,:] = griddata

  frac = ds.createVariable('WCB_frac', 'f4', ('time',))
  frac[:] = wcb_frac

  ds.close()
