import numpy as np
import os
import copy

# A game of Ultimate Tic-Tac-Toe
# The board is a 3x3 grid of 3x3 tic-tac-toe boards.
# Playing in cell (r,c) of a local board sends your opponent
# to local board (r,c). 
# If sent to a finished or full board, you may play anywhere.
# Win 3 local boards in a row to win.

n_players = 2

# Symbols for players
P1 = "X"
P2 = "O"
PSYM = [" ", P1, P2]

# Size of the meta-board and local boards
META = 3   # 3x3 arrangement of local boards
LOC  = 3   # each local board is 3x3


def new_game():
    """
    Creates an empty Ultimate Tic-Tac-Toe board.

    State is [grid, next_board] where:
      grid is a 9x9 numpy array of player tokens (0 = empty)
      next_board is a (row, col) tuple indicating which local board
      the current player must play in, or None if anywhere is allowed.

    Player 1 starts and may play anywhere on the first move.

    Returns: 
        list: the initial game state [grid, next_board]
    """
    grid = np.zeros([META * LOC, META * LOC], dtype=int)
    return [grid, None]   # None = free to play anywhere


def next_player(player):
    """
    The player whose turn comes after the given player.

    Args:
        player (int): the player to check
    
    Returns: 
        int: the number of the next player
    """
    return player % n_players + 1


def previous_player(player):
    """
    The player whose turn came before the given player.

    Args:
        player (int): the player to check
    
    Returns: 
        int: the number of the previous player
    """
    return player % n_players + 1


def _local_winner(grid, br, bc):
    """
    Checks if a local board has been won.

    Args:
        br (int): row index of the local board (0-2)
        bc (int): column index of the local board (0-2)
        grid (array): the full 9x9 play area

    Returns: 
        int: the winning player number, or 0 if no winner
    """
    # Extract the 3x3 local board
    r0 = br * LOC
    c0 = bc * LOC
    local = grid[r0:r0+LOC, c0:c0+LOC]

    # Check rows, columns and diagonals
    for i in range(LOC):
        if local[i, 0] != 0 and np.all(local[i, :] == local[i, 0]):
            return local[i, 0]
        if local[0, i] != 0 and np.all(local[:, i] == local[0, i]):
            return local[0, i]

    if local[0, 0] != 0 and local[0,0] == local[1,1] == local[2,2]:
        return local[0, 0]
    if local[0, 2] != 0 and local[0,2] == local[1,1] == local[2,0]:
        return local[0, 2]

    return 0


def _meta_board(grid):
    """
    Builds the 3x3 meta-board showing who has won each local board.

    Args:
        grid (array): the full 9x9 play area

    Returns:
        array: local board winners
    """
    meta = np.zeros([META, META], dtype=int)
    for br in range(META):
        for bc in range(META):
            meta[br, bc] = _local_winner(grid, br, bc)
    return meta


def _meta_winner(meta):
    """
    Checks if a player has won the meta-board (3 local boards in a row).

    Args:
        meta (array): 3x3 array of local board winners

    Returns: 
        int: winning player number, or 0 if no winner
    """
    for i in range(META):
        if meta[i, 0] != 0 and np.all(meta[i, :] == meta[i, 0]):
            return meta[i, 0]
        if meta[0, i] != 0 and np.all(meta[:, i] == meta[0, i]):
            return meta[0, i]

    if meta[0, 0] != 0 and meta[0,0] == meta[1,1] == meta[2,2]:
        return meta[0, 0]
    if meta[0, 2] != 0 and meta[0,2] == meta[1,1] == meta[2,0]:
        return meta[0, 2]

    return 0


def _local_board_playable(grid, br, bc):
    """
    Checks if a local board can still be played in.
    A board is unplayable if it has been won or is completely full.

    Args:
        grid (array): the full 9x9 play area
        br (int): row index of the local board (0-2)
        bc (int): column index of the local board (0-2)

    Returns:
        boolean: True if the board is playable, False otherwise
    """
    if _local_winner(grid, br, bc) != 0:
        return False
    r0 = br * LOC
    c0 = bc * LOC
    return np.any(grid[r0:r0+LOC, c0:c0+LOC] == 0)


def _valid_moves_in_board(grid, br, bc):
    """
    Lists all empty cells in a given local board.

    Args:
        grid (array): the full 9x9 play area
        br (int): row index of the local board (0-2)
        bc (int): column index of the local board (0-2)

    Returns: 
        list: (row, col) positions in the full 9x9 grid
    """
    moves = []
    r0 = br * LOC
    c0 = bc * LOC
    for r in range(LOC):
        for c in range(LOC):
            if grid[r0+r, c0+c] == 0:
                moves.append((r0+r, c0+c))
    return moves


def all_plays(state):
    """
    Lists all valid moves for the current player.

    If next_board points to a playable local board, moves are
    restricted to that board. Otherwise the player may play
    in any cell of any unfinished local board.

    Args:
        state (list): [grid, next_board]

    Returns: 
        list: valid (row, col) positions in the full 9x9 grid
    """
    grid, next_board = state
    moves = []

    if next_board is not None:
        br, bc = next_board
        if _local_board_playable(grid, br, bc):
            return _valid_moves_in_board(grid, br, bc)

    # Free choice: collect moves from all playable local boards
    for br in range(META):
        for bc in range(META):
            if _local_board_playable(grid, br, bc):
                moves.extend(_valid_moves_in_board(grid, br, bc))

    return moves


def all_reasonable_plays(state, player=None):
    """
    Lists all valid moves. In Ultimate TTT all valid moves are reasonable.

    Args:
        state (list): [grid, next_board]
        player (int): ignored (included for compatibility with play_game)

    Returns: 
        list: valid (row, col) positions
    """
    return all_plays(state)


