import re
#from multiprocessing import Process
import pickle
import numpy as np
import sys
import os
import re
# import eccodes
# import cfgrib
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import zarr
import gc
import glob
import copy
import netCDF4 as nv
#from sklearn.neighbors import BallTree
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic


#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  """ 

  def __init__(self, basepath="/jetfs/scratch/lbakels/LARA_zarr/",
    period="194001-194703",year="1940",month="01",
    interpolatetopo=True):

    self.path_topo = basepath
    self.path = f"{basepath}/{period}/{year}/{month}"

    #2D and 3D fields interpolation (if required)
    self.topography = None
    self.tropopause = None
    self.hmixing = None
    self.interpolatetopo = interpolatetopo

    self._lon = None
    self._lat = None
    self._z = None
    self._lon_av = None
    self._lat_av = None
    self._z_av = None
    self._topo = None
    self._tro = None
    self._hmix = None

    self.ds = {}
    self.ds['lon'] = self.lon
    self.ds['lon_av'] = self.lon_av
    self.ds['lat'] = self.lat
    self.ds['lat_av'] = self.lat_av
    self.ds['z'] = self.z
    self.ds['z_av'] = self.z_av

  @property
  def lon(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return self._lon.lon

  @property
  def lon_av(self):
    if self._lon_av is None:
      self._lon_av = xr.open_dataset(self.path+'/lon_av',engine="zarr")
    return self._lon_av.lon_av

  @property
  def lat(self):
    if self._lat is None:
      self._lat = xr.open_dataset(self.path+'/lat',engine="zarr")
    return self._lat.lat

  @property
  def lat_av(self):
    if self._lat_av is None:
      self._lat_av = xr.open_dataset(self.path+'/lat_av',engine="zarr")
    return self._lat_av.lat_av

  @property
  def z(self):
    if self._z is None:
      self._z = xr.open_dataset(self.path+'/z',engine="zarr")
    return self._z.z

  @property
  def z_av(self):
    if self._z_av is None:
      self._z_av = xr.open_dataset(self.path+'/z_av',engine="zarr")
    return self._z_av.z_av

  @property
  def _topo1(self):
    if self._topo is None:
      self._topo = xr.open_dataset(self.path_topo+'/to',engine="zarr").to
    return self._topo

  @property
  def topo(self):
    if self._topo is None:
      self._topo = self._topo1
    if self.interpolatetopo:
      if self.topography is None:
        topo_inter = self.interpolate2d(self._topo,0)
        #allocate memory for topography per particle
        self.topography = np.zeros((len(self.lon),len(self.lon[0,:])))
        #For each timestep, interpolate the position of the particle to the grid
        topo_inter = self.interpolate2d(self._topo,0)
        for itime in range(len(self.topography[0,:])):          
          self.topography[:,itime]  = topo_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
          mini=np.min(self._topo).values
          maxi=np.max(self._topo).values
          self.topography[self.topography[:,itime]<mini,itime]=mini
          self.topography[self.topography[:,itime]>maxi,itime]=maxi
      return self.topography
    else:
      return self._topo

  @property
  def tro(self):
    if self._tro is None:
      self._tro = xr.open_dataset(self.path+'/tro',engine="zarr").tro
    # if len(self._tro.dims)==3:
    #   if self.tropopause is None:
    #     self.tropopause = np.zeros((len(self.xlon),len(self.xlon[0,:])))
    #     for itime in range(len(self.tropopause[0,:])):
    #       tro_inter = self.interpolate2d(self._tro,itime)
    #       self.tropopause[:,itime] = tro_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
    #       mini=np.min(self._tro).values
    #       maxi=np.max(self._tro).values
    #       self.tropopause[self.tropopause[:,itime]<mini,itime]=mini
    #       self.tropopause[self.tropopause[:,itime]>maxi,itime]=maxi
    #   return self.tropopause
    # else:
    return self._tro

  @property
  def hmix(self):
    if self._hmix is None:
      self._hmix = xr.open_dataset(self.path+'/hmix',engine="zarr").hmix
    # if len(self._hmix.dims)==3:
    #   if self.hmixing is None:
    #     self.hmixing = np.zeros((len(self.xlon),len(self.xlon[0,:])))
    #     for itime in range(len(self.hmixing[0,:])):
    #       hmix_inter = self.interpolate2d(self._hmix,itime)
    #       self.hmixing[:,itime] = hmix_inter(self.lat[:,itime],self.lon[:,itime],grid=False)
    #       mini=np.min(self._hmix).values
    #       maxi=np.max(self._hmix).values
    #       self.hmixing[self.hmixing[:,itime]<mini,itime]=mini
    #       self.hmixing[self.hmixing[:,itime]>maxi,itime]=maxi
    #   return self.hmixing
    # else:
    return self._hmix

  def interpolate2d(self,field,itime):
    lon=np.arange(0.,360.5,0.5)
    lat=np.arange(-90,90.5,0.5)
    if len(field.dims)==2:
      info = field[:,:].values
    else:
      info = field[:,:,itime].values
    return rbs(lat,lon,info)

def convert_time(times):
  return times[0]*1000000+times[1]*10000+times[2]*100+times[3]

def find_particles_LARA(path_to_directory='/jetfs/scratch/lbakels/LARA/197901-198503/output/',
  period="194001-194703",year=1940,month="03",hours_per_day=1,lat=np.array([10,15])):

  pd = PartInfo(basepath=path_to_directory,period=period,year=year,month=month)

  mask = np.zeros(len(lat))
  mask[0]=1

  files=np.zeros(0).astype(str)

  lat_temp = pd.lat.load()

  # Selecting particles that are in the lat range
  ypart = np.digitize(lat_temp, lat)
  ipart,start_times = np.where(mask[ypart-1] == 1)

  # Only use the first instance that it is within the lat range
  ipart_unique, index_unique = np.unique(ipart, return_index=True)
  index_unique=np.sort(index_unique)
  ipart = ipart[index_unique]
  start_times = start_times[index_unique]

  pd.lat.close()
  del pd
  gc.collect()
  return ipart, start_times

def select_particles(pdtemp,field2,i_file,output,start_times_n,i_output,i_t,n_t_prev=0):
  particle_idx = xr.DataArray(output['ipart'][i_file].astype(int),dims="z")
  time_idx = xr.DataArray((start_times_n[i_file]+i_t-n_t_prev).astype(int),dims="z")
  output[field2][i_file,i_output]=pdtemp.isel(particle=particle_idx,time=time_idx).values


def convert_fields2load(pdata, fields_to_load):
  fields_new = []
  if 'topo' not in pdata._pd.variables:
    for field in fields_to_load:
      if field=='longitude':
        fields_new.append('lon')
      elif field=='latitude':
        fields_new.append('lat')
      elif field=='height':
        fields_new.append('z')
      elif field=='pr':
        fields_new.append('prs')
      elif field=='temperature':
        fields_new.append('T')
      else:
        fields_new.append(field)
    return fields_new
  else:
    return fields_to_load

def select_data_LARA(pd_n,path_to_directory_n,ipart_n,start_times_n,
  n_timesteps_n,t_interval,fields_to_load):

  #print("Number of cpus : ", mp.cpu_count())

  list_of_files=list(pd_n.keys())
  # print('Pre-loading xarray data from %i files' %len(list_of_files))
  # for file in list_of_files:
  #   print('File:', file)
  #   for field in fields_to_load:
  #     pd_n[file]._pd[field].load()
  # print('Loading finished')

  timesteps=np.arange(0,n_timesteps_n,t_interval).astype(int)

  output = {}
  for field in fields_to_load:
    output[field] = np.zeros((len(ipart_n),len(timesteps)))
    if 'topo' in fields_to_load:
      output['topo'] = np.zeros((len(ipart_n),len(timesteps)))

  output['ipart'] = ipart_n
  output['start_time_index'] = {}
  output['start_time_index'][list_of_files[0]] = start_times_n
  for i in range(1,len(list_of_files)):
    output['start_time_index'][list_of_files[i]] = np.zeros(len(ipart_n)).astype(int)



  #files_n and start_times_n have the same length. 
  # Each particle is associated to a file in files_n
  # and the first time it needs to be considered is given
  # in start_times_n.
  i_file=0
  file=list_of_files[i_file]
  print(file)

  if 'to' in fields_to_load:
    topo=pd_n[file].topo
    parts = pd_n[file].lat.particle
    time=pd_n[file].lat.time
    print('bla', len(topo),len(topo[0]), len(time), len(parts))
    pdtemp = xr.DataArray(topo, coords=[parts,time], dims=["particle","time"])

  # Follow these particles in the current and next files
  n_t = len(pd_n[file].lon[0,:])
  i_output=-1
  for i_t in timesteps:
    print('hour: ',i_t)
    i_output=i_output+1

    #Find particles present in this file and how many are still being tracked in the next
    i_this_file = np.where(output['start_time_index'][file]+i_t<n_t)[0]
    i_next_file = np.where(output['start_time_index'][file]+i_t>=n_t)[0]

    #Write data to output dictionary
    for i_field in range(len(fields_to_load)):
      field=fields_to_load[i_field]
      if field=='to':
        select_particles(pdtemp,'to',i_this_file,output,
          output['start_time_index'][file],i_output,i_t)
      else:
        pd_n[file].ds[field].load()
        select_particles(pd_n[file].ds[field],fields_to_load[i_field],
          i_this_file,output,output['start_time_index'][file],i_output,i_t)

    # procs = []
    # for field in fields_new:
    #   pd_n[file]._pd[field].load()
    #   proc = mp.Process(target = select_particles, args=(file,field,i_this_file,output,start_times_n,i_t,))
    #   procs.append(proc)
    #   proc.start()

    # for proc in procs:
    #   proc.join()

    if len(i_next_file) == 0:
      continue
     #some particles have the next timestep in the next file
    n_t_next = n_t #counting timesteps in each file
    n_t_prev = n_t
    for i_file_next in range(i_file+1,len(list_of_files)):
      file_next = list_of_files[i_file_next]
      n_t_next += len(pd_n[file_next].lon[0,:])
      i_next_file_temp = i_next_file[np.where((start_times_n[i_next_file]+i_t>=n_t_prev)&
        (start_times_n[i_next_file]+i_t<n_t_next))[0]]
      #print(file, i_t)
      if len(i_next_file_temp) == 0:
        n_t_prev = n_t_next
        continue
      if 'to' in fields_to_load:
        topo=pd_n[file_next].topo
        parts = pd_n[file_next].lat.particle
        time=pd_n[file_next].lat.time
        pdtemp = xr.DataArray(topo, coords=[parts,time], dims=["particle","time"])
      for i_field in range(len(fields_to_load)):
        field=fields_to_load[i_field]
        if field=='to':
          select_particles(pdtemp,'to',i_next_file_temp,output,output['start_time_index'][file],
            i_output,i_t,n_t_prev)
        else:
          pd_n[file_next].ds[field].load()
          select_particles(pd_n[file_next].ds[field],fields_to_load[i_field],
            i_next_file_temp,output,output['start_time_index'][file],i_output,i_t,n_t_prev)
      # procs = []
      # for field in fields_new:
      #   pd_n[file_next]._pd[field].load()
      #   proc = mp.Process(target=select_particles,args=(file_next,field,i_next_file_temp,output,start_times_n,i_t,n_t_prev,))
      #   procs.append(proc)
      #   proc.start()
        #if field=='height':
          #print(i_t,i_next_file_temp[-1])

      # for proc in procs:
      #   proc.join()

      n_t_prev = n_t_next
      
  for file in list_of_files:
    for field in fields_to_load:
      pd_n[file].ds[field].close()
    del pd_n[file]
  del pd_n
  gc.collect()

  return output

def select_particles_grid(sel_dict):
  '''
  Function that selects and organises data from the LARA dataset using the HILDA+ dataset.
  
  Parameters
  ----------
  sel_dict : dictionary
      'paths_to_directories' : list of 2 str
      'period' : 2D array of ints
        Selection period [[Year_i, Month_i, Day_i, Hour_i],[Year_f, Month_f, Day_f, Hour_f]]
      'period_select_hours' : int
        Number of hours particles are being selected after the initial hour each day within the selection period
      'period_analyse hours' : int
        Number of hours particles are being tracked after they are selected
      'time_interval' : int
        Data selection interval
      'f_forestloss' : float
        Minimum pc of forest loss between the two periods for a gridcell to be selected
      'fields_to_load' : list of str
        Names of desired LARA fields to load
  i_lon : array of floats
      Array of longitudes where there is deforestation
  i_lat : array of floats
      Array of corresponding latitudes where there is deforestation
      
  Returns
  -------
  output_1 : dict of 2D arrays
      For each 'fields_to_load', a 2D array is written with dim [particle, time] for period 1
  output_2 : dict of 2D arrays
      For each 'fields_to_load', a 2D array is written with dim [particle, time] for period 2
  '''
  lat=np.array([-0.1,0.1])

  print('Select particles')
  ipart_1,start_times_1=find_particles_LARA(path_to_directory=sel_dict['paths_to_directories'],
    period=sel_dict['period'],year=sel_dict['year'],month=sel_dict['month'],hours_per_day=sel_dict['period_select_hours'],
    lat=lat)

  pd_1 = {}
  pd_1[sel_dict['month']] = PartInfo(basepath=sel_dict['paths_to_directories'],
    period=sel_dict['period'],year=sel_dict['year'],month=sel_dict['month'])
  nmonth = "%02i" %([int(i) for i in sel_dict['month'].split() if i.isdigit()][0]+1)
  pd_1[nmonth] = PartInfo(basepath=sel_dict['paths_to_directories'],
    period=sel_dict['period'],year=sel_dict['year'],month=nmonth)

  print('Select data')
  output_1 = select_data_LARA(pd_1,sel_dict['paths_to_directories'],ipart_1,start_times_1,sel_dict['period_analyse_hours'],
    sel_dict['time_interval'],sel_dict['fields_to_load'])

  return output_1


def save_netcdf(output, file_name, sel_dict):

  ds = nv.Dataset(file_name, 'w', format='NETCDF4')
  
  timesteps=np.arange(0,sel_dict['period_analyse_hours'],sel_dict['time_interval']).astype(int)

  time = ds.createDimension('time', len(timesteps))
  part = ds.createDimension('part', len(output['ipart']))

  times = ds.createVariable('time', 'f4', ('time',))
  parts = ds.createVariable('part', 'f4', ('part',))
  tpart = ds.createVariable('start_time_index', 'f4', ('part',))

  times[:] = timesteps
  parts[:] = output['ipart']
  keys = np.array(list(output['start_time_index'].keys()))
  tpart[:] = output['start_time_index'][keys[0]]

  values = {}
  for field in sel_dict['fields_to_load']:
    values[field] = ds.createVariable(field, 'f4', ('part','time',))
    values[field][:,:] = output[field]

  ds.close()


def select_nc_files(years_all=(np.arange(1960,1996,1)).astype(int)):
  d_output={}
  for year in years_all:
    d_output[year] = nv.Dataset("Energy_%i.nc" %(year),"r")

  return d_output

def compute_surface_area(grid_xlon,grid_ylat):
  lon2d_1, lat2d_1 = np.meshgrid((grid_xlon[1:]+grid_xlon[:-1])/2., 
    (grid_ylat[1:]+grid_ylat[:-1])/2)

  lo = np.copy(grid_xlon)
  la = np.copy(grid_ylat)

  surface_area = np.zeros_like(lon2d_1)
  for i in range(len(lo)-1):
    a = haversine(lo[i],la[:-1],lo[i],la[1:])*1000.
    b = haversine(lo[i],la[1:],lo[i+1],la[1:])*1000.

    surface_area[:,i] = a*b
  return surface_area

def kg_to_mm_water(water_g, surface_area_m):
  #Assuming 1kg = 1L water
  masspart = 5.09256513E18/6.0e6 #kg
  w_m3 = water_g/1000.*masspart/1000. #gram to kg; L to m3
  w_mm = w_m3/surface_area_m*1000.
  return w_mm
