# /usr/bin/env python
import numpy as np
import scipy.optimize as opt
import matplotlib.pyplot as plt
import matplotlib.animation as ani
from math import *
import copy
from numpy.random import default_rng


def print_progress(step, total):
    """
    Prints a progress bar.
    
    Args:
        step (int): progress counter
        total (int): counter at completion
    """

    message = "simulation progress ["
    total_bar_length = 60
    percentage = int(step / total * 100)
    bar_fill = int(step / total * total_bar_length)
    for i in range(total_bar_length):
        if i < bar_fill:
            message += "|"
        else:
            message += " "
    
    message += "] "+str(percentage)+" %"
    if step < total:
        print(message, end="\r")     
    else:
        print(message) 
        
class Pixel:
    """
    A single pixel of a black-and-white bitmap image.
    
    Args:
        x (int) : x coordinate
        y (int) : y coordinate
        color (int) : 0 for white and 1 for black pixels
    """

    def __init__(self, x, y, color):
        self.x = x
        self.y = y
        self.color = color


class Circle:
    """
    A circle.
    
    Used for building an approximation of a bitmap image.
    
    Args:
        x (float): center x coordinate
        y (float): center y coordinate
        r (float): radius
    """

    def __init__(self, x, y, r):
        self.x = x
        self.y = y
        self.r = r
        
    def randomize(self):
        """
        Randomize the center and radius of the circle.
        """
        self.x = random.random()
        self.y = random.random()
        self.r = 0.09*random.random()+0.01

    def is_above_line(self, x0, y0, k):
        """
        Checks if the center of the circle is above the line 
        passing through point (x0,y0) with slope k, (y-y0) = k(x-x0).
        
        Args:
            x0 (float): x coordinate
            y0 (float): y coordinate
            k (float): slope
            
        Returns:
            bool: True, if the center is above the given line.
        """
        return self.y > k*(self.x - x0) + y0
        
    def is_inside_box(self,x0,x1,y0,y1):
        """
        Checks if the center of the circle is inside a rectangular area
        defined by the points (x0,y0), (x1,y0), (x0,y1), (x1,y1).
        
        Args:
            x0 (float): x coordinate of lower left corner
            y0 (float): y coordinate of lower left corner
            x1 (float): x coordinate of upper right corner
            y1 (float): y coordinate of upper right corner
            
        Returns:
            bool: True, if the center is inside the box
        """
    
        return ( self.x < x1 ) and ( self.x > x0 ) and ( self.y < y1 ) and ( self.y > y0 )
        
    def alter(self):
        """
        Make a small random change in the center and radius of the circle.
        """
        self.x = min(1, max(0, self.x + 0.1*random.random()-0.05))
        self.y = min(1, max(0, self.y + 0.1*random.random()-0.05))
        self.r = min(0.15, max(0.01, self.r + 0.08*random.random()-0.04))

    def covers(self, x,y):
        """
        Checks if the circle covers the point (x,y).
        
        This is true, if the distance between the given point
        and the center of the circle is smaller than the radius of the circle.
        
        Args:
            x (float): x coordinate
            y (float): y coordinate
            
        Returns:
            bool: True, if circle covers the point.
        """
        return (x-self.x)**2 + (y-self.y)**2 < self.r**2



