from __future__ import print_function, division, absolute_import

from contextlib import contextmanager
import argparse
import os.path
import os
import csv
import time
import sys
from functools import partial
import rpy2.robjects
import rpy2.robjects.numpy2ri
import shutil as sh

import pandas as pd
import numpy as np

import scipy
from numpy.testing import assert_allclose

import seaborn
import matplotlib.pyplot as plt

from libmmc import expansion, get_clustering


def create_matplotlib_heatmap(M, filename):
    p = M.shape[0]
    plt.figure()
    plt.pcolor(M, vmin=-1, vmax=1, cmap=plt.cm.jet)
    plt.yticks(np.arange(p)+0.5, range(1, p+1))
    plt.xticks(np.arange(p)+0.5, range(1, p+1))
    plt.colorbar()
    plt.gca().invert_yaxis()
    plt.gca().set_aspect('equal')
    plt.savefig(filename)


def create_seaborn_heatmap(M, filename):
    p = M.shape[0]
    fig, ax = plt.subplots()
    if p < 20:
        ax = seaborn.heatmap(M)
    else:
        ax = seaborn.heatmap(
                M, linewidths=0.0,
                xticklabels=False, yticklabels=False, rasterized=True)
    fig.savefig(filename)


def create_rdata_heatmap(M, filename):
    robj = rpy2.robjects.numpy2ri.numpy2ri(M)
    rpy2.robjects.r.assign('my.heatmap', robj)
    rpy2.robjects.r("save(my.heatmap, file='%s', compress=TRUE)" % filename)


def nontechnical_analysis(args, df, mask, C, clustering):
    # Re-order things more palatably for the user,
    # based on the results of the technical analysis.

    # Use the command line args to define the heatmap style.
    create_heatmap = dict(
            seaborn=create_seaborn_heatmap,
            matplotlib=create_matplotlib_heatmap,
            rdata=create_rdata_heatmap,
            )[args.heatmap_style]

    # Get the map from the name to the original row index.
    all_row_names = df.index.values
    row_index_map = {s : i for i, s in enumerate(all_row_names)}

    # If some variables are uninformative for clustering,
    # the correlation matrix and the cluster vector will have smaller
    # dimensions than the number of rows in the original data frame.
    remaining_row_names = df[mask].index.values

    # Count the variables included in the clustering.
    p = clustering.shape[0]

    # Count the clusters.
    k = clustering.max() + 1

    # To sort the modules and to sort the variables within the modules,
    # we want to use absolute values of correlations.
    C_abs = np.abs(C)

    # For each cluster, get its indices and its submatrix of C_abs.
    selections = []
    submatrices = []
    degrees = np.zeros(p, dtype=float)
    for i in range(k):
        selection = np.flatnonzero(clustering == i)
        selections.append(selection)
        submatrix = C_abs[np.ix_(selection, selection)]
        submatrices.append(submatrix)
        if selection.size > 1:
            denom = selection.size - 1
            degrees[selection] = (submatrix.sum(axis=0) - 1) / denom

    # Modules should be reordered according to decreasing "average degree".
    cluster_sizes = []
    average_degrees = []
    for selection in selections:
        cluster_sizes.append(selection.size)
        average_degrees.append(degrees[selection].mean())

    module_to_cluster = np.argsort(average_degrees)[::-1]
    cluster_to_module = {v : k for k, v in enumerate(module_to_cluster)}

    triples = [(
        cluster_to_module[clustering[i]],
        -degrees[i],
        i,
        ) for i in range(p)]

    _a, _b, new_to_old_idx = zip(*sorted(triples))

    # Make a csv file if requested.
    if args.csv_out is not None:
        if args.verbose:
            print('preparing to write the csv output to',
                    os.path.abspath(args.csv_out))
        header = ('Gene', 'Module', 'Entry Index', 'Average Degree', 'Degree')
        with open(args.csv_out, 'wb') as fout:
            writer = csv.writer(fout)
            writer.writerow(header)
            for old_i in new_to_old_idx:
                name = remaining_row_names[old_i]
                cluster = clustering[old_i]
                row = (
                        name,
                        cluster_to_module[cluster] + 1,
                        row_index_map[name] + 1,
                        average_degrees[cluster],
                        degrees[old_i],
                        )
                writer.writerow(row)

    # Draw the first heatmap.
    # Plot using something like
    # http://stackoverflow.com/questions/15988413/
    if args.unsorted_heatmap is not None:
        if args.verbose:
            print('preparing to write a heatmap to',
                    os.path.abspath(args.unsorted_heatmap))
        create_heatmap(C, args.unsorted_heatmap)

    # Prepare to create the sorted heatmaps.
    # This code has a lot of room for speed improvements.
    C_new = C[np.ix_(new_to_old_idx, new_to_old_idx)]
    clustering_new = clustering[np.ix_(new_to_old_idx)]

    # Draw the second heatmap (reordered according to the clustering).
    if args.sorted_heatmap is not None:
        if args.verbose:
            print('preparing to write a heatmap to',
                    os.path.abspath(args.sorted_heatmap))
        create_heatmap(C_new, args.sorted_heatmap)

    # Draw the third heatmap (smoothed).
    if args.smoothed_heatmap is not None:
        if args.verbose:
            print('preparing to write a heatmap to',
                    os.path.abspath(args.smoothed_heatmap))

        # Make a smoothed correlation array.
        # This code has a lot of room for speed improvements.
        S = expansion(clustering_new)
        block_mask = S.dot(S.T)
        denom = np.outer(S.sum(axis=0), S.sum(axis=0))
        small = S.T.dot(C_new).dot(S) / denom
        C_all_smoothed = S.dot(small).dot(S.T)
        C_smoothed = (
                C_all_smoothed * (1 - block_mask) +
                C_new * block_mask)

        # Draw the heatmap.
        create_heatmap(C_smoothed, args.smoothed_heatmap)

