/*
   This file is part of the RELXILL model code.

   RELXILL is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   RELXILL is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.
   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.

    Copyright 2022 Thomas Dauser, Remeis Observatory & ECAP
*/

#include "Relphysics.h"
#include "Relreturn_Table.h"

extern "C" {
#include "relutility.h"
#include "xilltable_reading.h"
}

returnTable* cached_retTable = nullptr;

int global_rr_do_interpolation = 1;

/** create a new return table */
static returnTable* new_returnTable(int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    auto tab = (returnTable*)malloc(sizeof(returnTable));
    CHECK_MALLOC_RET_STATUS(tab, status, tab)

    return tab;
}

/** init a new and empty return table (structure will be allocated)  */
static void init_returnTable(returnTable* tab, int nspin, int* status)
{
    CHECK_STATUS_VOID(*status);

    assert(tab != nullptr);
    tab->nspin = nspin;
    tab->spin = nullptr;

    tab->retFrac = (tabulatedReturnFractions**)malloc(nspin * sizeof(tabulatedReturnFractions*));
    CHECK_MALLOC_VOID_STATUS(tab->retFrac, status)
}

void free_2d(double*** vals, int n1)
{
    if (*vals != nullptr)
    {
        for (int ii = 0; ii < n1; ii++)
        {
            free((*vals)[ii]);
        }
        free(*vals);
        *vals = nullptr;
    }
}

void delete_2d(double*** vals, int n1)
{
    if (*vals != nullptr)
    {
        for (int ii = 0; ii < n1; ii++)
        {
            delete[] (*vals)[ii];
        }
        delete[] *vals;
        *vals = nullptr;
    }
}


static void free_returnFracData(tabulatedReturnFractions* dat)
{
    if (dat != nullptr)
    {
        free(dat->f_bh);
        free(dat->f_inf);
        free(dat->f_ret);

        free(dat->rlo);
        free(dat->rhi);

        free_2d(&(dat->frac_e), dat->nrad);
        free_2d(&(dat->tf_r), dat->nrad);

        free_2d(&(dat->gmin), dat->nrad);
        free_2d(&(dat->gmax), dat->nrad);

        if (dat->frac_g != nullptr)
        {
            for (int ii = 0; ii < dat->nrad; ii++)
            {
                free_2d(&(dat->frac_g[ii]), dat->nrad);
            }
            free(dat->frac_g);
        }

        free(dat);
    }
}

static void free_returnTable(returnTable** tab)
{
    if (*tab != nullptr)
    {
        if ((*tab)->spin != nullptr)
        {
            for (int ii = 0; ii < (*tab)->nspin; ii++)
            {
                free_returnFracData((*tab)->retFrac[ii]);
            }
            free((*tab)->spin);
            free((*tab)->retFrac);
        }
        free(*tab);
    }

    *tab = nullptr;
}

void free_cached_returnTable()
{
    free_returnTable(&cached_retTable);
}

static tabulatedReturnFractions* new_returnFracData(int nrad, int ng, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    tabulatedReturnFractions* dat = (tabulatedReturnFractions*)malloc(sizeof(tabulatedReturnFractions));
    CHECK_MALLOC_RET_STATUS(dat, status, dat)

    dat->nrad = nrad;
    dat->ng = ng;

    dat->f_bh = nullptr;
    dat->f_ret = nullptr;
    dat->f_inf = nullptr;

    dat->frac_e = nullptr;
    dat->tf_r = nullptr;
    dat->frac_g = nullptr;

    dat->gmin = nullptr;
    dat->gmax = nullptr;

    dat->rlo = nullptr;
    dat->rhi = nullptr;

    return dat;
}

void fits_moveToExtension(char* extname, fitsfile* fptr, int* status)
{
    int extver = 0;
    fits_movnam_hdu(fptr, BINARY_TBL, extname, extver, status);
    if (*status != EXIT_SUCCESS)
    {
        printf(" *** error moving to extension %s\n", extname);
    }
}