class Solution:
    """
    A solution to the covering problem.
    
    The solution is a set of circles that are supposed to cover
    a given set of points and not cover another set of points.
    
    The solution also has a fitness value which determines how
    good a solution it is. The higher the fitness the better the solution.
    
    Typically the solution is created as the offspring of two previous
    solutions. If such parent solutions are not given, a completely
    random solution is generated.
    
    If a reference is given, the fitness of the created solution
    is immediately calculated with :meth:`Solution.calculate_fitness` and saved.
    
    Args:
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
        mom (Solution): a parent
        dad (Solution): another parent
    """

    def __init__(self, reference, mom = None, dad = None):
        self.circles = []
        if mom is None:
                self.create_random_solution()
        else:
            self.crossover(mom, dad)
        if reference is not None:
            self.fitness = self.calculate_fitness(reference)
        else:
            self.fitness = 0           
    

    def create_random_solution(self):
        """
        Create the solution as a random set of :class:`Circle` objects.
        
        This is done by repeatedly calling :meth:`Solution.add_circle`.
        """
        self.circles = []
        
        n = 20
        
        for i in range(n):
            self.add_circle()
            
        

    def crossover(self, mom, dad):
        """
        Randomly pick a crossover operation.
        
        The options are :meth:`Solution.crossover_line`
        and :meth:`Solution.crossover_box`.
        
        Args:
            mom (Solution): a parent
            dad (Solution: a second parent
        """
        R = random.random()
        if R < 0.5:
            self.crossover_box(mom, dad)
        else:
            self.crossover_line(mom, dad)
    
        
    def crossover_line(self, mom, dad):
        """
        Create the solution as a crossover of two parents.
        
        The crossover is created as follows:
        
            * A random point is drawn in :math:`[0,1] \\times [0,1]`.
            * A random slope :math:`k` is drawn.
            * Define a line that passes through the point at the given slope.
            * Take all the :class:`Circle` objects of the first parent that are above the line.
            * Take all the :class:`Circle` objects of the second parent that are below the line.
            * Create the new solution as a combination of these :class:`Circle` objects.
            
        Args:
            mom (Solution): a parent
            dad (Solution): a second parent
        """
        self.circles = []
        
        x = 0.02+0.06*random.random()
        y = 0.02+0.06*random.random()
        k = tan(random.random()*0.8*pi-0.4*pi)
        
        for m in mom.circles:
            if m.is_above_line(x, y, k):
                self.circles.append(Circle(m.x,m.y,m.r))

        for d in dad.circles:
            if not d.is_above_line(x, y, k):
                self.circles.append(Circle(d.x,d.y,d.r))
        

    def crossover_box(self, mom, dad):
        """
        Create the solution as a crossover of two parents.
        
        The crossover is created as follows:
        
            * Two random points are drawn in :math:`[0,1] \\times [0,1]`.
            * Define a rectangle with these points as corners.
            * Take all the :class:`Circle` objects of the first parent that are in the box.
            * Take all the :class:`Circle` objects of the second parent that are outside the box.
            * Create the new solution as a combination of these :class:`Circle` objects.
            
        Args:
            mom (Solution): a parent
            dad (Solution): a second parent
        """
        self.circles = []
        
        x0 = 0.9*random.random()
        y0 = 0.9*random.random()
        x1 = 0.1+x0+(0.9-x0)*random.random()
        y1 = 0.1+y0+(0.9-y0)*random.random()
                
        for m in mom.circles:
            if m.is_inside_box(x0, y0, x1, y1):
                self.circles.append(Circle(m.x,m.y,m.r))

        for d in dad.circles:
            if not d.is_inside_box(x0, y0, x1, y1):
                self.circles.append(Circle(d.x,d.y,d.r))
        
        
    def add_circle(self,x=0,y=0,r=0):
        """
        Adds a randomly placed :class:`Circle` in the solution.
        """
        if x == 0:
            x = random.random()
        if y == 0:
            y = random.random()
        if r == 0:
            r = 0.01 + 0.09*random.random()

        new_circle = Circle(x,y,r)
        self.circles.append(new_circle)
        

    def remove_circle(self, threshold = 3):
        """
        Randomly removes one :class:`Circle` from the solution.
        
        Args:
            threshold (int): Only remove the circle if there are more :class:`Circle` s than this.
        """
        if len(self.circles) > threshold:
            index = random.integers(0,len(self.circles)-1)
            del self.circles[index]
            

    def mutate(self,reference):
        """
        Randomly changes the solution.
        
        Each :class:`Circle` has a 5 % chance of being changed by
        :meth:`Circle.alter`.
        In addition, there is a 20 % chance a :class:`Circle` is randomly added
        with :meth:`Solution.add_circle` and a 20 % chance a :class:`Circle` is
        removed with :meth:`Solution.remove_circle`. When a circle is added,
        its center point will be the coordinates of a randomly chosen black pixel.
        
        After the mutation, the fitness of the new solution is calculated
        with :meth:`Solution.calculate_fitness`.
        
        Args:
            reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
        """
        
        for c in self.circles:
            randomizer = random.random()
            if randomizer < 0.05:
                c.alter()
                
        randomizer = random.random()
        if randomizer < 0.2:
            self.remove_circle()
            
        randomizer = random.random()
        if randomizer < 0.2:
            pixel = random.choice(reference)
            x = pixel.x
            y = pixel.y
            self.add_circle(x,y,0)
            
        self.fitness = self.calculate_fitness(reference)
            
  
    def black( self, x, y ):
        """
        Checks if the solution covers the point (x,y).
        
        Returns 1 if at least one class:`Circle` covers the point,
        and 0 if none of the circles cover the point.
        
        The name of the method refers to the color of the pixel being covered.
        The solution is supposed to cover all black pixels but not white ones.
        
        Args:
            x (float): x coordinate
            y (float): y coordinate
            
        Returns:
            int: 1, if the solution covers the point.
        """
        for c in self.circles:
            if c.covers(x,y):
                return 1
                
        return 0


    def calculate_fitness( self, reference ):
        """
        Calcualtes the overall fitness of the solution.
        
        The solution is good if its circles cover all the balck pixels
        but none of the white ones. It's also better to have as few circles
        as possible.        
        Therefore function checks each :class:`Pixel` in the reference image
        and adds 1 to the fitness for each correctly coverd or uncovered pixel.
        A penalty of 0.1 is given for each circle.
        
        Args:
            reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover

        returns:
            float: the fitness
        """
        correct = 0
        for pixel in reference:
            x = pixel.x
            y = pixel.y
            color = pixel.color
            solution_color = self.black(x,y)
            if solution_color == color:
                correct += 1
                
        return correct - 0.1*len(self.circles)


    def draw( self, fig, scale, alpha = 1.0 ):
        """
        Draws the solution as a set of circles.
        
        Args:
            fig (matplotlib figure): the figure to draw on
            scale (float): scaling factor to change the size of the image
            alpha (float): the alpha (opacity) of the circles to be drawn 
        """
        for c in self.circles:
            fig.gca().add_artist(plt.Circle( (scale*c.x, scale*c.y), scale*c.r , color='r', alpha=alpha) )
        
    

