"""
This is a more complicated and specific iterative refinement of a bipartition.

Find a local minimum of a function of two sets.
The domain of the function consists of all possible bipartitions of a set.
Locality is defined with respect to the operation that moves a member
from one set to the other set.

In this module, the cost function is assumed to take a specific form:
the negative of the sum of within-cluster signed edge values.

"""
from __future__ import print_function, division, absolute_import

import numpy as np
from numpy.testing import assert_allclose

__all__ = ['faster_refinement']


def _faster_single_refinement(M, alpha, beta):
    """

    Parameters
    ----------
    M : 2d array of edge values
        Larger positive values are evidence of shared cluster membership.
        In particular, the value of the bipartition is the sum of the edge
        values within the first cluster plus the sum of the edge values
        within the second cluster.
        The cost of the bipartition is the negative of its value.
    alpha : set
        First set in the initial bipartition.
        This is a set of indices of M.
    beta : set
        Second set in the initial bipartition.
        This is a set of indices of M.

    Returns
    -------
    initial_cost : float
        Initial cost.
    final_cost : float
        Cost of the maximal bipartition.
    a : set
        First set of the refined bipartition.
    b : set
        Second set of the refined bipartition.

    """
    # Precompute functions of the blocks of M corresponding to the bipartition.
    ia = np.array(list(alpha), dtype=int)
    ib = np.array(list(beta), dtype=int)
    A = M[np.ix_(ia, ia)]
    B = M[np.ix_(ib, ib)]
    AB = M[np.ix_(ia, ib)]
    vaa = A.sum(axis=0)
    vbb = B.sum(axis=0)
    vab = AB.sum(axis=0) + np.diag(B)
    vba = AB.sum(axis=1) + np.diag(A)
    a_move_improvements = vba - vaa
    b_move_improvements = vab - vbb
    initial_cost = -(vaa.sum() + vbb.sum())

    # Check each move from the first set to the second set,
    # and from the second set to the first set.
    best_improvement = 0
    best_pair = alpha, beta
    for i, a in enumerate(ia):
        improvement = a_move_improvements[i]
        if improvement > best_improvement:
            best_improvement = improvement
            best_pair = alpha - {a}, beta | {a}
    for i, b in enumerate(ib):
        improvement = b_move_improvements[i]
        if improvement > best_improvement:
            best_improvement = improvement
            best_pair = alpha | {b}, beta - {b}

    # Return the best improvement and the corresponding bipartition.
    # The 2x factor is related to the fact that the matrix is symmetric.
    best_cost = initial_cost - 2*best_improvement
    return initial_cost, best_cost, best_pair[0], best_pair[1]


def faster_refinement(M, threshold, alpha, beta):
    """
    Repeatedly apply single refinment steps until no improvement is found.

    Parameters
    ----------
    M : 2d array of edge values
        Larger positive values are evidence of shared cluster membership.
    threshold : float
        Cost differences smaller than this value are considered negligible.
    alpha : set
        First set in the initial bipartition.
    beta : set
        Second set in the initial bipartition.

    Returns
    -------
    initial_cost : float
        Initial cost.
    final_cost : float
        Cost of the maximal bipartition.
    a : set
        First set of the maximal bipartition.
    b : set
        Second set of the maximal bipartition.

    """
    initial_cost = None
    cost = None
    a, b = alpha, beta
    while True:
        icost, fcost, na, nb = _faster_single_refinement(M, a, b)
        if initial_cost is None:
            initial_cost = icost
            cost = icost
        if icost - fcost < threshold:
            break
        cost, a, b = fcost, na, nb
    return initial_cost, cost, a, b
