import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
import glob
import re
import datetime
import cartopy.crs as ccrs
import sys
import matplotlib.colors as mcolors
from matplotlib import rcParams

# Usage: python plot_travel_time.py <output_path> <h_thresh> <startmonth YYmm> <endmonth YYmm>

def colormap_vintage():
    cmap = 'RdYlBu_r'
    cmap_values = plt.get_cmap(cmap)(np.linspace(0.05,1.,256))
    #darken cmap_values
    cmap_values = np.minimum(cmap_values*(1.0-np.sin(
        np.arange(len(cmap_values))*np.pi/(len(cmap_values)-1))\
        *0.05)[:,np.newaxis],1.)
    white_weight = np.linspace(1.,0.,50)[:,np.newaxis]
    gamma = 1/0.8
    white_values = np.array([1.0,1.0,1.0,0.0])*white_weight**gamma \
        + cmap_values[0,:]*(1-white_weight**gamma)
    new_cmap_values = np.vstack([white_values, cmap_values])
    newcmap = mcolors.LinearSegmentedColormap.from_list(
        cmap+'_white100',new_cmap_values)

    return newcmap

# read commandline arguments
output_path = sys.argv[1]
h_thresh = float(sys.argv[2])
startmonth = datetime.datetime.strptime(sys.argv[3], '%Y%m')
endmonth = datetime.datetime.strptime(sys.argv[4], '%Y%m')

# find the LARA directories
directories = os.listdir(output_path)
dir_period_tmp = [(re.findall(r'\d{6}-\d{6}',dir1)) for dir1 in directories]
dir_period = np.array(sorted([i for i in dir_period_tmp if len(i)>0]))[::-1,0]

# find the right directory
selected_dir = None
for dir1 in dir_period:
    dir1_startmonth = datetime.datetime.strptime(dir1.split('-')[0], '%Y%m')
    dir1_endmonth = datetime.datetime(int(dir1.split('-')[1][-6:-2]), 12, 1)
    if dir1_startmonth <= startmonth and dir1_endmonth >= endmonth:
        selected_dir = dir1
        break

if selected_dir is None:
    print('No directory found. Check your dates.')

# load data
ncfile1 = xr.open_dataset(output_path + '/' + selected_dir + '/' + startmonth.strftime('%Y/travel_time_%Y%m.nc')).load()
ncfile2 = xr.open_dataset(output_path + '/' + selected_dir + '/' + endmonth.strftime('%Y/travel_time_%Y%m.nc')).load()

if startmonth == dir1_startmonth: # initial values to zero
    ncfile1['travel_time'].values = np.zeros_like(ncfile1['travel_time'].values)
    ncfile1['counter'].values = np.zeros_like(ncfile1['counter'].values)

if h_thresh is not None:
    h_ind = ncfile1['z'].values <= h_thresh
    levels = np.arange(0, 52, 2)
else:
    h_ind = ncfile1['z'].values <= 1000.
    levels = np.arange(1, 10.5, 0.5)

travel_time = np.sum(ncfile2['travel_time'].values[h_ind] - ncfile1['travel_time'].values[h_ind], axis=0) / 24.
counter = np.sum(ncfile2['counter'].values[h_ind] - ncfile1['counter'].values[h_ind], axis=0)

# plot
rcParams['ps.useafm'] = True
rcParams['pdf.use14corefonts'] = False
rcParams['text.usetex'] = False
rcParams['font.sans-serif'] = ['cmr10', 'Times-Roman']
rcParams['font.size'] = 22

proj = ccrs.PlateCarree(central_longitude=0, globe=None)

fig = plt.figure(figsize=(20,8))
ax = fig.add_subplot(111, projection=proj)
ax.coastlines()
ax.set_global()

# grid and labels
ax.gridlines(draw_labels=False, color='grey', linewidth=0.5, linestyle='--', ylocs=np.arange(-90, 91, 30), xlocs=np.arange(-180, 181, 60))
ax.set_xticks([-120,-60,0,60,120], crs=ccrs.PlateCarree())
ax.set_xticklabels(labels=['120$^{\\circ}$ W','60$^{\\circ}$ W','0$^{\\circ}$','60$^{\\circ}$ E','120$^{\\circ}$ E'])
ax.set_xlabel('Longitude',fontsize=24)
ax.set_yticks([-60,-30, 0, 30,60], crs=ccrs.PlateCarree())
ax.set_yticklabels(labels=['60$^{\\circ}$ S','30$^{\\circ}$ S','0$^{\\circ}$','30$^{\\circ}$ N','60$^{\\circ}$ N'])
ax.set_ylabel('Latitude',fontsize=24)

# values
plt.contourf(ncfile1['lon'], ncfile1['lat'], travel_time/counter, cmap=colormap_vintage(), levels=levels, extend='max', transform=ccrs.PlateCarree())
plt.colorbar(label='Travel time (days)')

# save
plt.savefig(f'figures/travel_time_{h_thresh:.0f}m.png', bbox_inches='tight', dpi=120)