def draw(frame, history, image):
    """
    Draws the populations.
    
    Used for animation.
    
    Args:
        frame (int): index of the frame to draw
        history (list): list of populations at different times
        image (array): number array representing the image to cover
    """
    population = history[frame]
    plt.clf()
    plt.imshow(image, cmap='Greys', interpolation='nearest')
    size = image.shape[0]
    for solu in population:
        fig = plt.gcf()
        solu.draw( fig, size, alpha=0.1 )
        
        
def animate(history, image):
    """
    Animates the simulation.
    
    Args:
        history (list): list of populations at different times
        image (array): number array representing the image to cover
    """
    nframes = len(history)
    print("animating "+str(nframes)+" frames")
    fig = plt.figure()
    motion = ani.FuncAnimation(fig, draw, nframes, fargs=( history, image ) )
    plt.show(block=True)
    plt.close(fig)
    
    
def draw_population_to_file(population, reference, image, filename):
    """
    Takes an entire set of :class:`Solution` objects and draws them
    superimposed on top of each other.
    
    Args:
        population (list): a set of :class:`Solution` objects
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
        image (array): number array representing the image to cover
        filename (str): name of the file tow rite
    """
    
    plt.clf()
    plt.imshow(image, cmap='Greys', interpolation='nearest')
    size = image.shape[0]
    for solu in population:
        fig = plt.gcf()
        solu.draw( fig, size, alpha=0.1 )
    plt.savefig(filename, bbox_inches='tight')
    plt.clf()


def draw_solution(image, sol):
    """
    Draws the image as a bitmap and the solution as a collection of circles.
    
    Args:
        image (array): number array representing the image to cover
        sol (Solution): the solution to visualize
    """
    plt.imshow(image, cmap='Greys', interpolation='nearest')
    fig = plt.gcf()
    size = image.shape[0]
    sol.draw( fig, size, alpha=0.5 )
    plt.show()


def write_solution_to_file(solution, filename):
    """
    Writes the center coordinates and radii of the circles
    of the solution to a file.
    
    Args:
        solution (Solution): the solution to record
        filename (str): name of the file to write
    """

    f = open(filename, 'w')
    for c in solution.circles:
        output = str(c.x)+", "+str(c.y)+", "+str(c.r)+"\n"
        f.write(output)
        
    f.close()
    
    
def read_solution_from_file(reference, filename):
    """
    Reads a :class:`Solution` from a file.
    
    The file should contain the center coordinates and radii
    of the circles that make up the solution, separated by commas, 
    as recorded by :meth:`write_solution_to_file`.
    
    Args:
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
        filename (str): name of the file to read
        
    Returns:
        Solution: the solution constructed from the information in the file
    """
    f = open(filename)
    lines = f.readlines()
    f.close()
    
    sol = Solution(reference)
    sol.circles = [] 
    
    for l in lines:
        parts = l.split(",")
        x = float(parts[0])
        y = float(parts[1])
        r = float(parts[2])
        sol.add_circle(x,y,r)
   
    sol.fitness = sol.calculate_fitness(reference)

    return sol
    


