import numpy as np
import xarray as xr
import datetime
import os
import glob
import re
import pickle
import scipy.stats
import sys
import time
from global_land_mask import globe

lara_path = sys.argv[1]
output_path = sys.argv[2]
selected_dir = sys.argv[3]
h_thresh = float(sys.argv[4])

print(selected_dir)

timestep = 1 # in hours

z_bins = np.array([0, 100, 500, 1000, 2000])
z = (z_bins[:-1]+z_bins[1:])/2.

lat_bins = np.arange(-90.25, 90.75, 0.5)
lat = (lat_bins[:-1]+lat_bins[1:])/2.

lon_bins = np.arange(-180.25, 180.75, 0.5)
lon = (lon_bins[:-1]+lon_bins[1:])/2.

files = sorted(glob.glob(lara_path+'/'+selected_dir+'/????/??/'))
print(files)

for f,file in enumerate(files):
    filedate = datetime.datetime(int(file.split('/')[-3]), int(file.split('/')[-2]), 1)

    # load data
    print(f'Loading data for file {f} ' + filedate.strftime('%Y-%m'), end=' ')
    cpu_time0 = time.process_time()
    filepath = lara_path + '/' + selected_dir + filedate.strftime('/%Y/%m/')
    lon_traj = xr.open_zarr(filepath + '/lon_av')['lon_av'].values
    lon_traj = ((lon_traj + 180) % 360) - 180
    lat_traj = xr.open_zarr(filepath + '/lat_av')['lat_av'].values
    z_traj = xr.open_zarr(filepath + '/z_av')['z_av'].values
    time_traj = xr.open_zarr(filepath + '/z_av').time
    num_traj = xr.open_zarr(filepath + '/z_av').particle
    cpu_time1 = time.process_time()
    print('finished in', (cpu_time1-cpu_time0)/60., 'minutes.')

    # increment travel time
    travel_time = np.zeros((len(num_traj), len(time_traj)))
    for n in range(len(time_traj)):
        print(n, end=' ')
        lsm = globe.is_land(lat_traj[:,n], lon_traj[:,n]) | (z_traj[:,n] > h_thresh)
        if n == 0 and f == 0: # first timestep of the file and first file
            tm = lsm
        else:
            tm = lsm * (prev_time + timestep)
        travel_time[:,n] = tm
        prev_time = tm

    # temporarily save prev_time
    os.makedirs(output_path+'/'+selected_dir+'/'+str(filedate.year), exist_ok=True)
    prev_time_file = glob.glob(output_path+'/'+selected_dir+'/prev_time*')
    with open(output_path+'/'+selected_dir+'/'+\
              '/prev_time_'+filedate.strftime('%Y%m')+'.pkl', mode='wb') as fhandle:
        pickle.dump(prev_time, fhandle)

    # grid
    print('\nGridding...')
    if f == 0: # initialize
        travel_time_gridded = np.zeros((len(z), len(lat), len(lon)))
        counter_gridded = np.zeros((len(z), len(lat), len(lon)))
    travel_time_gridded += scipy.stats.binned_statistic_dd(
        sample=[z_traj.ravel(), lat_traj.ravel(), lon_traj.ravel()],
        values=travel_time.ravel(),
        statistic='sum',
        bins=[z_bins, lat_bins, lon_bins]
    ).statistic
    counter_gridded += scipy.stats.binned_statistic_dd(
        sample=[z_traj.ravel(), lat_traj.ravel(), lon_traj.ravel()],
        values=None,
        statistic='count',
        bins=[z_bins, lat_bins, lon_bins]
    ).statistic

    # write to netcdf
    print('Writing to netcdf...')
    ds = xr.Dataset(
        data_vars=dict(
            travel_time=(['z', 'lat', 'lon'], travel_time_gridded),
            counter=(['z', 'lat', 'lon'], counter_gridded),
        ),
        coords=dict(
            lon=('lon', lon),
            lat=('lat', lat),
            z=('z',z),
        ),
    )
    ds.to_netcdf(output_path+'/'+selected_dir+'/'+str(filedate.year)+\
                 '/travel_time_'+filedate.strftime('%Y%m')+'.nc')

    # remove old prev_time
    if len(prev_time_file) > 0:
        os.remove(prev_time_file[0])

    print('Done!')