#%%
import numpy as np
import pandas as pd
import os
import pathlib
import datetime as dt

def get_time_intervall(start_date, end_date, path):
    
    doyStart = start_date.timetuple().tm_yday
    dateStart = start_date.year
    doyEnd = end_date.timetuple().tm_yday
    dateEnd = end_date.year

    res = []
    d = pathlib.Path(path)
    
    for entry in d.rglob('*'):
        if entry.is_file():
            if entry.parts[-1] != '.DS_Store':
                current_year = int(entry.parts[-2])
                current_doy = int(entry.stem[-3:])
                if (current_year >= int(dateStart)) and (current_year <= int(dateEnd)):
                    if (int(dateStart) == int(dateEnd)):
                        if((current_doy >= int(doyStart)) and (current_doy <= int(doyEnd))):
                            res.append(entry)
                    elif ((current_doy >= int(doyStart)) and (current_year < int(dateEnd)) or 
                            ((current_doy <= int(doyEnd)) and (current_year == int(dateEnd)))):
                        res.append(entry)
                    elif ((current_year > int(dateStart)) and (current_year < int(dateEnd))):
                        res.append(entry)
                        
    return res

def load_lvl1_counts(path):
    lvl1Names = ['year', 'doy', 'msod', 'min95', 'SC_EPOCH', 'G0', 'A00','A01','A02','A03','A04','A05',
                'B00','B01','B02','B03','B04','B05', 'C0','D0','E0','F0', 'P4GM','P4GR','P4S',
                'P8GM','P8GR','P8S', 'H4GM','H4GR','H4S1','H4S23', 'H8GM','H8GR','H8S1','H8S23',
                'E150','E300','E1300','E3000',  'INT', 'P25GM','P25GR','P25S', 'P41GM','P41GR','P41S',
                'H25GM','H25GR','H25S1','H25S23', 'H41GM', 'H41GR','H41S1','H41S23'
                ,'CT0:','CT1:','CT2:','CT3:','CT4:','CT5:', 'Bin0', 'Bin1', 'Bin2', 'Bin3', 'Bin4', 'Bin5', 'Bin6',
                'Bin7', 'Bin8', 'Bin9','Bin10', 'Bin11', 'Bin12', 'Bin13', 'Bin14', 'Bin15', 'Bin16',
                'Bin17', 'Bin18', 'Bin19','Bin20', 'Bin21', 'Bin22', 'Bin23', 'Bin24', 'Bin25', 'Bin26',
                'Bin27', 'Bin28', 'Bin29', 'Bin30', 'Bin31', 'stat1', 'stat2', 'stat3', 'stat4', 'stat5', 'lionDatL', 'EPHINdatL', 'flag']
    
    data = pd.read_csv(path, sep=' ', names=lvl1Names)
    data = data.loc[(data != 0).any(axis=1)] # drop rows with all entries are zero
    return data

def sci_dt_correction(path_to_file):
    data = load_lvl1_counts(path_to_file)
    cols = ['P4GM','P4GR','P4S', 'P8GM','P8GR','P8S', 'H4GM','H4GR','H4S1','H4S23', 'H8GM','H8GR','H8S1','H8S23',
            'E150','E300','E1300','E3000',  'INT', 'P25GM','P25GR','P25S', 'P41GM','P41GR','P41S', 'H25GM','H25GR',
            'H25S1','H25S23', 'H41GM', 'H41GR','H41S1','H41S23','CT0:','CT1:','CT2:','CT3:','CT4:','CT5:','A00',
            'A01','A02','A03','A04','A05', 'B00', 'B01','B02', 'B03', 'B04', 'B05', 'C0', 'D0', 'E0', 'F0', 'Bin0',
            'Bin1', 'Bin2', 'Bin3', 'Bin4', 'Bin5', 'Bin6', 'Bin7', 'Bin8', 'Bin9','Bin10', 'Bin11', 'Bin12', 'Bin13',
            'Bin14', 'Bin15', 'Bin16', 'Bin17', 'Bin18', 'Bin19','Bin20', 'Bin21', 'Bin22', 'Bin23', 'Bin24', 'Bin25',
            'Bin26', 'Bin27', 'Bin28', 'Bin29', 'Bin30', 'Bin31']
    data[cols] = data[cols].astype(float)
    # function of single counter dead time
    fit = np.load("/data/etph/hoerloeck/SOHO/scripts/single_count_fit.npy") 
    # function of coincidence counter dead time
    coinc_dt_fit = np.load('/data/etph/hoerloeck/SOHO/scripts/coinc_dead_time_fit.npy')
    x_vals = np.logspace(0, 6, int(1e5))
    y_vals = 10**np.polyval(fit, np.log10(x_vals))

    for index, row in data.iterrows():
        single_counters = np.zeros(16)
        # correction of single counters
        for i, det in enumerate(['A00','A01','A02','A03','A04','A05',
                                'B00', 'B01','B02', 'B03', 'B04', 'B05',
                                'C0', 'D0', 'E0', 'F0']):
            if row[det] > 0:
                idx = (np.argwhere(np.diff(np.sign(row[det]/60 - y_vals))))
                if idx.size > 0:
                    idx = idx[0][0]
                    single_counters[i] = (x_vals[idx] + x_vals[idx+1]) / 2
                    if single_counters[i] < row[det] / 60:
                        single_counters[i] = row[det] / 60
                    elif row[det] / 60 > 1e5:
                        single_counters[i] = np.nan
                else:
                    single_counters[i] = row[det] / 60
        data.loc[index, "A00":"F0"] = single_counters * 60

        # correction of coincidence dead time
        a_ges = single_counters[0:6].sum()
        coinc_dead_time = (a_ges - 10**np.polyval(coinc_dt_fit, np.log10(a_ges)))/a_ges
        if coinc_dead_time < 0:
            coinc_dead_time = 0
        data.loc[index, 'P4GM':'CT5:'] /= (1 - coinc_dead_time)

    return data


    