/** read one axis of the rel table from the FITS file   */
static void get_reltable_axis_double(double** val,
                                     int* nval,
                                     char* extname,
                                     char* colname,
                                     fitsfile* fptr,
                                     int* status)
{
    CHECK_STATUS_VOID(*status);

    fits_moveToExtension(extname, fptr, status);
    CHECK_STATUS_VOID(*status);


    // get the column id-number
    int colnum;
    if (fits_get_colnum(fptr, CASEINSEN, colname, &colnum, status)) return;

    long n;

    // get the number of rows
    if (fits_get_num_rows(fptr, &n, status)) return;

    // allocate memory for the array
    *val = (double*)malloc(n * sizeof(double));
    CHECK_MALLOC_VOID_STATUS(*val, status)

    int anynul = 0;
    double nullval = 0.0;
    auto nelem = (LONGLONG)n;
    fits_read_col(fptr, TDOUBLE, colnum, 1, 1, nelem, &nullval, *val, &anynul, status);

    *nval = (int)n;
}

static void get_returnRad_frac_dimensions(fitsfile* fptr, int* nr, int* ng, int* status)
{
    CHECK_STATUS_VOID(*status);

    long n;

    // get the number of rows
    if (fits_get_num_rows(fptr, &n, status)) return;

    // read nr and ng from VARIABLES for now

    *nr = RETURNRAD_TABLE_NR;
    *ng = RETURNRAD_TABLE_NG;

    if (*nr != (int)n)
    {
        RELXILL_ERROR("return rad table: mismatch in number or radial bins ", status);
        printf("     expecting %i bins, but found %i rows in the table \n", *nr, (int)n);
    }
}

static void fits_rr_read_col(fitsfile* fptr,
                             double* val,
                             int firstrow,
                             int firstelem,
                             int nval,
                             int colnum,
                             int* status)
{
    CHECK_STATUS_VOID(*status);

    int anynul = 0;
    double nullval = 0.0;
    auto nelem = (LONGLONG)nval;
    fits_read_col(fptr, TDOUBLE, colnum, (long)firstrow, (long)firstelem, nelem, &nullval, val, &anynul, status);

    relxill_check_fits_error(status);
}

int get_colnum(fitsfile* fptr, char* colname, int* status)
{
    int colnum;
    fits_get_colnum(fptr, CASEINSEN, colname, &colnum, status);
    relxill_check_fits_error(status);
    return colnum;
}

static double* fits_rr_load_1d_data(fitsfile* fptr, char* colname, int nval, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    int colnum = get_colnum(fptr, colname, status);
    CHECK_STATUS_RET(*status, nullptr);

    double* val = (double*)malloc(nval * sizeof(double));
    CHECK_MALLOC_RET_STATUS(val, status, nullptr)

    fits_rr_read_col(fptr, val, 1, 1, nval, colnum, status);

    return val;
}

static double** fits_rr_load_2d_data(fitsfile* fptr, char* colname, int nval1, int nval2, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    int colnum = get_colnum(fptr, colname, status);
    CHECK_STATUS_RET(*status, nullptr);

    double** val = (double**)malloc(nval1 * sizeof(double*));
    CHECK_MALLOC_RET_STATUS(val, status, nullptr)

    for (int ii = 0; ii < nval1; ii++)
    {
        val[ii] = (double*)malloc(nval2 * sizeof(double));
        CHECK_MALLOC_RET_STATUS(val[ii], status, nullptr)

        fits_rr_read_col(fptr, val[ii], ii + 1, 1, nval2, colnum, status);
    }

    return val;
}

