# /usr/bin/env python
import copy
from math import *
import numpy as np
from numpy.random import default_rng
import matplotlib.pyplot as plt
import matplotlib.animation as ani
from scipy.optimize import curve_fit
import time



def print_progress(step, total):
    """
    Prints a progress bar.
    
    Args:
        step (int): progress counter
        total (int): counter at completion
    """

    message = "simulation progress ["
    total_bar_length = 60
    percentage = int(step / total * 100)
    bar_fill = int(step / total * total_bar_length)
    for i in range(total_bar_length):
        if i < bar_fill:
            message += "|"
        else:
            message += " "
    
    message += "] "+str(percentage)+" %"
    if step < total:
        print(message, end="\r")     
    else:
        print(message) 
        

def show_cluster(grid, cut_corners = False):
    """
    Draws the system as a black and white pixel image and saves it to cluster.png.
    
    Args:
        grid (array): grid where the cluster is stored as values of 1
        cut_corners (bool): If True, do not show the entire grid.
            Instead only draw the area containing the cluster.
    """
    
    plottable = grid
    if cut_corners:
        midpoint = int(grid.shape[0]/2)
        r = int(calculate_cluster_radius(grid, [midpoint, midpoint]))
        r = min(r+2, midpoint-2)
        plottable = grid[midpoint-r:midpoint+r, midpoint-r:midpoint+r]
    
    plt.clf()
    ax = plt.axes()
    ax.set_aspect('equal') 
    plt.pcolormesh(plottable, cmap='Greys', vmin=0, vmax=1)
    plt.savefig('cluster.png', dpi=1000)
    plt.show()
    

    
def draw(frame, history):
    """
    Draws the system for animation.
    
    Args:
        frame (int): index of the frame to draw
        history (list): list of systems at different times
    """
    plt.clf()
    ax = plt.axes()
    ax.set_aspect('equal') 
    plt.pcolormesh(history[frame], cmap='Greys', vmin=0, vmax=1)


def animate(history):
    """
    Animate the simulation.
    
    Args:
        history (list): list of systems at different times
    """

    nframes = len(history)
    print("animating "+str(nframes)+" frames")
    fig = plt.figure()
    motion = ani.FuncAnimation(fig, draw, nframes, fargs=( history, ) )
    plt.show(block=True)
    plt.close(fig)


def count_particles_in_square(center, boxsize, grid):
    """
    Count the number of points in the grid
    with the value of 1 for which
    the x and y indices differ from the coordinates
    [x0 ,y0] given in 'center' by at most boxsize.
    
    That is, count how many of the grid points in
    [x0-boxsize, x0+boxsize] x [y0-boxsize, y0+boxsize]
    are 1.
    
    Args:
        center (array): cluster center coordinates [x0, y0]
        boxsize (int): maximum distance from center
        grid (array): the grid to be analyzed
    
    Returns: 
        int: the number of gridpoints = 1 in the box
    """
    
    minx = center[0]-boxsize
    maxx = center[0]+boxsize
    miny = center[1]-boxsize
    maxy = center[1]+boxsize

    particles = 0
    for i in range(minx, maxx):
        for j in range(miny, maxx):
            if grid[i,j]:
                particles += 1
    return particles
  
  
def count_particles_in_circle(center, R_max, grid):
    """
    Count the number of points in the grid
    with the value of 1 for which the distance
    from center is at most R_max.
    
    That is, count how many of the grid points
    in a circle with radius R_max belong to the cluster.
    
    Args:
        center (array): cluster center coordinates [x0, y0]
        R_max (float): maximum distance from center
        grid (array): the grid to be analyzed
    
    Returns: 
        int: the number of gridpoints = 1 in the circle
    """
    
    ilen, jlen = grid.shape

    particles = 0
    for i in range(ilen):
        for j in range(jlen):
            if grid[i,j]:
                dx = i - center[0]
                dy = j - center[1]
                if dx*dx+dy*dy < R_max*R_max:
                    particles += 1
    
    return particles  


