import time

import pandas as pd
import numpy as np

import bson

import mongo_setup as mongo_setup
from specific_seq_samplings import Sequence

from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, DataRange1d, Select, PreText, Select, CheckboxGroup
from bokeh.models import LinearColorMapper, SetValue, Dropdown
from bokeh.palettes import Blues4
from bokeh.plotting import figure, show
from bokeh.models.callbacks import CustomJS

tissues = ['Adrenal Gland','Brain','Colon','Esophagus','Fallopian tube','Heart','Kidney','Liver','Lung',\
           'Ovary','Pancreas','Prostate','Salivary gland','Small intestine','Smooth muscle','Spleen',\
           'Stomach','Testis','Thyroid']
rbps = list(np.loadtxt('/home/olafsohr/RBP/src/data/kd_absolute.csv',dtype=object,delimiter=',',max_rows=1)[1:])
rbps.append("Unbound")

def main():
    
    startTime = time.process_time()

    mongo_setup.global_init()
    
    bokeh_comp_plot("ENST00000631435.1")

    print('Time: %ss'%(round(time.process_time()-startTime, 2)))


def retrieve_data(tx_name):
    """
    Creates a dictionary of sampling matrix per tissue type and retreives sequence for a specific transcript
    """
    #Retrieve transcript data from database    
    found_tx = list(Sequence.objects(rna_name=tx_name).all())
    found_tissues = [x.tissue_name for x in found_tx]
    found_matrix = np.array([x.sampling_matrix for x in found_tx])
    found_seq = [x.rna_sequence for x in found_tx]

    #initiate dictionary
    df = pd.DataFrame(columns=["RBP", "Sequence", "Probability","Tissue"])
    idx = ['%s_%s'%(x,found_seq[0][x]) for x in range(len(list(found_seq[0])))]

    #add specific tissue data to dictionary
    for i in range(len(found_tissues)):
        tissue = found_tissues[i]
        values = found_matrix[i]
        df1 = pd.DataFrame(columns=idx, index=list(rbps), data=values)
        df1 = pd.DataFrame(df1.stack(), columns=["Probability"]).reset_index()
        df1.columns = ["RBP", "Sequence", "Probability"]
        df1['Tissue'] = tissue
        df = pd.concat([df, df1])

    return df, found_seq[0]


def nix(val, lst):
    return [x for x in lst if x != val]


def bokeh_comp_plot(tx_name):

    df, fullSeq = retrieve_data(tx_name)

    idx = ['%s_%s'%(x,fullSeq[x]) for x in range(len(list(fullSeq)))] 
    source = ColumnDataSource(data=dict(seq=[], rbp=[], t1_prob=[], t2_prob=[]))
    mapper = LinearColorMapper(palette="Magma256", low=0, high=1)
    

    p1 = figure(x_range=idx, y_range=list(rbps),width=15*len(fullSeq), height=500,\
                x_axis_location = "below", toolbar_location="below",\
                tools = "hover,save,pan,box_zoom,reset,wheel_zoom,box_select")

    p1.rect(x="seq", y="rbp", width=1, height=1, source=source,\
            fill_color={'field':'t1_prob', 'transform':mapper},\
            line_color=None)


    p2 = figure(x_range=idx, y_range=list(rbps), width=15*len(fullSeq), height=500,\
                x_axis_location = "below", toolbar_location="below",\
                tools = "hover,save,pan,box_zoom,reset,wheel_zoom,box_select")

    p2.rect(x="seq", y="rbp", width=1, height=1, source=source,\
            fill_color={'field':'t2_prob', 'transform':mapper},\
            line_color=None)


    dropdown1 = Dropdown(label="Select Tissue", menu=list(zip(tissues, tissues)))
    dropdown1.js_on_event("menu_item_click", CustomJS(code="console.log('dropdown: ' +'this.item, this.toString())"))
#    ticker1 = Select(value='Heart', options=nix('Smooth muscle',tissues))
#    callback = SetValue(obj=ticker1, attr="value", value=ticker1.value)
#    ticker1.js_on_change("value", callback)
    def ticker1_change(attrname, old, new):
        ticker2.options = nix(new, tissues)
        update()

#    ticker1.js_on_change("value", ticker1_change)

#    ticker2 = Select(value='Smooth muscle', options=nix('Heart', tissues))
#    ticker2.js_on_change("value", callback)
    def ticker2_change(attrname, old, new):
        ticker1.options = nix(new, tissues)
        update()
#    ticker2.js_on_change("value", ticker2_change)

    def get_data(t1, t2):
        d1 = df[(df['Tissue']==t1)]
        d2 = df[(df['Tissue']==t2)]
        source.data["seq"]=d1["Sequence"]
        source.data["rbp"]=d1["RBP"]
        source.data["t1_prob"] = d1["Probability"]
        source.data["t2_prob"] = d2["Probability"]

    def update(selected=None):
        t1, t2 = ticker1.value, ticker2.value
        print(f"Tissues: {t1} & {t2}")
        data = get_data(t1, t2)



#    update()

    show(dropdown1)
#    show(row(column(p1, ticker1), column(p2, ticker2)))



def load_all_data(idx, tissueToMatrix):
    """
    Returns Data Frame of "stakced" probability matrices for all tissue types
    """
    #generate master dataframe for all tissue types
    df = pd.DataFrame(columns=["RBP", "Sequence", "Probability","Tissue"])

    for tissue in tissueToMatrix:
        #create dataframe: columns-nucleotides, rows-rbps, data-samplingProbabilities
        df1 = pd.DataFrame(columns=idx,index = list(rbps), data=tissueToMatrix[tissue])
        #stack dataframe
        df1 = pd.DataFrame(df1.stack(), columns=["Probability"]).reset_index()
        #rename columns
        df1.columns = ["RBP", "Sequence", "Probability"]
        df1['Tissue'] = tissue
        #concatenate specific tissue dataframe with master dataframe
        df = pd.concat([df,df1])

    return df
    

def load_dataframe(tissue, idx, tissueToMatrix):
    """
    Returns data frame of "stacked" probability matrix for one specific tissue type
    """

    #create dataframe: columns-nucleotides, rows-rbps, data-samplingProbabilities
    df = pd.DataFrame(columns=idx,index = list(rbps), data=tissueToMatrix[tissue])
    #stack dataframe
    df1 = pd.DataFrame(df.stack(), columns=["Probability"]).reset_index()
    #rename columns
    df1.columns = ["RBP", "Sequence", "Probability"]
    df1['Tissue'] = tissue

    return df1


if __name__ == '__main__':
    main()