static double*** fits_rr_load_3d_data(fitsfile* fptr, char* colname, int nval1, int nval2, int nval3, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    int colnum = get_colnum(fptr, colname, status);
    CHECK_STATUS_RET(*status, nullptr);

    double*** val = (double***)malloc(nval1 * sizeof(double**));
    CHECK_MALLOC_RET_STATUS(val, status, nullptr)

    for (int ii = 0; ii < nval1; ii++)
    {
        val[ii] = (double**)malloc(nval2 * sizeof(double*));
        CHECK_MALLOC_RET_STATUS(val[ii], status, nullptr)

        for (int jj = 0; jj < nval2; jj++)
        {
            val[ii][jj] = (double*)malloc(nval3 * sizeof(double));
            CHECK_MALLOC_RET_STATUS(val[ii][jj], status, nullptr)

            fits_rr_read_col(fptr, val[ii][jj], ii + 1, jj * nval3 + 1, nval3, colnum, status);
        }
    }

    return val;
}

static tabulatedReturnFractions* fits_rr_load_single_fractions(fitsfile* fptr, char* extname, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    fits_moveToExtension(extname, fptr, status);
    CHECK_STATUS_RET(*status, nullptr);

    int nrad = 0;
    int ng = 0;
    get_returnRad_frac_dimensions(fptr, &nrad, &ng, status);

    tabulatedReturnFractions* dat = new_returnFracData(nrad, ng, status);

    dat->rlo = fits_rr_load_1d_data(fptr, (char*)"rlo", nrad, status);
    dat->rhi = fits_rr_load_1d_data(fptr, (char*)"rhi", nrad, status);

    dat->frac_e = fits_rr_load_2d_data(fptr, (char*)"frac_e", nrad, nrad, status);
    dat->tf_r = fits_rr_load_2d_data(fptr, (char*)"tf_r", nrad, nrad, status);

    dat->gmin = fits_rr_load_2d_data(fptr, (char*)"gmin", nrad, nrad, status);
    dat->gmax = fits_rr_load_2d_data(fptr, (char*)"gmax", nrad, nrad, status);

    dat->frac_g = fits_rr_load_3d_data(fptr, (char*)"frac_g", nrad, nrad, ng, status);

    dat->f_ret = fits_rr_load_1d_data(fptr, (char*)"f_ret", nrad, status);
    dat->f_bh = fits_rr_load_1d_data(fptr, (char*)"f_ret", nrad, status);
    dat->f_inf = fits_rr_load_1d_data(fptr, (char*)"f_ret", nrad, status);

    return dat;
}

static void fits_rr_load_all_fractions(fitsfile* fptr, returnTable* tab, int* status)
{
    CHECK_STATUS_VOID(*status);


    // currently our naming scheme only supports 99 spin values
    assert(tab->nspin <= 99);
    assert(tab->nspin > 0);
    char extname[50];

    for (int ii = 0; ii < tab->nspin; ii++)
    {
        sprintf(extname, "FRAC%02i", ii + 1);

        tab->retFrac[ii] = fits_rr_load_single_fractions(fptr, extname, status);
        CHECK_STATUS_BREAK(*status);

        tab->retFrac[ii]->a = tab->spin[ii]; // TODO: verify if we really need this
    }
}

static void fits_rr_load_returnRadTable(fitsfile* fptr, returnTable** inp_tab, int* status)
{
    CHECK_STATUS_VOID(*status);

    returnTable* tab = new_returnTable(status);
    CHECK_STATUS_VOID(*status);

    /* get the number and values of the spin */
    double* spin;
    int nspin = 0;
    get_reltable_axis_double(&spin, &nspin, (char*)"SPIN", (char*)"a", fptr, status);

    /* initialize the table with them */
    init_returnTable(tab, nspin, status);
    tab->spin = spin;

    fits_rr_load_all_fractions(fptr, tab, status);

    (*inp_tab) = tab;
}

static void fits_read_returnRadTable(char* filename, returnTable** inp_tab, int* status)
{
    CHECK_STATUS_VOID(*status);

    // open the table, stored at pwd or RELXILL_TABLE_PATH
    fitsfile* fptr = open_fits_table_stdpath(filename, status);

    // make sure we only store the table in a location which is empty / nullptr
    assert(*inp_tab == nullptr);

    fits_rr_load_returnRadTable(fptr, inp_tab, status);

    if (*status != EXIT_SUCCESS)
    {
        printf(" *** error *** initializing of the RETURN RADIATION table %s failed \n", filename);
        free_returnTable(inp_tab);
    }

    if (fptr != nullptr)
    {
        fits_close_file(fptr, status);
    }
}

