import sys
import copy
from math import *
import numpy as np
from numpy.random import default_rng
import matplotlib.pyplot as plt
import matplotlib.animation as ani
    
[docs]def print_progress(step, total):
    """
    Prints a progress bar.
    
    Args:
        step (int): progress counter
        total (int): counter at completion
    """
    message = "simulation progress ["
    total_bar_length = 60
    percentage = int(step / total * 100)
    bar_fill = int(step / total * total_bar_length)
    for i in range(total_bar_length):
        if i < bar_fill:
            message += "|"
        else:
            message += " "
    
    message += "] "+str(percentage)+" %"
    if step < total:
        print(message, end="\r")     
    else:
        print(message)  
        
        
[docs]def expand_discretization(wave,
            left_boundary_free = False,
            right_boundary_free = False):
    """
    Adds zeros to the ends of the wave vector, if necessary.
    
    If the system has fixed boundaries, the wave function will always
    be set to zero there. It is therefore unnecesary to include it
    in the simulation proper. For these boundaries, the wave is simulated
    ignoring the zero end points, and this function reinserts the zeros
    in place after the simulation.
    
    Args:
        wave (array): discretized wave function
        left_boundary_free (bool): If True, free medium boundary condition is applied at the left bound
        right_boundary_free (bool): If True, free medium boundary condition is applied at the right bound        
    """
    expanded = wave
    if not left_boundary_free:
        expanded = np.concatenate( ( [0], expanded ) )
    if not right_boundary_free:
        expanded = np.concatenate( ( expanded, [0]) )
    return expanded 
[docs]def draw(frame, wave, x):
    """
    Draws the wave. 
    
    Used for animation.
    
    Args:
        frame (int): index of the frame to be drawn
        wave (list): list of waves as a time series
        x (array): x coordinates
    """
    plt.clf()
    ax = plt.axes()
    ax.set_ylim([-2,2])
    plt.xlabel("position, x")
    plt.ylabel("wavefunction, u")
    plt.plot(x, wave[frame]) 
    
    
[docs]def animate(wave, L,
            left_boundary_free = False,
            right_boundary_free = False):
    """
    Animates the simulation.
    
    Args:
        wave (list): list of waves as a time series
        L (float): length of simulated region :math:`L`
    """
    
    nframes = len(wave)
        
    for i in range(nframes):
        wave[i] = expand_discretization(wave[i], left_boundary_free, right_boundary_free)
    
    n_nodes = len(wave[0])
    x = np.linspace(0,L,n_nodes)
    
    fig = plt.figure()
    motion = ani.FuncAnimation(fig, draw, nframes, interval=10, fargs=( wave, x ) )
    plt.show() 
    
    
[docs]def propagator(n_nodes, v, dt, dx,
            left_boundary_free = False,
            right_boundary_free = False):
    """
    Calculates the matrix describing the time evolution of the system.
    
    The discretized wave equation can be written in matrix form as
    
    .. math ::
    
        u(t + \\Delta t) = M u(t) - u(t - \\Delta t).
        
    Here :math:`M` is a sparse matrix which has the diagonal elements
    
    .. math ::
    
        M_{i,i} = 2 - 2 v^2 \\left( \\frac{\\Delta t}{\\Delta x} \\right)^2
        
    and off-diagonals
    
    .. math ::
    
        M_{i,i\\pm 1} = v^2 \\left( \\frac{\\Delta t}{\\Delta x} \\right)^2.
        
    If the medium is fixed to :math:`u = 0` at the boundaries, this is
    true for all off-diagonals. If the medium is free to move at the boundaries,
    the first or last row off-diagonals must be adjusted to
    
    .. math ::
    
        M_{0,1} = M_{N-1,N-2} = 2 v^2 \\left( \\frac{\\Delta t}{\\Delta x} \\right)^2.
    
    .. note ::
        This function is incomplete!
    
    Args:
        n_nodes (int): number of grid points :math:`N` (i.e., the number of elements in the vector :math:`u`)
        v (float): wave speed
        dt (float): time step, :math:`\\Delta t`
        dx (float): distance between grid points, :math:`\\Delta x`
        left_boundary_free (bool): If True, free medium boundary condition is applied at the left bound
        right_boundary_free (bool): If True, free medium boundary condition is applied at the right bound
        
    Returns:
        array: the matrix :math:`M`
    """
    # initialize the array to zeros
    m = np.zeros([n_nodes, n_nodes])
    # calculate and store v^2 dt^2 / cx^2
    c = v*v*dt*dt/(dx*dx)
    
    for i in range(n_nodes):
        m[i,i] = 2 - 2*c
        if i > 0:
            m[i, i-1] = c
        if i < n_nodes-1:
            m[i, i+1] = c
    
    # The above sets fixed boundary conditions.
    # Add free boundaries, if requested.
    # todo
                
    return m 
    
    
[docs]def add_gaussian(string, a = 1, center = 0.5):
    """
    Initialize the wave to a gaussian initial shape.
    
    The function adjusts the shape of the wave by adding
    a narrow gaussian function to the wave function. 
    Since this operation is additive,
    it can be called several times to add several pulses.
    
    Args:
        string (array): discretized wave function to be modified
        a (float): pulse height
        center (float): fractional position of the peak, should be between 0 and 1
    """
    n_nodes = len(string)
    for i in range(n_nodes):
        string[i] += a * exp( - ( 5.0/float(n_nodes) ) * (i-n_nodes*center+0.5)**2 ) 