def random_step(x,y):
    """
    Take a random step of length 1
    from the given coordinates (x,y)
    and return the new coordinates.
    
    Args:
        x (int): the x coordinate
        y (int): the y coordinate
    
    Returns: 
        list: the new coordinates, [x,y]
    """
    rnd = random.random()
    if rnd < 0.25:
        return [x+1, y]
    elif rnd < 0.5:
        return [x-1, y]
    elif rnd < 0.75:
        return [x, y+1]
    else:
        return [x, y-1]
    return [x, y]



def neighbor_is_in_cluster(x,y, grid):
    """
    Check if any of the neighbors of the given point
    are already a part of the cluster.
    
    The neighbors of the point (x,y) are the points
    (x-1, y), (x+1, y), (x, y-1) and (x, y+1).
    
    A point is part of the cluster if that point
    has the value of 1 in the given array 'grid'.
    
    Args:
        x (int): the x coordinate
        y (int): the y coordinate
        grid (array): the grid to be checked
    
    Returns: 
        bool: True if any neighbor is in the cluster
    """
    try:
        if grid[x-1,y]:
            return True
        elif grid[x+1,y]:
            return True
        elif grid[x,y-1]:
            return True
        elif grid[x,y+1]:
            return True
        else:
            return False
    except:
        return False




def pick_random_point(center, distance):
    """
    Choose a random point at the given distance from
    the center point. 
    
    Distance should be a real number
    and center should be an array of two real numbers,
    [x, y]. The result is rounded to integer coordinates.
    
    Args:
        center (list): center point coordinates [x0, y0]
        distance (float): distance from center
    
    Returns: 
        array: integer coordinates as an array [x, y]
    """

    angle = random.random()*2.0*pi    
    point = center + distance * np.array([cos(angle), sin(angle)])
    return np.array( [int(point[0]), int(point[1])] )


    
def mind_the_gap(x, y, center, cluster_radius, margin=10):
    """
    Check the distance between the cluster and the particle and
    decrease it if the distance is too large.
    
    If the particle is far from the cluster, we want to move it
    a bit closer. This saves simulation time by making sure the
    particle cannot wander aimlessly away from the cluster. That
    could lead to very long random walks.
    
    In practice, the function calculates the distance betweeen
    the given (x,y) coordinates and the center coordinates.
    If this distance is larger than cluster_radius + margin,
    the function calculates a new point which is in the same direction
    from the center as the point (x,y) but 0.5*margin closer.
    
    If the original point was not too far, the function returns
    the original coordinates. If it was, the function returns
    the new coordinates rounded to integers values.
    
    Args:
        x (int): the x coordinate
        y (int): the y coordinate
        center (array): cluster center coordinates [x0, y0]
        cluster_radius (float): the maximum radius of the cluster
        margin (float): the allowed distance in addition to cluster radius
    
    Returns: 
        int, int: new x, new y
    """
    
    # distance from the center in x and y
    dx = x-center[0]
    dy = y-center[1]
    
    # current distance, squared
    rsq = dx*dx + dy*dy
    
    # the maximum distance allowed, squared
    rsqmax = (cluster_radius+margin)*(cluster_radius+margin)
    
    # check if the point is too far away
    if rsq > rsqmax:

        # Find a point closer to the center.
        #
        # If the original vector from center to (x,y) was
        # r0 = [dx, dy],
        # we want to find a new shorter vector in the same direction
        # r1 = r0 * |r1|/|r0| = [dx * |r1|/|r0|, dy * |r1|/|r0|].
        # The change in position is
        # r1 - r0 = = [dx (|r1|/|r0| - 1), dy (|r1|/|r0| - 1)]
        # and this same change needs to be applied to the actual
        # coordinates x and y.
        #
        # So, we need to change the x coordinate by
        # dx * (|r1|/|r0| - 1) = dx * (|r1|-|r0|)/|r0|
        # and the y coordinate by
        # dy * (|r1|/|r0| - 1) = dy * (|r1|-|r0|)/|r0|.
        # We first calculate the scaling factor
        # (|r1|-|r0|)/|r0|
        # and then multiply dx and dy with it to get the necessary
        # shifts in coordinates.
        #
        # For simplicity, let's choose |r1|-|r0| = -0.5*margin.
        # This means we always shift the coordinates by
        # half the margin towards the center. Since we are at a
        # distance of approximately cluster_radius + margin,
        # we will end up at a distance of cluster_radius + 0.5*margin.

        scaler = -0.5*margin/sqrt(rsq)

        x += int(scaler*dx)
        y += int(scaler*dy)
    
    return x, y
    

