import pandas as pd
import numpy as np

import bson
import mongo_setup as mongo_setup
from specific_seq_samplings import Sequence

from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, Select, CheckboxGroup, LinearColorMapper, Dropdown
from bokeh.palettes import Blues4
from bokeh.plotting import figure, curdoc


# SET GLOBALS
tissues = ['Adrenal Gland' , 'Brain' , 'Colon' , 'Esophagus' ,\
           'Fallopian tube' , 'Heart' , 'Kidney' , 'Liver' , 'Lung' ,\
           'Ovary' , 'Pancreas' , 'Prostate' , 'Salivary gland' ,\
           'Small intestine' , 'Smooth muscle' , 'Spleen' ,\
           'Stomach' , 'Testis' , 'Thyroid']

rbps = list(np.loadtxt('/home/olafsohr/RBP/src/data/kd_absolute.csv',\
                       dtype=object,delimiter=',',max_rows=1)[1:])
rbps.append("UNBOUND")

mongo_setup.global_init()



#DEFINE FUNCTIONS

def retrieve_data(tx_name):
    """
    Create a DataFrame (DF) of sampling matrix stacked by tissue type for Tx
    Retrieve transcript sequence
    RETURN dataFrame, sequence
    """

    #Fetch data from mongo database
    found_tx = list(Sequence.objects(rna_name=tx_name).all())
    found_tissues = [x.tissue_name for x in found_tx]
    found_matrix = np.array([x.sampling_matrix for x in found_tx])
    found_seq = [x.rna_sequence for x in found_tx]

    #initiate data frame
    df = pd.DataFrame(columns = ["RBP", "Sequence", "Probability", "Tissue"])
    #define index
    idx = [f'{x}_{found_seq[0][x]}' for x in range(len(list(found_seq[0])))]
    
    #add specific tissue data to dataFrame
    for i in range(len(found_tissues)):
        tissue = found_tissues[i]
        values = found_matrix[i]
        # create DF of sampling data matrix for specific tissue
        df1 = pd.DataFrame(columns = idx, index = rbps, data = values)
        #stack specific tissue data 
        df1 = pd.DataFrame(df1.stack(), columns = ["Probability"]).reset_index()
        df1.columns = ["RBP", "Sequence", "Probability"]
        #add column of tissue names
        df1["Tissue"] = tissue
        #append stacked DF to master DF
        df = pd.concat([df, df1])

    #return master DF and sequence
    return df, found_seq[0]


def nix(val, lst):
    """
    Remove value (val) from given list (lst), return edited list
    """
    return [x for x in lst if x != val]


df, fullSeq = retrieve_data("ENST00000631435.1")

idx = [f'{x}_{fullSeq[x]}' for x in range(len(list(fullSeq)))]
source = ColumnDataSource(data=dict(seq=[], rbp=[], t1_prob=[], t2_prob=[]))
mapper = LinearColorMapper(palette="Magma256", low=0, high=1)

p1 = figure(x_range=idx, y_range=list(rbps), width=15*len(fullSeq), height=500,\
            x_axis_location="below", toolbar_location="below",\
            tools = "hover,save,pan,box_zoom,reset,wheel_zoo,box_select")

p1.rect(x="seq", y="rbp", width=1, height=1, source=source,\
        fill_color={'field':'t1_prob','transform':mapper},\
        line_color=None)


p2 = figure(x_range=idx, y_range=list(rbps), width=15*len(fullSeq), height=500,\
            x_Axis_location="below", toolbar_location="below",\
            tools = "hover,save,pan,box_zoom,reset,wheel_zoom,box_select")
p2.rect(x="seq", y="rbp", width=1, height=1, source=source,\
        fill_color={'field':'t2_prob',  'transform':mapper},\
        line_color=None)

ticker1 = Select(value='Heart', options=nix('Smoothmuscle', tissues))
def ticker1_change(attrname, old, new):
    ticker2.options = nix(new, tissues)
    update()
ticker1.on_change("value", ticker1_change)

ticker2 = Select(value="Smooth muscle", options=nix("Heart", tissues))
def ticker2_change(attrname, old, new):
    ticker1.options = nix(new, tissues)
    update()
ticker2.on_change("value", ticker2_change)

def get_data(t1, t2):
    d1 = df[(df['Tissue']==t1)]
    d2 = df[(df['Tissue']==t2)]
    source.data["seq"]=d1["Sequence"]
    source.data["rbp"]=d1["RBP"]
    source.data["t1_prob"] = d1["Probability"]
    source.data["t2_prob"] = d2["Probability"]

def update(selected=None):
    t1, t2 = ticker1.value, ticker2.value
    print(f"Tissues: {t1} & {t2}")
    data = get_data(t1,t2)

update()

curdoc().add_root(row(column(p1, ticker1), column(p2, ticker2)))
curdoc().title = "RBP Binding Probability by Tissue"
