import uproot
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import cholesky, inv
from scipy.stats import binned_statistic
import os

def get_correlations():
    # Create the Images directory if it doesn't exist
    os.makedirs("Images", exist_ok=True)

    # Number of dimensions and variable labels/names
    n_dims = 7
    var_names = ["x", "xp", "y", "yp", "t", "p", "phi"]

    # Injected Muons: Parameters for this input
    filename = "BMADStored.root"
    imagefilename = "BMADFullDeriv.png"
    plot_range = [(-20, 20), (-0.01, 0.01), (-60, 60), (-0.01, 0.01), (-150, 150), (-0.006, 0.006), (-1, 2 * np.pi - 1)]

    # Open input file and configure tree
    with uproot.open(filename) as f:
        t = f["t"]
        vals = {var: t[var].array() for var in var_names}

    # Convert to numpy arrays for easier manipulation
    vals = {var: np.array(vals[var], dtype=float) for var in var_names}

    #///////////////////////////////////////////////////////
    # STAGE 1: GET FULL DERIVATIVES
    #///////////////////////////////////////////////////////

    # Create histograms that we're going to profile and fit for our full derivatives
    # We want everything plotted against momentum
    h_vars_vs_p = {}
    full_deriv = np.zeros(n_dims)
    full_deriv_err = np.zeros(n_dims)

    # Full derivatives canvas where we're going to draw all the fits
    fig, axes = plt.subplots(1, n_dims, figsize=(5 * n_dims, 6))

    for j, var in enumerate(var_names):
        ax = axes[j]

        # Calculate profile using binned_statistic
        profile_x, bin_edges, _ = binned_statistic(
            vals["p"], vals[var], statistic="mean", bins=10, range=plot_range[n_dims - 2]
        )
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

        # Fit a line to the profile
        valid = ~np.isnan(profile_x)  # Ignore NaN values
        coeffs = np.polyfit(vals["p"], vals[var], 1)
        full_deriv[j] = coeffs[0]
        full_deriv_err[j] = np.std(vals[var] - np.polyval(coeffs, vals["p"]))

        # Plot
        ax.hist2d(vals["p"], vals[var], bins=100, range=[plot_range[n_dims - 2], plot_range[j]], cmap='viridis')
        ax.plot(bin_centers[valid], np.polyval(coeffs, bin_centers[valid]), color='green', label=f"Fit: {coeffs[0]:.4f}x + {coeffs[1]:.4f}")
        ax.set_title(f"{var} vs p")
        ax.set_xlabel("p")
        ax.set_ylabel(var)
        ax.legend()

    plt.tight_layout()
    plt.savefig(f"Images/{imagefilename}")
    plt.close()

    print("\nFull Derivative Values:")
    for j, var in enumerate(var_names):
        print(f"d{var}/dp = {full_deriv[j]} +/- {full_deriv_err[j]}")

    #///////////////////////////////////////////////////////
    # STAGE 2: GET PARTIAL DERIVATIVES
    #///////////////////////////////////////////////////////

    # Calculate mean values for covariance matrix calculation
    means = {var: np.mean(vals[var]) for var in var_names}

    # Calculate covariance matrix elements
    cov_matrix = np.zeros((n_dims, n_dims))
    for j in range(n_dims):
        for k in range(n_dims):
            cov_matrix[j, k] = np.mean((vals[var_names[j]] - means[var_names[j]]) * (vals[var_names[k]] - means[var_names[k]]))

    # Store sigmas from diagonals
    sigmas = np.sqrt(np.diag(cov_matrix))

    # Calculate correlation matrix
    corr_matrix = np.zeros((n_dims, n_dims))
    for j in range(n_dims):
        for k in range(n_dims):
            corr_matrix[j, k] = cov_matrix[j, k] / (sigmas[j] * sigmas[k])

    # Do Cholesky decomposition
    L = cholesky(corr_matrix, lower=True)

    # Invert the top left corner (everything except the phi line)
    LXY = inv(L[:-1, :-1])

    # Plug the inverted matrix into the top left of full dimension matrix
    M = np.eye(n_dims)
    M[:-1, :-1] = LXY

    # Finally, get the normalised regression coefficients
    beta = L @ M
    part_deriv = beta[-1, :-1] * sigmas[-1] / sigmas[:-1]

    print("\nPartial Derivative Values:")
    for j, var in enumerate(var_names[:-1]):
        print(f"∂phi/∂{var} = {part_deriv[j]}")

    #///////////////////////////////////////////////////////
    # STAGE 3: PUT PARTIAL + FULL DERIVATIVES TOGETHER
    #///////////////////////////////////////////////////////

    # Calculate total derivative from combination of partial and full derivatives
    calc_deriv = np.sum(part_deriv * full_deriv[:-1])

    # Print out the final summary:
    print("\nFinal Comparison:")
    print(f"Total dphi/dp (Fitted) = {full_deriv[-1]}")
    print(f"Total dphi/dp (Calc)   = {calc_deriv}")
    print("\nwhere the calculated value is the sum of:")
    for j, var in enumerate(var_names[:-1]):
        print(f"∂phi/∂{var} * d{var}/dp = {part_deriv[j]} * {full_deriv[j]} = {part_deriv[j] * full_deriv[j]}")

# Call the function
get_correlations()
