from __future__ import print_function, division, absolute_import

import argparse

import numpy as np
import matplotlib.pyplot as plt

from libmmc import (sample_correlation_matrix, get_clustering,
        get_corr_fig2, get_corr_fig2_desired_clustering,
        get_clustering_accuracy)


def main(args):
    sigmas, step = np.linspace(args.sigma_low, args.sigma_high,
            num=args.sigma_num, retstep=True)
    nsamples = args.samples
    R = get_corr_fig2()
    desired_v = get_corr_fig2_desired_clustering()
    ns = np.arange(4, 37, dtype=int)
    arr = []
    for n in ns:
        totals = np.zeros(3, dtype=float)
        for sample_index in range(nsamples):
            C = sample_correlation_matrix(R, n)
            observed_v, best_sigma, m = get_clustering(C, sigmas)
            accuracies = get_clustering_accuracy(observed_v, desired_v)
            totals += np.array(accuracies)
        arr.append(totals)

    results = np.array(arr) / nsamples
    a, b, c = (results[:, i] for i in range(3))
    lines = plt.plot(
            ns, a, 'red',
            ns, b, 'blue',
            ns, c, 'black')
    plt.ylabel("Accuracy")
    plt.xlabel("Number of observations")
    plt.legend(
            lines,
            ('Clustered pairs', 'Separated pairs', 'Overall'),
            loc='lower right')
    plt.savefig('fig4.svg')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--samples', type=int, default=1000)
    parser.add_argument('--sigma-low', type=float, default=0.05,
            help='Low value of sigma (Default: 0.05).')
    parser.add_argument('--sigma-high', type=float, default=0.35,
            help='High value of sigma (Default: 0.35).')
    parser.add_argument('--sigma-num', type=float, default=7,
            help='Number of values of sigma to search (Default: 7).')
    args = parser.parse_args()
    main(args)