def simulate_single_particle(grid, center, cluster_radius, start):
    """
    Run a diffusion simulation for a single particle on a grid.
    
    * Place the particle at the given position.
    * Let the particle perform a random walk on the grid
      via :meth:`random_step`.
    * If at any time the particle wanders too far away from the cluster,
      move it closer to save simulation time.
    * Let the random walk continue until the walker
      is next to a grid point with the value 1.
    * Change the value of the grid point at the last position of
      the walker to 1. Physically, grid represents a cluster
      and this action represents a particle joining the cluster.
    * Check if the new particle increases the maximum radius of the cluster.
      The new radius is returned in the end.
      
    .. note ::
        This function is incomplete!
      
    Args:
        grid (array): the grid representing the cluster
        center (list): the center coordinates of the cluster, [x0, y0]
        cluster_radius (float): maximum distance between the center and all
            other points in the cluster
        start (list): the starting coordinates of the particle, [x, y]
    
    Returns: 
        float: the new radius
    """
        
    x = start[0]
    y = start[1]
        
    # tag for notifying when the cluster is found
    found_cluster = False    

    # run a random walk until the particle meets the cluster
    while not found_cluster:
                        
        # update coordinates
        [x,y] = random_step(x,y)
        
        # if the particle is too far from the cluster, 
        # move it a bit closer
        x, y = mind_the_gap(x, y, center, cluster_radius)
        
        # todo: Check if the particle is next to the cluster.

        # todo: Calculate the distance from the center and
        # update 'cluster_radius' if the cluster size grows.
            
    return cluster_radius
    

def create_grid(n):
    """
    Creates a grid as n x n numpy array.
    
    The grid contains 0's except the center point is 1.
    This represents the initial seed for the cluster.
    
    Args:
        n (int): size of the grid
    
    Returns: 
        array, array: grid, center coordinates [x0, y0]
    """
    
    # initialize grid
    grid = np.array( [ [0]*n ]*n )
    
    # set the cluster seed in the middle
    midpoint = int(n/2)
    grid[ midpoint, midpoint ] = 1
    center = np.array([midpoint, midpoint])
    
    return grid, center



def diffusion_simulation(grid, center, particles, recording_interval = 100):
    """
    Run a full diffusion simulation on the given grid.
        
    The grid must be an array of zeros and ones.
    It represents the space in which the simulation takes place as a
    discrete lattice model. That is, each particle to be simulated
    can only have integer coordinates.
    Each zero in the grid represents empty space.
    Each one in the grid represents a position filled by the cluster.
    
    The simulation is carried out as follows:
    
    * Particles are released, one-by-one, outside the cluster.
    * A particle diffuses until they end up at a position next to the existing cluster
      via :meth:`simulate_single_particle`.
    * When this happens, the particle attaches to the cluster.
      This is represented by changing the value of the grid to 1 at that position.
    * The simulation ends after the given number of particles have
      joined the cluster.
    
    .. note ::
        This function is incomplete!
    
    Args:
        grid (array): a square grid of integer values
        center (array): center coordinates [x0, y0]
        particles (int): number of particles to simulate
        recording_interval (int): record the state of the system after this many particles
        
    Returns:
        list: list of grids showing a time series of the evolution of the system
    """    

    start_time = time.perf_counter()
    history = []
            
    # send particles one by one
    for i in range(particles): 
                
        # todo: run the simulation

        print_progress(i+1, particles)
        
        if i%recording_interval == 0:
            history.append(copy.deepcopy(grid))
        

    end_time = time.perf_counter()
    print("simulation took "+str(end_time-start_time)+" s")
    
    return history
        



