import numpy as np
from scipy import stats as ss

from chemex.containers import plot as ccp
from chemex.parameters import settings as cps


def post_fit(experiments, params, path, plot=False, simulation=False):
    _print_chisqr(experiments, params)
    _write_files(experiments, params, path)
    if plot:
        ccp.write_plots(experiments, params, path, simulation)


def _write_files(experiments, params, path):
    """Write the results of the fit to output files."""
    print(f'Writing results -> "{path}/"\n')
    path.mkdir(parents=True, exist_ok=True)
    cps.write_par(params, path)
    experiments.write(params, path)
    _write_statistics(experiments, params, path=path)


def _write_statistics(experiments, params, path):
    """Write fitting statistics to a file."""
    statistics = calculate_statistics(experiments, params)
    filename = path / "statistics.toml"
    with open(filename, "w") as f:
        f.write(f"number of data points          = {statistics['ndata']}\n")
        f.write(f"number of variables            = {statistics['nvarys']}\n")
        f.write(f"chi-square                     = {statistics['chisqr']: .5e}\n")
        f.write(f"reduced-chi-square             = {statistics['redchi']: .5e}\n")
        f.write(f"chi-squared test               = {statistics['pvalue']: .5e}\n")
        f.write(f"Kolmogorov-Smirnov test        = {statistics['ks_pvalue']: .5e}\n")
        f.write(f"Akaike Information Criterion   = {statistics['aic']: .5e}\n")
        f.write(f"Bayesian Information Criterion = {statistics['bic']: .5e}\n")


def _print_chisqr(experiments, params):
    statistics = calculate_statistics(experiments, params)
    print(f"Final Chi2        : {statistics['chisqr']:.3e}")
    print(f"Final Reduced Chi2: {statistics['redchi']:.3e}\n")


def calculate_statistics(experiments, params):
    residuals = experiments.residuals(params)
    ndata = len(residuals)
    nvarys = len([param for param in params.values() if param.vary and not param.expr])
    chisqr = sum(residuals ** 2)
    redchi = chisqr / max(1, ndata - nvarys)
    _neg2_log_likel = ndata * np.log(chisqr / ndata)
    aic = _neg2_log_likel + 2 * nvarys
    bic = _neg2_log_likel + np.log(ndata) * nvarys
    _, ks_p_value = ss.kstest(residuals, "norm")
    pvalue = 1.0 - ss.chi2.cdf(chisqr, ndata - nvarys)
    return {
        "ndata": ndata,
        "nvarys": nvarys,
        "chisqr": chisqr,
        "redchi": redchi,
        "pvalue": pvalue,
        "ks_pvalue": ks_p_value,
        "aic": aic,
        "bic": bic,
    }


def print_header(params, grid):
    pnames = (str(params[fname].user_data["pname"]) for fname in grid)
    header_pnames = " ".join(f"{pname}" for pname in pnames)
    return f"# {header_pnames} {'[χ²]'}\n"


def print_values(values, chisqr):
    body_values = " ".join(f"{value:.5e}" for value in values)
    return f"  {body_values} {chisqr:.5e}\n"


def print_values_stat(params, fnames, chisqr):
    body_values_list = []
    for fname in fnames:
        if fname in params:
            body_values_list.append(f"{params[fname].value:12.5e}")
        else:
            body_values_list.append(f"{'--':^12s}")
    body_values = " ".join(body_values_list)
    return f"  {body_values} {chisqr:.5e}\n"