def read_image(filename=None):
    """
    Tries to read a black and white bitmap image and
    create an array of :class:`Pixel` objects based on it.
    
    If reading the file fails, the function returns
    a default array representing a star.
    
    Args:
        filename (str): name of the file to read
        
    Returns:
        array, array: the image as (number array, :class:`Pixel` object array)
    """

    # Read the image to cover. 1 = Black, 0 = White.
    try:
        image = plt.imread(filename) # read the image
        image = np.round(image+0.1,0) # flatten the image to black and white if not already
        image = image[:,:,0] # take the first channel
        image = 1-image # invert colors
    
    except: # default image
        image = np.array(\
        [[ 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], \
         [ 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ], \
         [ 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0 ], \
         [ 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0 ], \
         [ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 ], \
         [ 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1 ], \
         [ 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0 ], \
         [ 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0 ], \
         [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0 ], \
         [ 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0 ], \
         [ 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0 ], \
         [ 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0 ], \
         [ 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0 ], \
         [ 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0 ], \
         [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] \
          ] \
        )
    
    # Turn the number array to a Pixel array.
    # It will be truncated to square array even if the image is not a square.
    reference = []
    size = min(image.shape[0], image.shape[1])
    delta = 1.0/size
    for i in range(size):
        for j in range(size):
            reference.append(Pixel( j*delta, i*delta, image[i,j] ))
    
    return image, reference
    
    
def fill_population(population, pop_size, reference):
    """
    Fills the population with random solutions up to the given size.
    
    Args:
        population (list): a set of :class:`Solution` objects
        pop_size (int): the number of solutions in the populations in the end   
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
    """

    for i in range(len(population), pop_size):
        new_solution = Solution(reference)
        population.append(new_solution)


def sort_by_fitness(population):
    """
    Sorts the population from best to worst.
        
    The sorting is done in place by changing the given list of solutions.
    
    After the operation, the best solution will be population[0].
    
    Args:
        population (list): a set of :class:`Solution` objects
    """
    population.sort(key= lambda sol: -sol.fitness)


def kill_the_unfit(population, kill_percentage=0.5):
    """
    Kills a given portion of the population.
    
    The population is first sorted via :meth:`sort_by_fitness` so that
    the best solutions are in the beginning of the list.
    Then, the list is truncated from the end.
    
    If the population is very homogenous, individuals are killed
    randomly so that the same solutions don't always survive.

    By default, half of the solutions are killed. This fraction can be
    adjusted using kill_percentage. This number must be
    strictly larger than 0 and smaller than 1.
    
    Args:
        population (list): a set of :class:`Solution` objects
        kill_percentage (float): The fraction of solutions to kill.
    """
    sort_by_fitness(population)
    survivors = int( (1-kill_percentage)*len(population) )
    
    fit_delta = (population[0].fitness - population[survivors].fitness) / population[0].fitness

    # there is variance, kill the least fit
    if fit_delta > 0.01:
        del population[survivors+1:]
        
    # there isn't much variance, kill randomly to give mutations a chance to survive
    else:
    
        pop_size = len(population)
        kill_quota = pop_size-survivors
        
        # try to kill each solution individually, starting
        # from the least fit
        for i in range(pop_size-1,0,-1):

            if random.random() < kill_percentage:
                del population[i]
                kill_quota -= 1
                
                if kill_quota == 0:
                    break

    
def best_solution(population, already_sorted=False):
    """
    Finds the best solution in the population.
    
    If the list of solutions is already sorted, the function only
    picks population[0] as the best solution. Otherwise it first
    calls :meth:`sort_by_fitness` to sort the solutions.
    
    Args:
        population (list): a set of :class:`Solution` objects
        already_sorted (bool): if True, the best solution is assumed to be the first in population
        
    Returns:
        Solution: the best solution
    """
    if not already_sorted:
        sort_by_fitness(population)
    
    return population[0]