def read_grid_from_file(filename):
    """
    Read a cluster from a file.
    
    The first line of the file should contain the
    coordinates of the initial seed of the cluster.
    This should be followed by the coordinates of all
    other points in the cluster, line by line:
    
    .. code-block::
    
        x0 y0
        x1 y1
        x2 y2
        ...
    
    The function returns a grid describing the cluster and the
    coordinates of the initial seed.
    
    Args:
        filename (str): file to be read
    
    Returns: 
        array, array: grid, cluster center coordinates [x0, y0]
    """

    f = open(filename)
    lines = f.readlines()
    f.close()
    maxindex = 0
    
    parts = lines[0].split()
    center = [ int(parts[0]), int(parts[1]) ]
        
    grid = np.array( [ [False]*(center[0]*2) ]*(center[1]*2) )    

    for line in lines:
        parts = line.split()
        if len(parts) > 0:
            grid[ int(parts[0]), int(parts[1]) ] = True

    return grid, center


def linearfit(x, a, b):
    """Calculates the linear function :math:`ax+b` and returns the result.
    
    Args:
        x (float): the variable
        a (float): the slope
        b (float): the constant term
    
    Returns: 
        float: the result 
    """
    return a * x + b


def calculate_fractal_dimension(grid, center, printout=True):
    """
    Calculate the fractal dimenstion of the cluster and optionally
    display the plot from which the dimension can be determined.
    
    In :math:`d` dimensions, the size :math:`s` of an object increases as
    :math:`s = cr^d` as the linear size :math:`r` of the object grows,
    where :math:`c` is some constant.
    For instance, the area of a 2D circle is :math:`A = \\pi r^2`
    and the volume of a 3D ball is :math:`V = \\frac{4}{3} \\pi r^3`.
    
    Taking the logarithm, we get :math:`\\ln(A) = \\ln(\\pi) + 2 \\ln(r)`
    and :math:`\\ln(V) = \\ln(\\frac{4}{3}\\pi) + 3 \\ln(r)` and in general
    :math:`\\ln(s) = \\ln(c) + d \\ln(r)`. This means we can find the
    dimensionality of an object by looking at how its size grows as its
    linear length grows on a logarithmic scale.
    
    Although ideal mathematical shapes tend to have integer dimensionality,
    real objects are usually fractals at least on some lenght scales. This
    means their dimensionality can be non-integer.
    
    The fractal dimension of the cluster is calculated as follows:
    
    * Separate circles of radius R = 1,2,3,... around the cluster center.
    * Calculate the number of cluster particles, N, (grid points with value 1)
      in each circle (using :meth:`count_particles_in_circle`).
    * Optionally: plot the N(R) data on a log-log scale. 
    * Estimate the region where the data falls on a line.
    * Fit a linear function to the data, in this region. 
      The slope of the fitted line is the fractal dimension.
    
    .. note ::
        This function is incomplete!
    
    Args:
        grid (array): an array of points where the value of 1 denotes 
              a point that belongs in the cluster
        center (array): the coordinates of the initial seed of the cluster 
        printout (bool): if True, results are shown
    
    Returns: 
        float: the calculated estimate for fractal dimension
    """
    
    # maximum value for R
    cluster_radius = calculate_cluster_radius(grid,center)
    if cluster_radius > grid.shape[0]/2-1:
        cluster_radius = grid.shape[0]/2-1

    # store the statistics in these lists
    particlecount = []
    logsize = []
    logcount = []

    # todo: loop over an increasing R and count n(R)
    
    # create arrays out of lists
    logsize = np.array(logsize)
    logcount = np.array(logcount)
    
    # The first and last values in the data to be included in the fit.
    # todo: You should only use the linear part of the plot in the fitting.
    first_value = 1
    last_value = -1

    popt, pcov = curve_fit(linearfit, 
                            logsize[first_value:last_value],
                            logcount[first_value:last_value], 
                            p0 = [1.5, 2.0]) # initial guess for parameters
    
        
    if printout:
        print("estimated value for fractal dimension: "+str(popt[0]) )
        plt.plot(logsize, logcount, 'o')
        plt.plot(logsize, linearfit(logsize, popt[0], popt[1]))
        plt.show()

    return popt[0]