returnTable* get_returnrad_table(int* status)
{
    if (cached_retTable == nullptr)
    {
        fits_read_returnRadTable((char*)RETURNRAD_TABLE_FILENAME, &cached_retTable, status);
    }

    return cached_retTable;
}


static int select_spinIndexForTable(double val_spin, double* arr_spin, int nspin, int* status)
{
    // arr[k]<=val<arr[k+1]
    int k = binary_search(arr_spin, nspin, val_spin);

    // the spin of the table needs to be spin[k]>=val_spin
    if (arr_spin[k] < val_spin)
    {
        k++;
    }
    assert(arr_spin[k] >= val_spin);

    if (k >= nspin || k < 0)
    {
        RELXILL_ERROR("failed determining index of spin for the return radiation table \n", status);
        printf("    spin=%f leads to not allowed index of %i\n", val_spin, k);
    }

    return k;
}


static void allocate_radial_grid(returningFractions* ipol, double Rin, double Rout)
{
    assert(ipol != nullptr);
    int klo_Rlo = binary_search(ipol->tabData->rlo, ipol->tabData->nrad, Rin);
    int khi_Rhi = binary_search(ipol->tabData->rhi, ipol->tabData->nrad, Rout);

    if (fabs(Rout - ipol->tabData->rhi[ipol->tabData->nrad - 1]) < 1e-6)
    {
        khi_Rhi = ipol->tabData->nrad - 1;
    }
    else
    {
        khi_Rhi++;
    }

    if (fabs(Rin - ipol->tabData->rlo[0]) < 1e-6)
    {
        klo_Rlo = 0;
    }


    int nrad_trim = (khi_Rhi + 1) - klo_Rlo;

    assert(nrad_trim > 0);
    assert(nrad_trim <= ipol->tabData->nrad);

    ipol->nrad = nrad_trim;
    ipol->irad = VecI(nrad_trim);

    for (int ii = 0; ii < nrad_trim; ii++)
    {
        ipol->irad[ii] = klo_Rlo + ii;
    }

    ipol->rlo = VecD(nrad_trim);
    ipol->rhi = VecD(nrad_trim);
    ipol->rad = VecD(nrad_trim);

    for (int ii = 0; ii < nrad_trim; ii++)
    {
        ipol->rlo[ii] = ipol->tabData->rlo[ipol->irad[ii]];
        ipol->rhi[ii] = ipol->tabData->rhi[ipol->irad[ii]];
    }


    assert(Rin >= ipol->rlo[0] - 1e-4); // TODO: make this limit stronger with a better table

    // reset lowest bin to Rin
    ipol->rlo[0] = Rin;
    ipol->rhi[nrad_trim - 1] = Rout;

    for (int ii = 0; ii < nrad_trim; ii++)
    {
        ipol->rad[ii] = 0.5 * (ipol->rlo[ii] + ipol->rhi[ii]);
    }
}


static returningFractions* new_returningFractions(tabulatedReturnFractions* tab, double spin, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);

    auto dat = new returningFractions;
    CHECK_MALLOC_RET_STATUS(dat, status, dat)

    dat->tabData = tab;
    dat->a = spin;
    dat->nrad = 0;
    dat->irad = VecI();
    dat->rlo = VecD();
    dat->rhi = VecD();
    dat->rad = VecD();
    dat->tf_r = VecD_2D();
    return dat;
}

void delete_returningFractions(returningFractions** dat)
{
    if (*dat != nullptr)
    {
        delete *dat;
        *dat = nullptr;
    }
}

