"""Plot slope field and show analytical solution."""

# slope_field_analytical_soln.py
# With help from Copilot for the script and ChatGPT for setting the fonts
# Re-import necessary libraries after execution state reset
import numpy as np
import matplotlib.pyplot as plt

from font_select import set_font_style

set_font_style("sans-serif")


def dv_dt(v):
    """Define the differential equation dv/dt = 10 - 0.001*v**2."""
    return 10 - 0.001 * v**2


def exact_solution(t):
    """Exact solution using separation of variables."""
    return 100 * np.tanh(0.1 * t)


# Parameters
t_start, t_end, dt = 0, 500, 1  # Time domain in seconds, with step size of 1 s
v_min, v_max = 0, 125  # Velocity domain in m/s
t = np.arange(t_start, t_end + dt, dt)

# Compute exact solution and Euler approximation
v_exact = exact_solution(t)

# Generate the slope field
t_grid, v_grid = np.meshgrid(
    np.linspace(t_start, t_end, 20), np.linspace(v_min, v_max, 20)
)
dv = dv_dt(v_grid)
dt_field = np.ones_like(dv)  # Time step for slope field
magnitude = np.sqrt(dt_field**2 + dv**2)  # Normalize vectors
dt_field /= magnitude
dv /= magnitude

# Create the figure and plot the elements
plt.figure(figsize=(8, 6))

# Plot the slope field as quiver
plt.quiver(
    t_grid,
    v_grid,
    dt_field,
    dv,
    color="grey",  # Maroon
    alpha=0.6,
    scale=30,
    label="Slope Field",
)

# Plot the asymptote v = 100
plt.axhline(
    y=100,
    color="red",
    linestyle="dashdot",
    linewidth=2.5,
    label=r"Asymptote $v=100$",
)

# Plot the exact solution
plt.plot(t, v_exact, label="Exact Solution", color="green", linewidth=2)

# Add labels, legend, and title
plt.xlabel(r"Time $t$ (s)")
plt.ylabel(r"Velocity $v$ (m/s)")
plt.title(
    (
        r"Slope Field, Asymptote, and Exact Solution for "
        r"$\dfrac{\mathrm{d}v}{\mathrm{d}t}= 100 - 0.001v^2$"
    ),
    pad=15,
)
plt.legend()
plt.xlim(t_start, t_end)
plt.ylim(v_min, v_max)
plt.grid()

# Show the plot
plt.tight_layout()
plt.savefig("../images/skydiver-slope-field-analytical.svg", transparent=True)
plt.show()