def write_cluster_file(grid, center):
    """
    Write the cluster in a datafile which can be
    later read in using :meth:`read_grid_from_file`.
    
    Args:
        grid (array): the grid where values of 1 represent the cluster
        center (array): cluster center coordinates [x0, y0]
    """
    size = len(grid[0])
    writelines = str(center[0])+" "+str(center[0])+"\n"
    
    for i in range(size):
        for j in range(size):
            if grid[i,j]:
                writelines += str(i)+" "+str(j)+"\n "
    f = open('cluster.txt','w')
    f.write(writelines)
    f.close()    
    
    
def calculate_cluster_radius(grid, center):
    """
    Loops over all the sites in the grid and calculates the distance
    from the center coordinates.
    Returns the maximum distance found for occupied sites.
    
    That is, the function calculates the radius of the smallest
    circle that is centered at the given center coordinates
    and which overlaps the entire cluster.
    
    
    .. note ::
        This function is incomplete!
    
    Args:
        grid (array): the grid where values of 1 represent the cluster
        center (array): cluster center coordinates [x0, y0]
    
    Returns: 
        float: the calculated cluster radius
    """
    
    # todo
                
    return 0
    
    
def calculate_particles_in_cluster(grid):
    """
    Loops over all the sites in the grid and counts the number
    of sites with value 1.
    
    Args:
        grid (array): the grid where values of 1 represent the cluster
    
    Returns: 
        int: number of particles
    """
    ilen, jlen = grid.shape
    n = 0
    
    for i in range(ilen):
        for j in range(jlen):
            if grid[i,j]:
                n += 1
                
    return n
    
    
def main(filename = None, size = 300, n_particles = 0):
    """
    The main program.
    
    Runs a simulation, shows the resulting cluster and
    calculates its fractal dimension.
    The end result is written in a file using :meth:`write_cluster_file`.
    
    Can either start from scratch (a cluster of 1 particle)
    or read in the result from a previous simulation and
    use that as the starting point.
    
    If a file is read successfully, the size parameter does nothing.
    If no file is given or an error occurs while reading, a new
    size x size simulation grid is created.
    
    Args:
        filename (str): name of the file to read for the starting configuration
        size (int): the size of the simulation space, if starting from scratch
        n_particles (int): the number of particles to simulate
    """
    
    grid = []
    center = []
    try:
        # read a file containing cluster data
        grid, center = read_grid_from_file(filename)
        print("Read starting configuration from "+str(filename))
    except:
        # if no valid file was found, create a new grid
        grid, center = create_grid(size)
        print("Created a "+str(size)+" x "+str(size)+" simulation area.")
        
    # run simulation        
    history = diffusion_simulation(grid, center, n_particles, recording_interval=max(n_particles//100,1) )
        
    # save and see the results
    write_cluster_file(grid, center)
    show_cluster(grid)
    animate(history)
        
    # you need a lot of particles to estimate the fractal dimension   
    if n_particles >= 1000:
        calculate_fractal_dimension(grid, center)

    
    
#
# Run the main program if this file is executed
# but not if it is read in as a module.
#
if __name__ == "__main__":
    random = default_rng()
    
    # edit the main function call to adjust simulation behaviour
    main(n_particles = 100)
    
else:
    random = default_rng()