static VecD_2D trim_transfun_to_radial_grid(double** tab_tfr, const VecI& ind_arr, int n)
{
    VecD_2D tfr_trim(n);
    for (int ii = 0; ii < n; ii++)
    {
        tfr_trim[ii] = VecD(n );
        for (int jj = 0; jj < n; jj++)
        {
            tfr_trim[ii][jj] = tab_tfr[ind_arr[ii]][ind_arr[jj]];
        }
    }
    return tfr_trim;
}

static double get_ring_area_correction_factor(int index_radius, returningFractions* ret_fractions)
{
    assert(!ret_fractions->irad.empty());
    double rlo_table = ret_fractions->tabData->rlo[ret_fractions->irad[index_radius]];

    // make sure the lowest radial value of the table is at the ISCO
    if (ret_fractions->irad[index_radius] == 0 && rlo_table > kerr_rms(ret_fractions->a))
    {
        // only print a warning for significant deviations
        if (fabs(rlo_table - kerr_rms(ret_fractions->a)) > 1e-4)
        {
            printf(" *** warning: resetting Rin of the return radiation table to the ISCO at %e  (was %e before)\n",
                   kerr_rms(ret_fractions->a), rlo_table);
        }
        rlo_table = kerr_rms(ret_fractions->a);
    }

    double rhi_table = ret_fractions->tabData->rhi[ret_fractions->irad[index_radius]];
    double area_table = 0.5 * (rlo_table + rhi_table) * (rhi_table - rlo_table);
    double area_model = 0.5 * (ret_fractions->rlo[index_radius] + ret_fractions->rhi[index_radius]) *
        (ret_fractions->rhi[index_radius] - ret_fractions->rlo[index_radius]);

    return area_model / area_table;
}

void calculate_interpolated_tfr(returningFractions* ret_fractions)
{
    assert(!ret_fractions->irad.empty());
    assert(ret_fractions->tf_r.empty());

    ret_fractions->tf_r = trim_transfun_to_radial_grid(ret_fractions->tabData->tf_r,
                                            ret_fractions->irad,
                                            ret_fractions->nrad);

    int i_rad_rin = 0;
    int i_rad_rout = ret_fractions->nrad - 1;

    double area_correction_rin = get_ring_area_correction_factor(i_rad_rin, ret_fractions);
    double area_correction_rout = get_ring_area_correction_factor(i_rad_rout, ret_fractions);

    assert(area_correction_rin - 1 < 1e-6);
    assert(area_correction_rout - 1 < 1e-6);

    if (global_rr_do_interpolation)
    {
        // correction of frac_e by emitted photons with respect to the smaller area of emission at i_rad_in
        for (int i_rad_incident = i_rad_rin; i_rad_incident < i_rad_rout; i_rad_incident++)
        {
            ret_fractions->tf_r[i_rad_incident][i_rad_rin] *= area_correction_rin;
        }
        // same for i_rad_out
        for (int i_rad_incident = i_rad_rin; i_rad_incident < i_rad_rout - 1; i_rad_incident++)
        {
            ret_fractions->tf_r[i_rad_incident][i_rad_rout] *= area_correction_rout;
        }
    }
}


returningFractions* get_rrad_fractions(double spin, double rin, double rout, int* status)
{
    CHECK_STATUS_RET(*status, nullptr);
    assert(rin > 0);
    assert(rout > rin);

    returnTable* tab = get_returnrad_table(status); // table will only be loaded if it is not already done

    int ind_spin = select_spinIndexForTable(spin, tab->spin, tab->nspin, status);
    tabulatedReturnFractions* tab_fractions = tab->retFrac[ind_spin];

    returningFractions* ret_fractions = new_returningFractions(tab_fractions, spin, status);

    if (ret_fractions == nullptr)
    {
        CHECK_STATUS_AND_THROW(*status);
        return ret_fractions;
    }

    allocate_radial_grid(ret_fractions, rin, rout);
    calculate_interpolated_tfr(ret_fractions);

    return ret_fractions;
}