import numpy as np
import sys
import os
import re
import xarray as xr #cite as: Hoyer, S. & Hamman, J., (2017). xarray: N-D labeled Arrays and Datasets in Python. Journal of Open Research Software. 5(1), p.10. DOI: http://doi.org/10.5334/jors.148
import gc
import glob
import copy
#from sklearn.neighbors import BallTree
import matplotlib.pyplot as plt
import netCDF4 as nv
from scipy.interpolate import interp1d,interp2d,griddata
from scipy.interpolate import RectBivariateSpline as rbs
from scipy.interpolate import RegularGridInterpolator as rgi
from scipy.optimize import brentq, curve_fit, minimize_scalar, minimize, Bounds
from scipy.stats import binned_statistic_2d, binned_statistic_dd, binned_statistic
import pickle
import cartopy.crs as ccrs
from matplotlib import rcParams

rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.weight'] = 'normal'
rcParams['font.size'] = 20
rcParams['xtick.labelsize'] = 20
rcParams['ytick.labelsize'] = 20
rcParams['axes.labelsize'] = 22
rcParams['axes.linewidth'] = 1
rcParams['axes.unicode_minus'] = False
rcParams['xtick.minor.width'] = 1
rcParams['ytick.minor.width'] = 1
rcParams['xtick.major.width'] = 1
rcParams['ytick.major.width'] = 1
rcParams['xtick.minor.size'] = 4
rcParams['ytick.minor.size'] = 4
rcParams['xtick.major.size'] = 5
rcParams['ytick.major.size'] = 5
rcParams['xtick.direction'] = 'in'
rcParams['ytick.direction'] = 'in'
rcParams['mathtext.fontset'] = 'cm'


import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as colors
from matplotlib import gridspec
from matplotlib import ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import NullFormatter
from collections import OrderedDict, Counter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle, Ellipse
import matplotlib.patheffects as pe

from math import radians, cos, sin, asin, sqrt


#----------Make it print------------------
class Unbuffered(object):
   def __init__(self, stream):
       self.stream = stream
   def write(self, data):
       self.stream.write(data)
       self.stream.flush()
   def writelines(self, datas):
       self.stream.writelines(datas)
       self.stream.flush()
   def __getattr__(self, attr):
       return getattr(self.stream, attr)

sys.stdout = Unbuffered(sys.stdout)


def haversine(lon1,lat1,lon2,lat2):
  # convert decimal degrees to radians 
  lo1 = np.radians(lon1)
  lo2 = np.radians(lon2)
  la1 = np.radians(lat1)
  la2 = np.radians(lat2)

  # haversine formula 
  dlon = lo2 - lo1 
  dlat = la2 - la1 
  a = np.sin(dlat/2)**2 + np.cos(la1) * np.cos(la2) * np.sin(dlon/2)**2
  c = 2. * np.arcsin(np.sqrt(a)) 
  r = 6371. # Radius of earth in km.
  return c * r
"""
------------------------------------------

------------------------------------------

-----------Notes--------------------------------------------------------
"""