def can_continue(state):
    """
    Checks if the game can continue.

    The game ends if someone has won the meta-board or if there
    are no valid moves left.

    Args:
        state (list): [grid, next_board]

    Returns: 
        boolean: True if the game is still going, False otherwise
    """
    grid, _ = state
    if _meta_winner(_meta_board(grid)) != 0:
        return False
    return len(all_plays(state)) > 0


def check_winner(state):
    """
    Checks if someone has won the game.

    Returns 0 if the game is still in progress or ends in a draw.

    Args:
        state (list): [grid, next_board]

    Returns: 
        int: winning player number, or 0
    """
    grid, _ = state
    return _meta_winner(_meta_board(grid))


def make_move(play, player, state):
    """
    Places a token and updates which local board is active next.

    The local board for the next move is determined by the cell
    (r,c) within the local board where the current move was made.
    If that local board is already finished, next_board is set
    to None (free choice).

    Args:
        play (tuple): (row, col) in the full 9x9 grid
        player (int): the player making the move
        state (list): [grid, next_board] - modified in place
    """
    grid = state[0]
    row, col = play

    # Place the token
    grid[row, col] = player

    # Determine next active board from the local position of this move
    next_br = row % LOC
    next_bc = col % LOC

    if _local_board_playable(grid, next_br, next_bc):
        state[1] = (next_br, next_bc)
    else:
        state[1] = None   # free choice


def copy_game(state):
    """
    Returns a deep copy of the game state.

    Args:
        state (list): [grid, next_board]
        
    Returns: 
        list: a copy of state
    """
    grid, next_board = state
    return [copy.deepcopy(grid), next_board]


def draw(state):
    """
    Draws the full Ultimate Tic-Tac-Toe board with text graphics.

    Each local 3x3 board is drawn with its own grid lines.
    Local boards are separated by empty space.
    Active boards are highlighted with * instead of | and - separators.
    Row and column numbers are shown for easy move entry.

    Args:
        state (list): [grid, next_board]
    """
    grid, next_board = state
    meta = _meta_board(grid)

    os.system('cls' if os.name == 'nt' else 'clear')

    # Determine which local boards are highlighted
    if next_board is None:
        highlighted = {(br, bc) for br in range(META)
                                 for bc in range(META)
                                 if _local_board_playable(grid, br, bc)}
    else:
        br, bc = next_board
        if _local_board_playable(grid, br, bc):
            highlighted = {(br, bc)}
        else:
            highlighted = {(br, bc) for br in range(META)
                                     for bc in range(META)
                                     if _local_board_playable(grid, br, bc)}

    # Each local board is 11 chars wide: " x | x | x "
    # Gap between local boards
    GAP = "    "

    # Column header: global column indices 0-8
    header = "      "   # left margin
    for meta_c in range(META):
        for local_c in range(LOC):
            col_idx = meta_c * LOC + local_c
            header += f" {col_idx}  "
        if meta_c < META - 1:
            header += GAP[:-1]
    print(header)
    print()

    for meta_r in range(META):
        for local_r in range(LOC):
            row_idx = meta_r * LOC + local_r

            # Data row: " x | x | x "
            line = f"  {row_idx}   "

            for meta_c in range(META):
                r = meta_r * LOC + local_r
                c0 = meta_c * LOC
                w = meta[meta_r, meta_c]
                active = (meta_r, meta_c) in highlighted
                sep = "#" if active else "|"

                for local_c in range(LOC):
                    c = c0 + local_c
                    if w != 0:
                        line += f" {PSYM[w]} "   # won board: fill with winner
                    elif grid[r, c] == 0:
                        line += "   "
                    else:
                        line += f" {PSYM[grid[r, c]]} "
                    if local_c < LOC - 1:
                        line += sep

                if meta_c < META - 1:
                    line += GAP

            print(line)

            # Separator row between local rows (not after the last row)
            if local_r < LOC - 1:
                sep_line = "      "   # left margin
                for meta_c in range(META):
                    active = (meta_r, meta_c) in highlighted
                    if active:
                        sep_line += "###########"
                    else:
                        sep_line += "---+---+---"
                    if meta_c < META - 1:
                        sep_line += GAP
                print(sep_line)

        # Empty line between meta-rows
        if meta_r < META - 1:
            print()

    print()
    w1 = np.sum(meta == 1)
    w2 = np.sum(meta == 2)
    print(f"  Local boards won  —  {P1}: {w1}   {P2}: {w2}")
    if next_board is None:
        print("  Active board: free choice")
    else:
        print(f"  Active board: {next_board}")
    print()


def ask_for_move(player, state):
    """
    Asks a human player for a move.

    Displays the board and valid moves, then reads row and column.

    Args:
        player (int): the player whose turn it is
        state (list): [grid, next_board]

    Returns: 
        int, int: (row, col) in the full 9x9 grid
    """
    draw(state)
    valid = all_plays(state)

    print("You are " + PSYM[player])
    print("  Valid moves: " + ", ".join(str(m) for m in valid))
    print("  (Enter row and column in the full 9x9 grid)")

    while True:
        try:
            row = int(input("  Enter row (0-8): "))
            col = int(input("  Enter col (0-8): "))
        except:
            print("  Please enter integers (or -9 to quit)")
            continue

        if row == -9 or col == -9:
            quit()

        if (row, col) in valid:
            return (row, col)
        else:
            print(f"  ({row}, {col}) is not a valid move. Try again.")


def declare_winner(winner):
    """
    Announces the winner or a draw.

    Args:
        winner (int): winning player number, or 0 for a draw
    """
    if winner == 0:
        print("The game ended in a draw!")
    else:
        print(PSYM[winner] + " won!")