def breed(population, pop_size, reference):
    """
    Fill the population with new solutions by breeding
    the members of the previous generation.
    
    Breeding is done by randomly choosing two different
    solutions from population and combining them with
    :meth:`Solution.crossover`. This creates offspring that
    may potentially combine the good traits of their parents.
    
    In addition, the offspring are allowed to :meth:`Solution.mutate`
    in order to add variance.
    
    .. note ::
        This function is incomplete!
        
    Args:
        population (list): a set of :class:`Solution` objects
        pop_size (int): size of the population in the end
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
    """

    # we want to add pop_size - len(population) new children in the population
    for i in range(len(population), pop_size):
        mom = 0
        dad = 0
        
        # todo                      
                        
  
def mutate_all(population, reference):
    """
    Mutates all solutions in the population.
    
    If the population becomes very homogenous, crossover operations
    do not change it much and mutations are unlikely to survive.
    Therefore one may mutate all solutions to create diversity.
    
    Args:
        population (list): a set of :class:`Solution` objects
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
    """
    for s in population:
        s.mutate(reference)
  
    
def run_evolution(population, reference, n_generations, kill_percentage=0.5):
    """
    Runs the evolutionary optimization algorithm.
    
    The algorithm tries to find an optimal :class:`Solution` by
    taking a population of solutions, killing the least fit ones
    with :meth:`kill_the_unfit` and allowing the survivors to :meth:`breed`, 
    creating new solutions.    
    
    Args:
        population (list): a set of :class:`Solution` objects
        reference (list): an array of :class:`Pixel` objects defining the points to cover and not cover
        n_generations (int): number of generations (iterations)
        kill_percentage (float): the fraction of solutions to kill each generation
        
    Returns:
        list, list: fitness of the best solution after each generation, population history
    """
            
    pop_size = len(population)
    fitness_evolution = []
    history = []
    scale_coefficient = 1.0
    worst_fitness = 0
    
    for iteration in range(n_generations):

        kill_the_unfit( population, kill_percentage )
        breed(population, pop_size, reference)
        
        best_sol = best_solution(population, already_sorted=True)
        fitness_evolution.append(best_sol.fitness)
        print_progress(iteration+1, n_generations)
        
        if (50*iteration)%n_generations == 0:
            history.append( copy.deepcopy(population) )
                
    return fitness_evolution, history
    
    
    
def main(n_runs, n_gens, pop_size, image_file, read_from_file = False):
    """
    The main program.
    
    Reads the image to and runs an evolutionary optimization
    to create a set of circles that approximates the image.
    
    After the optimization, the solution is drawn and saved.
    
    Args:
        n_runs (int): number of optimization runs
        n_gens (int): number of generations in a run
        pop_size (int): population size
        image_file (str): name of the image file
        read_from_file (bool): if True, an existing solutions is read from best_solution.csv
    """
    
    # read the image to approximate
    image, reference = read_image(image_file)

    population = []
    convergence = []
    history = []
    cont_file = "best_solution.csv"

    # create the initial population
    if read_from_file:
        try:
            population.append( read_solution_from_file(reference, cont_file) )
        except:
            print("could not read "+cont_file)
            
    fill_population(population, n_population, reference)
    best_sol = copy.deepcopy( best_solution(population) )

    for run in range(n_runs):
        print("starting optimization "+str(run+1)+" / "+str(n_runs))

        # evolve the population
        conv, hist = run_evolution(population, reference, n_gens)
        convergence += conv
        history += hist
        
        # check if we have the best solution
        current_best = best_solution(population)
        if current_best.fitness > best_sol.fitness:
            best_sol = copy.deepcopy( current_best )
            
        # mutate all solutions to add population variance
        mutate_all(population, reference)
    
    print( "final best fitness = ", best_sol.fitness )

    # visualize and save the solution
    write_solution_to_file(best_sol, cont_file)
    draw_population_to_file(population, reference, image, "final_generation.png")
    draw_solution(image, best_sol)
    animate( history, image )

    # visualize evolution of the best solution
    plt.plot(convergence)
    plt.xlabel("generation")
    plt.ylabel("best solution fitness")
    plt.show()    
    


if __name__ == "__main__":
    
    n_population = 10
    n_runs = 1
    n_gens = 100
    random = default_rng()
    
    # optimize for a small star
    main(n_runs, n_gens, n_population, None)

    # optimize for a large star
    #main(n_runs, n_gens, n_population, "starman.png")
    
    # optimize for Mona Lisa
    #main(n_runs, n_gens, n_population, "lisa.png")
    




