Number of interpolation in the Born Mermin Chapman Interpolation

The Chapman Interpolation of the Born Mermin approximation saves computation time by evaluating the Structure factor not on all frequencies, but only on a given number of points, and interpolates after. This example is investigating the required number of points for the interpolation.

We also compare the number of points required when solving the integral for the imaginary part of the collision frequency, or connecting it via a Kramers Kronig relation (KKT = True). We observe that the number of grid-points has to be higher for comparable quality of the result, when using KKT.

Note

The time printed in this script is the time after the first compilation (which normally takes a notable time).

plot BM ChapmanInterp
Full BMA: 2.365185499191284s
2 interp points: 0.1532890796661377s       Mean deviation from full RPA:  -6.417660737971308e-20 Max deviation from full RPA:  3.3269295227579567e-18
4 interp points: 0.1525421142578125s       Mean deviation from full RPA:  5.842183106337609e-20 Max deviation from full RPA:  1.1313588367551087e-18
20 interp points: 0.20197486877441406s       Mean deviation from full RPA:  -2.8298423586592834e-20 Max deviation from full RPA:  4.2172178767420766e-20
100 interp points: 0.3237178325653076s       Mean deviation from full RPA:  -3.277086057580894e-20 Max deviation from full RPA:  2.137215271671695e-20
2 interp points: 0.1518878936767578s       Mean deviation from full RPA:  2.200652021343034e-19 Max deviation from full RPA:  2.9810113737122587e-18
4 interp points: 0.15661859512329102s       Mean deviation from full RPA:  1.1844918959575883e-19 Max deviation from full RPA:  1.0306087100343865e-18
20 interp points: 0.20629048347473145s       Mean deviation from full RPA:  2.278627298005906e-20 Max deviation from full RPA:  1.2710680301020966e-18
100 interp points: 0.3339109420776367s       Mean deviation from full RPA:  2.0669797057135996e-20 Max deviation from full RPA:  1.2838363569820362e-18
RPA: 0.05400681495666504s

import os
import time
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jaxrts

# Allow jax to use 6 CPUs, see
# https://astralord.github.io/posts/exploring-parallel-strategies-with-jax/
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=6"

ureg = jaxrts.units.ureg

plt.style.use("science")

# Create a sharding for the probing energies
measured_energy = jnp.linspace(295, 305, 300) * ureg.electron_volt

setup = jaxrts.setup.Setup(
    ureg("60°"),
    energy=ureg("300eV"),
    measured_energy=measured_energy,
    instrument=partial(
        jaxrts.instrument_function.instrument_gaussian,
        sigma=(0.01 * ureg.electron_volt) / ureg.hbar,
    ),
)
state = jaxrts.PlasmaState(
    [jaxrts.Element("H")],
    jnp.array([1.0]),
    jnp.array([0.0017]) * ureg.gram / ureg.centimeter**3,
    jnp.array([2]) * ureg.electron_volt / ureg.k_B,
    jnp.array([2]) * ureg.electron_volt / ureg.k_B,
)

state["free-free scattering"] = jaxrts.models.BornMermin_Full()
# This is required for the S_ii in the collision frequency
state["ionic scattering"] = jaxrts.models.ArkhipovIonFeat()

# This is required for the V_eiS in the collision frequency
state["BM V_eiS"] = jaxrts.models.DebyeHueckel_BM_V()
state["BM S_ii"] = jaxrts.models.Sum_Sii()
state.evaluate("free-free scattering", setup).m_as(ureg.second)
t0 = time.time()
BM_free_free_scatter = state.evaluate("free-free scattering", setup).m_as(
    ureg.second
)
jax.block_until_ready(BM_free_free_scatter)
print(f"Full BMA: {time.time()-t0}s")
state["free-free scattering"] = jaxrts.models.BornMermin()
state["free-free scattering"].set_guessed_E_cutoffs(state, setup)


for ls, KKT in zip(["solid", "dotted"], [False, True], strict=False):
    for i, no_of_freq in enumerate([2, 4, 20, 100]):
        state["free-free scattering"].no_of_freq = no_of_freq
        state["free-free scattering"].KKT = KKT
        state.evaluate("free-free scattering", setup).m_as(ureg.second)
        t0 = time.time()
        free_free_scatter = state.evaluate("free-free scattering", setup).m_as(
            ureg.second
        )
        jax.block_until_ready(free_free_scatter)
        print(
            f"{no_of_freq} interp points: {time.time()-t0}s      ",
            "Mean deviation from full RPA: ",
            jnp.mean(free_free_scatter - BM_free_free_scatter),
            "Max deviation from full RPA: ",
            jnp.max(free_free_scatter - BM_free_free_scatter),
        )
        plt.plot(
            setup.measured_energy.m_as(ureg.electron_volt),
            free_free_scatter,
            label=f"{no_of_freq} interpolation points",
            linestyle=ls,
            color=f"C{i}",
            alpha=0.8,
        )

plt.plot(
    setup.measured_energy.m_as(ureg.electron_volt),
    BM_free_free_scatter,
    label="Full BMA",
    color="black",
)

state["free-free scattering"] = jaxrts.models.RPA()
state.evaluate("free-free scattering", setup).m_as(ureg.second)
t0 = time.time()
free_free_scatter = state.evaluate("free-free scattering", setup).m_as(
    ureg.second
)
jax.block_until_ready(free_free_scatter)
print(f"RPA: {time.time()-t0}s")
plt.plot(
    setup.measured_energy.m_as(ureg.electron_volt),
    free_free_scatter,
    label="RPA",
    linestyle="dashed",
    color="gray",
)

plt.xlabel("Energy [eV]")
plt.ylabel("Scattering intensity")
plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1.00))
plt.show()

Total running time of the script: (1 minutes 5.446 seconds)

Gallery generated by Sphinx-Gallery