#def transfer_results(resout):
#    print(resout)
#    print(os.getcwd())


def main_analysis(args, df):
    # Input includes the command line args and the data frame.
    tm0 = time.time()

    # Optionally report properties of the input csv file.
    if args.verbose:
        p, n = df.shape
        print('original number of variables in the input file:', p)
        print('number of observations per variable:', n)
        print()

    ## If there is no variance in a row, the correlations cannot be computed.
    drops = []
    i = 0
    for i in range(0,df.shape[0]):
        if ((df.iloc[i]-df.iloc[i].mean()).sum()**2)==0.0:
            drops.append(i)
    df = df.drop(df.index[drops])


    # Compute the matrix of correlation coefficients.
    C = df.T.corr(method=args.correlation).values

    # For now, ignore the possibility that a variable
    # will have negligible variation.
    mask = np.ones(df.shape[0], dtype=bool)

    # Count the number of variables not excluded from the clustering.
    p = np.count_nonzero(mask)

    # Consider all values of tuning parameter sigma in this array.
    sigmas, step = np.linspace(
            args.sigma_low, args.sigma_high, num=args.sigma_num, retstep=True)

    # Compute the clustering for each of the several values of sigma.
    # Each sigma corresponds to a different affinity matrix,
    # so the modularity matrix is also different for each sigma.
    # The goal is to the clustering whose modularity is greatest
    # across all joint (sigma, partition) pairs.
    # In practice, we will look for an approximation of this global optimum.
    clustering, sigma, m = get_clustering(C, sigmas, verbose=args.verbose)

    # Count the number of clusters.
    k = clustering.max() + 1

    # Report a summary of the results of the technical analysis.
    print('after partition refinement:')
    print('  sigma:', sigma)
    print('  number of clusters:', k)
    print('  modulated modularity:', m)

    # Optionally report the elapsed time.
    if args.verbose:
        print('cumulative elapsed wall time in seconds:', time.time() - tm0)

    # Run the nontechnical analysis using the data frame and the less nerdy
    # of the outputs from the technical analysis.
    nontechnical_analysis(args, df, mask, C, clustering)

    if args.verbose:
        print('cumulative elapsed wall time in seconds:', time.time() - tm0)
        


@contextmanager
def pushd(new_dir):
    prev_dir = os.getcwd()
    os.chdir(new_dir)
    yield
    os.chdir(prev_dir)

def main():

    # Get command line arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument('--verbose', '-v', action='store_true',
            help='Show more information on the screen.')
    parser.add_argument('--correlation',
            choices=('pearson', 'kendall', 'spearman'),
            default='pearson',
            help=(
                "Compute correlation coefficients using either "
                "'pearson' (standard correlation coefficient), "
                "'kendall' (Kendall Tau correlation coefficient), or "
                "'spearman' (Spearman rank correlation)."))
    parser.add_argument('--csv-in', required=True,
            help=(
                'Path to the data file, which is expected to be in '
                'comma-separated values (csv) format '
                'with row and column labels, and for which the rows '
                'are to be clustered.'))
    parser.add_argument('--sigma-low', type=float, default=0.05,
            help='Low value of sigma (Default: 0.05).')
    parser.add_argument('--sigma-high', type=float, default=0.50,
            help='High value of sigma (Default: 0.50).')
    parser.add_argument('--sigma-num', type=float, default=451,
            help='Number of values of sigma to search (Default: 451).')
    parser.add_argument('--heatmap-style',
            choices=('seaborn', 'matplotlib', 'rdata'),
            default='matplotlib',
            help='Use this heatmap plotting style.')

    # Get output directory for website daemon
    parser.add_argument('--results-out',
            help='Absolute path for final results directory.')

    # Specify output file names relative to the directory
    # containing the input csv file.
    parser.add_argument('--csv-out',
            help=('Output csv path '
                  'relative to the directory containing the input csv file'))
    parser.add_argument('--unsorted-heatmap',
            help=('Output unsorted correlation heatmap image file name '
                  'relative to the directory containing the input csv file'))
    parser.add_argument('--sorted-heatmap',
            help=('Output sorted correlation heatmap image file name '
                  'relative to the directory containing the input csv file'))
    parser.add_argument('--smoothed-heatmap',
            help=('Output smoothed sorted correlation heatmap image file name '
                  'relative to the directory containing the input csv file'))
    args = parser.parse_args()

    # Optionally report the command line arguments.
    if args.verbose:
        print('sys.argv:')
        print(sys.argv)
        print()
        print('numpy version:', np.__version__)
        print('pandas version:', pd.__version__)
        print('scipy version:', scipy.__version__)
        print()

    # Read the csv file as a 'pandas' data frame.
    # This has the side-effect of checking that the file actually exists.
    df = pd.io.parsers.read_csv(args.csv_in, index_col=0)

    # Temporarily change to the directory containing the input csv file.
    # Continue the analysis from there.
    csv_in_head, csv_in_tail = os.path.split(os.path.abspath(args.csv_in))
    with pushd(csv_in_head):
        main_analysis(args, df)
        ## Transfer all files to user directory
        # Make final results directory
        #if not os.path.exists(os.path.abspath(args.results_out)):
        #    os.makedirs(args.results_out)
        #files = [f for f in os.listdir('.') if os.path.isfile(f)]
    
if __name__ == '__main__':
    main()