class PartInfo():
  """
  Fields in partposit files
  ---------------------------------
  xlon = latitude (degrees)
  ylat = longitude (degrees)
  z = heigth (meters)
  topo = topography
  tro = tropopause
  hmix = PBL height
  rho = density
  qv = specific humidity
  pv = potential vorticity
  tt = temperature
  uu = 
  vv =
  """ 

  def __init__(self, basepath="/jetfs/scratch/lbakels/LARA_zarr/194001-194703/",
    year="1940",month="01",interpolatetopo=True):

    self.path_topo = basepath
    self.path = f"{basepath}/{year}/{month}"

    #2D and 3D fields interpolation (if required)
    self._lon = None
    self._lat = None
    self._z = None
    self._prs = None

    self.ds = {}
    self.ds['lon'] = self.lon
    self.ds['lat'] = self.lat
    self.ds['z'] = self.z
    self.ds['prs'] = self.prs

  @property
  def lon(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return self._lon.lon

  @property
  def lat(self):
    if self._lat is None:
      self._lat = xr.open_dataset(self.path+'/lat',engine="zarr")
    return self._lat.lat

  @property
  def z(self):
    if self._z is None:
      self._z = xr.open_dataset(self.path+'/z',engine="zarr")
    return self._z.z

  @property
  def prs(self):
    if self._prs is None:
      self._prs = xr.open_dataset(self.path+'/prs',engine="zarr")
    return self._prs.prs

  @property
  def qv(self):
    if self._qv is None:
      self._qv = xr.open_dataset(self.path+'/sh',engine="zarr")
    return self._qv.sh

  @property
  def pv(self):
    if self._pv is None:
      self._pv = xr.open_dataset(self.path+'/pv',engine="zarr")
    return self._pv.pv

  @property
  def tt(self):
    if self._tt is None:
      self._tt = xr.open_dataset(self.path+'/T',engine="zarr")
    return self._tt.T

  @property
  def time(self):
    if self._lon is None:
      self._lon = xr.open_dataset(self.path+'/lon',engine="zarr")
    return np.array(self._lon.time)

  @property 
  def day(self):
    return self.time.dt.day.values
  @property 
  def year(self):
    return self.time.dt.year.values
  @property 
  def month(self):
    return self.time.dt.month.values
  @property
  def hour(self):
    return self.time.dt.hour.values


def convert_time(times):
  return times[0]*1000000+times[1]*10000+times[2]*100+times[3]

def find_hours(start_time=[1970,1,1,0],end_time=[1971,2,2,2]):
  dd = np.array([31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ddly = np.array([31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ly = np.arange(1940, 2024, 4).astype(int)
  years=np.arange(start_time[0],end_time[0]+1,1)

  if start_time[0]==end_time[0]:
    if start_time[0] not in ly:
      return ((np.sum(dd[start_time[1]-1:end_time[1]]) - #month
        start_time[2]+1-(dd[end_time[1]-1]-end_time[2]))*24 - #days
        start_time[3]-(24-end_time[3])) #hours
    else:
      return ((np.sum(ddly[start_time[1]-1:end_time[1]]) - #month
        start_time[2]+1-(ddly[end_time[1]-1]-end_time[2]))*24 - #days
        start_time[3]-(24-end_time[3])) #hours

  total_hours=0

  for y in years:
    if y == start_time[0]:
      if y not in ly:
        total_hours += ((np.sum(dd[start_time[1]-1:]) - #month
          start_time[2]+1)*24 - #days
          start_time[3]) #hours
      else:
        total_hours += ((np.sum(ddly[start_time[1]-1:]) - #month
          start_time[2]+1)*24 - #days
          start_time[3]) #hours      
    elif y == end_time[0]:
      if y not in ly:
        total_hours += ((np.sum(dd[:end_time[1]]) - #month
          (dd[end_time[1]-1]-end_time[2]))*24 - #days
          (24-end_time[3])) #hours
      else:
        total_hours += ((np.sum(ddly[:end_time[1]]) - #month
          (dd[end_time[1]-1]-end_time[2]))*24 - #days
          (24-end_time[3])) #hours
      return total_hours
    else:
      if y not in ly:
        total_hours += np.sum(dd)*24
      else:
        total_hours += np.sum(ddly)*24

  return total_hours

def find_end_time(start_time=[1970,1,1,0],hours=24):
  dd = np.array([31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ddly = np.array([31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
  ly = np.arange(1940, 2024, 4).astype(int)
  
  if hours<0:
    return find_start_time(start_time=start_time, hours=hours)

  days=int((hours+start_time[3])/24.)
  leftover_hours=np.mod((hours+start_time[3]),24)
  if start_time[0] not in ly:
    if (start_time[2]+days)<=dd[start_time[1]-1]:
      return [start_time[0],start_time[1],start_time[2]+days,leftover_hours]
    else:
      days = days - (dd[start_time[1]-1]-start_time[2])
  else:
    if (start_time[2]+days)<=ddly[start_time[1]-1]:
      return [start_time[0],start_time[1],start_time[2]+days,leftover_hours]  
    else:
      days = days - (ddly[start_time[1]-1]-start_time[2])  

  month=start_time[1]+1
  for year in range(start_time[0],2024):
    if year not in ly:
      for i_m in range(month,13):
        if (days-dd[i_m-1])<0:
          return [year,i_m,days,leftover_hours]
        else:
          days=days-dd[i_m-1]
    else:
      for i_m in range(month,13):
        if (days-ddly[i_m-1])<0:
          return [year,i_m,days,leftover_hours]
        else:
          days=days-ddly[i_m-1]      
    month=1

  print('Warning: number of hours exceeds 2024-03-31:24')
  return [2024,3,31,24]

def select_files_LARA(path_to_directory='/jetfs/scratch/lbakels/LARA/197901-198503/',
  year='1979'):
  
  pd = {}

  months_dir = os.listdir(path_to_directory+'/'+year)
  
  for month in months_dir:
    pd[f'{path_to_directory}/{month}'] = PartInfo(basepath=path_to_directory,
      year="%i" %s_time[0], month="%02i" %s_time[1])
    pd_temp = pd[f'{path_to_directory}/{month}']
    pd_times[f'{path_to_directory}/{month}'] = pd_temp.year*1000000+pd_temp.month*10000+pd_temp.day*100+pd_temp.hour

  return pd,pd_times

def conservation_grid(path_to_directory='/jetfs/scratch/lbakels/LARA/197901-198503/output/',
  start_time=[1980,1,1,0],end_time=[1980,1,2,0]):

  grid_ylat=np.arange(-90.25,90.26,0.5)
  grid_xlon=np.arange(-0.25,360,0.5)
  grid_z=np.array([0.,1000.,5000.,10000.,20000.])

  griddata_qv = None

  pd,pd_times = select_files_LARA(path_to_directory=path_to_directory,
    start_time=start_time,
    end_time=end_time)

  if pd is None:
    return

  t_assimilation=12 #hours

  s_time=convert_time(start_time)
  e_time=convert_time(end_time)

  total_hours = find_hours(start_time,end_time)
  all_files=np.array(list(pd.keys()))
  isteps=0
  for itime in range(0, total_hours, t_assimilation):
    isteps=isteps+1
    s_time=convert_time(find_end_time(np.array(start_time),itime))
    e_time=convert_time(find_end_time(np.array(start_time),itime+1)) #assimilation window
    for i_file in range(len(all_files)):
      file=all_files[i_file]
      #Select the two timesteps
      itimes = np.where((pd_times[file]>=s_time)&(pd_times[file]<=e_time))[0]
      if len(itimes)!=2:
        continue
      #print('timesteps:', itimes)

      # Taking the first position to grid the data

      lons = pd[file].lon[:,0].values

      wnumpart = np.where((lons<0)|(np.isnan(lons))|(lons>360))[0]
      if len(wnumpart)>0:
        numpart = wnumpart[0]
      else:
        numpart = len(lons)

      lons = pd[file].lon[:numpart,itimes[0]].values
      lats = pd[file].lat[:numpart,itimes[0]].values
      zs = pd[file].z[:numpart,itimes[0]].values

      qv = pd[file].qv[:numpart,itimes].values
      pv = pd[file].pv[:numpart,itimes].values
      print(file, numpart, qv.shape)
      #Compute theta
      t = pd[file].tt[:numpart,itimes].values
      p = pd[file].prs[:numpart,itimes].values
      rh = 0.263*p*qv*(np.exp((17.67*(t-273.15))/(t-29.65)))**(-1)
      theta = t*(100000./p)**(0.286)


      if griddata_qv is None:
        griddata_qv = binned_statistic_dd([lats,lons,zs],
          np.abs(qv[:,1]-qv[:,0]), statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]
        griddata_pv = binned_statistic_dd([lats,lons,zs],
          np.abs(pv[:,1]-pv[:,0]), statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]
        griddata_th = binned_statistic_dd([lats,lons,zs],
          np.abs(theta[:,1]-theta[:,0]), statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]

        histdata = np.histogramdd([lats,lons,zs],
          bins=[grid_ylat, grid_xlon, grid_z])[0]
      else:
        griddata_qv = griddata_qv + binned_statistic_dd(
          [lats,lons,zs], np.abs(qv[:,1]-qv[:,0]),statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]
        griddata_pv = griddata_pv + binned_statistic_dd(
          [lats,lons,zs], np.abs(pv[:,1]-pv[:,0]),statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]
        griddata_th = griddata_th + binned_statistic_dd(
          [lats,lons,zs], np.abs(theta[:,1]-theta[:,0]),statistic='sum',
          bins=[grid_ylat, grid_xlon, grid_z])[0]

        histdata = histdata + np.histogramdd([lats,lons,zs],
          bins=[grid_ylat, grid_xlon, grid_z])[0]


  grid_qv = griddata_qv
  grid_pv = griddata_pv
  grid_th = griddata_th

  grid_qv[np.isnan(grid_qv)]=0.
  grid_qv[np.isinf(grid_qv)]=0.
  grid_pv[np.isnan(grid_pv)]=0.
  grid_pv[np.isinf(grid_pv)]=0.
  grid_th[np.isnan(grid_th)]=0.
  grid_th[np.isinf(grid_th)]=0.

  files=list(pd.keys())
  for file in files:
    pd[file]._pd.close()
    del pd[file]
  del pd
  gc.collect()

  return grid_qv, grid_pv, grid_th, histdata, grid_ylat, grid_xlon, grid_z

def compute_conservation_properties(sel_dict):

  grid_qv = {}
  grid_pv = {}
  grid_th = {}
  histdata = {}
  #Assimilation window
  print('Find assimilation window')
  grid_qv['assim'], grid_pv['assim'], grid_th['assim'], histdata['assim'], grid_ylat, grid_xlon, grid_z = conservation_grid(
    path_to_directory=sel_dict['paths_to_directories'],
    start_time=sel_dict['period'][0],
    end_time=sel_dict['period'][1])

  #Assimilation window + 1 hour
  s_time = sel_dict['period'][0]
  s_time[3] += 1
  grid_qv['assim+1'], grid_pv['assim+1'], grid_th['assim+1'], histdata['assim+1'], grid_ylat, grid_xlon, grid_z = conservation_grid(
    path_to_directory=sel_dict['paths_to_directories'],
    start_time=s_time,
    end_time=sel_dict['period'][1])

  #Assimilation window - 1 hour
  s_time[3] -= 2
  grid_qv['assim-1'], grid_pv['assim-1'], grid_th['assim-1'], histdata['assim-1'], grid_ylat, grid_xlon, grid_z = conservation_grid(
    path_to_directory=sel_dict['paths_to_directories'],
    start_time=s_time,
    end_time=sel_dict['period'][1])

  return grid_qv,grid_pv,grid_th,histdata

def save_netcdf(grid_qv,grid_pv,grid_th,histdata,file_name, sel_dict):

  ds = nv.Dataset(file_name, 'w', format='NETCDF4')
  grid_ylat=np.arange(-90,90.26,0.5)
  grid_xlon=np.arange(-0,360,0.5)
  grid_z=np.array([1000.,5000.,10000.,20000.])

  hours=find_hours(sel_dict['period'][0],sel_dict['period'][1])

  lon = ds.createDimension('lon', len(grid_xlon))
  lat = ds.createDimension('lat', len(grid_ylat))
  z = ds.createDimension('z', len(grid_z))

  lons = ds.createVariable('lon', 'f4', ('lon',))
  lats = ds.createVariable('lat', 'f4', ('lat',))
  zs = ds.createVariable('z', 'f4', ('z',))

  lons[:] = grid_xlon
  lats[:] = grid_ylat
  zs[:] = grid_z

  values = ds.createVariable('qvA', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_qv['assim'][:,:,:]
  values = ds.createVariable('qvA-1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_qv['assim-1'][:,:,:]
  values = ds.createVariable('qvA+1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_qv['assim+1'][:,:,:]

  values = ds.createVariable('pvA', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_pv['assim'][:,:,:]
  values = ds.createVariable('pvA-1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_pv['assim-1'][:,:,:]
  values = ds.createVariable('pvA+1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_pv['assim+1'][:,:,:]

  values = ds.createVariable('thA', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_th['assim'][:,:,:]
  values = ds.createVariable('thA-1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_th['assim-1'][:,:,:]
  values = ds.createVariable('thA+1', 'f4', ('lat','lon','z'))
  values[:,:,:] = grid_th['assim+1'][:,:,:]

  values = ds.createVariable('NA', 'f4', ('lat','lon','z'))
  values[:,:,:] = histdata['assim'][:,:,:]
  values = ds.createVariable('NA-1', 'f4', ('lat','lon','z'))
  values[:,:,:] = histdata['assim-1'][:,:,:]
  values = ds.createVariable('NA+1', 'f4', ('lat','lon','z'))
  values[:,:,:] = histdata['assim+1'][:,:,:]

  ds.close()