[docs]def add_sinusoidal(string, n, a = 1):
    """
    Initialize the wave to a sinusoidal initial shape.
    
    The function adjusts the shape of the wave by adding
    a sinus function to the wave function. 
    The sinus function will go to zero at the edges of the system.
    Since this operation is additive,
    it can be called several times to build the wave as a
    superposition of sinus functions.
    
    Args:
        string (array): discretized wave function to be modified
        n (int): number of antinodes
        a (float): amplitude
    """
    n_nodes = len(string)
    for i in range(n_nodes):
        string[i] += sin(n*pi*(i+1)/(n_nodes+1)) 
    
[docs]def add_triangle(string, a = 1, peak = 0.5):
    """
    Initialize the wave to a triangular initial shape.
    
    The function adjusts the shape of the wave by adding
    a triangular function to the wave function.
    The triangle will go to zero at the edges of the system.
    Since this operation is additive,
    it can be called several times to add several triangles.
    
    Args:
        string (array): discretized wave function to be modified
        a (float): amplitude
        peak (float): fractional position of the peak, should be between 0 and 1
    """
    n_nodes = len(string)
    top = peak*n_nodes-0.5
    for i in range(n_nodes):
        if i < top:
            string[i] = a*(i+1)/(top+1)
        else:
            string[i] = a*(i-n_nodes)/(top-n_nodes) 
            
[docs]def first_step(initial_wave, matrix, initial_velocity = 0, dt = 0):
    """
    Performs the first time step.
    
    Since the wave equation is a second degree partial differential equation
    with respect to time, one needs to know the wave function at two
    time steps in order to simulate the wave dynamics.
    
    Often one knows instead the initial shape of the wave, :math:`u(0)`, and the
    initial velocity of the medium, :math:`\\partial_t u(0)`.
    Especially if the medium is initially at rest, :math:`\\partial_t u(0) = 0`.
    
    This function calculates the shape of the wave after the first as
    
    .. math ::
    
        u(\\Delta t) = \\frac{1}{2} M u(0) + \\partial_t u(0) \\Delta t.
        
    Args:
        initial_wave (array): discretized wave function in the beginning, :math:`u(0)`
        matrix (array): matrix :math:`M` - should be pre-calculated using :meth:`propagator`
        initial_velocity (array): velocity of the wave medium in the beginning, :math:`\\partial_t u(0)`
        dt (float): time step :math:`\\Delta t`
        
    Returns:
        array: discretized wave function after the first time step, :math:`u(\\Delta t)`
    """
    return 0.5 * matrix @ initial_wave + dt*initial_velocity 
  
[docs]def run_simulation(initial_wave, next_wave, matrix, dt, time, recording_dt = 0):
    """
    Run a dynamic wave simulation.
    
    The discretized wave equation can be written in matrix form as
    
    .. math ::
    
        u(t + \\Delta t) = M u(t) - u(t - \\Delta t).
        
    This function iterates this step for the required time.
    
    The function returns a list containing the wave function at different times.
    
    .. note ::
        This function is incomplete!
    
    Args:
        initial_wave (array): discretized wave function at the start, :math:`u(0)`
        next_wave (array): discretized wave function after first step, :math:`u(\\Delta t)`
        matrix (array): matrix :math:`M` - should be pre-calculated using :meth:`propagator`
        dt (float): time step :math:`\\Delta t`
        time (float): total simulation time
        recording_dt (float): time interval between saved wavefunctions
    Returns:
        list: time evolution of the wave function
    """
    
    n_steps = int(time/dt)
    if recording_dt < dt:
        recording_interval = n_steps
    else:
        recording_interval = int(recording_dt/dt)
    
    history = []
    
    current_wave = next_wave
    past_wave = initial_wave
    
    for i in range(n_steps):
        # apply dynamics
        # move u(j) to u(j-1) and u(j+1) to u(j)
        # todo
        print_progress(i+1,n_steps)
        
        if i%recording_interval == 0:
            history.append(current_wave)
        
    return history 
  
[docs]def main():
    """
    The main program. Simulates a 1D wave.
    """
    # boundary conditions
    left_bound_free = False
    right_bound_free = False
    
    # number of grid points in x-axis
    n_nodes = 99
    
    # length of the string
    length = 2.0
    # if we have fixed bounds, the first and last 
    # grid points are not explicitly included, but let's count them
    dn = 0
    if not left_bound_free:
        dn += 1
    if not right_bound_free:
        dn += 1
    # distance between grid points in x and t
    dx = length/(n_nodes+dn-1)
    dt = 0.02
    
    
    # timing
    simulation_time = 8.0
    sample_dt = 0.05
    
    # wave speed
    v = 1.0
    
    # Calculate the propagator matrix for time dynamics
    matrix = propagator(n_nodes, v, dt, dx,
            left_bound_free,
            right_bound_free )
    # initial conditions
    wave_initial = np.zeros(n_nodes)
    v_initial = np.zeros(n_nodes)
    
    # choose the initial shape of your wave
    initial_position = 0.2
    add_gaussian(wave_initial, a = 1, center = initial_position)
    #add_triangle(wave_initial, a = 1, peak = initial_position)
    
    # calculate the shape of the wave one timestep after start
    # since wave equation is 2nd degree, this cannot be obtained
    # from wave equation
    
    # option 1: the medium is initially at rest
    wave_dt = first_step(wave_initial, matrix, dt, v_initial)
    # option 2: the wave moves initially to the right
    # note - this only works if the initial shape was gaussian
    #wave_dt = np.zeros(n_nodes)
    #add_gaussian(wave_dt, a = 1, center = initial_position + v*dt/length)
    # record evolution of the wave
    wavefunction = [wave_initial]
    #
    # Run simulation
    #
    wavefunction += run_simulation(wave_initial, wave_dt, matrix, dt, simulation_time, sample_dt)
        
    # Show the result
    animate(wavefunction, length, left_bound_free, right_bound_free) 
    
    
if __name__ == "__main__":
    main()