"""
Saha Ionization Map
===================

This example plots the ionization state of a CH plasma as a two-dimensional
quantity over temperature and density. Different selections for the IPD model
(here, :py:class:`jaxrts.models.StewartPyattIPD`) can be used for the
calculation, moving the ionization considerably.
"""

import matplotlib.pyplot as plt
import matplotlib as mpl
import jax.numpy as jnp
import jaxrts
from scipy.interpolate import RectBivariateSpline

ureg = jaxrts.ureg

plt.rcParams.update({"font.size": 15})

# Set colormap and norms
cmap = "turbo"
norm_C = mpl.colors.Normalize(vmin=0, vmax=6)
norm_H = mpl.colors.Normalize(vmin=0, vmax=1)

# Init PlasmaState
rho_init: float = 25.0  # g/cm^3
T_e_init: float = 80.0  # eV
Z_C_init: float = 4.0  # unitless
Z_H_init: float = 1.0  # unitless

list_ions = [jaxrts.Element("C"), jaxrts.Element("H")]
number_fraction = jnp.array([0.5, 0.5])
Z_free = jnp.array([Z_C_init, Z_H_init])

ions = list_ions
mass_fraction = jaxrts.helpers.mass_from_number_fraction(
    number_fractions=number_fraction, elements=ions
)

plasmastate = jaxrts.PlasmaState(
    ions=ions,
    Z_free=Z_free,
    mass_density=rho_init * mass_fraction * ureg.gram / ureg.cm**3,
    T_e=T_e_init * ureg.electron_volt / ureg.k_B,
)

# Set IPD model
plasmastate["ipd"] = jaxrts.models.StewartPyattIPD()
# plasmastate["ipd"] = jaxrts.models.DebyeHueckelIPD()

plasmastate["chemical potential"] = jaxrts.models.IchimaruChemPotential()

# Set nr_points_per_row for resolution and provide T and rho range
nr_points_per_row = 50
temperature_range = jnp.logspace(0, 3, nr_points_per_row)
mass_density_range = jnp.logspace(-2, 3, nr_points_per_row)
ionization_C = jnp.zeros((temperature_range.size, mass_density_range.size))
ionization_H = jnp.zeros_like(ionization_C)

# Iterate trough all (T,rho) combinations
for i, temp in enumerate(temperature_range):
    for j, rho in enumerate(mass_density_range):
        plasmastate.T_e = temp * (ureg.electron_volt / ureg.k_B)
        plasmastate.mass_density = rho * mass_fraction * ureg.gram / ureg.cm**3
        Z_C, Z_H = jaxrts.ionization.calculate_mean_free_charge_saha(
            plasma_state=plasmastate,
            use_ipd=True,
            use_chem_pot=True,
            use_distribution=False,
        )[1]
        ionization_C = ionization_C.at[i, j].set(Z_C)
        ionization_H = ionization_H.at[i, j].set(Z_H)


# Interpolate result for a smoother colorplot
interp_temp = jnp.logspace(0, 3, 1000)
interp_rho = jnp.logspace(-2, 3, 1000)
interp_C = RectBivariateSpline(
    mass_density_range, temperature_range, ionization_C, kx=1, ky=1
)
interp_H = RectBivariateSpline(
    mass_density_range, temperature_range, ionization_H, kx=1, ky=1
)

ioni_C = interp_C(interp_rho, interp_temp)
ioni_H = interp_H(interp_rho, interp_temp)

# Plotting
fig, (ax, ax1) = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
pcm_C = ax.pcolormesh(
    interp_rho,
    interp_temp,
    ioni_C,
    cmap=cmap,
    norm=norm_C,
    shading="auto",
    rasterized=True,
)
pcm_H = ax1.pcolormesh(
    interp_rho,
    interp_temp,
    ioni_H,
    cmap=cmap,
    norm=norm_H,
    shading="auto",
    rasterized=True,
)


cb = plt.colorbar(pcm_C, ax=ax, location="right")
cb.set_label(label=r"Mean ionization Z$_C$")
cb1 = plt.colorbar(pcm_H, ax=ax1, location="right")
cb1.set_label(label=r"Mean ionization Z$_H$")

ax.set_xlabel("Mass Density [g/cc]")
ax1.set_xlabel("Mass Density [g/cc]")
ax.set_ylabel("Electron Temperature [eV]")
ax.set_yscale("log")
ax1.set_yscale("log")
ax.set_xscale("log")
ax1.set_xscale("log")
ax.set_title("CH plasma: Ionization Carbon", weight="bold")
ax1.set_title("CH plasma: Ionization Hydrogen", weight="bold")
fig.tight_layout(pad=0.3)
# fig.savefig("saha_ionization_plot.png", dpi=300)
plt.show()
