import pandas as pd
import numpy as np

import bson
import model as mdl
import mongo_setup as mongo_setup
from specific_seq_samplings import Sequence

from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, Select, CheckboxGroup, LinearColorMapper, CustomJS, TextAreaInput
from bokeh.palettes import Blues4
from bokeh.plotting import figure, curdoc



# SET GLOBALS

tissues = ['Adrenal Gland' , 'Brain' , 'Colon' , 'Esophagus' ,\
           'Fallopian tube' , 'Heart' , 'Kidney' , 'Liver' , 'Lung' ,\
           'Ovary' , 'Pancreas' , 'Prostate' , 'Salivary gland' ,\
           'Small intestine' , 'Smooth muscle' , 'Spleen' ,\
           'Stomach' , 'Testis' , 'Thyroid']

rbps = list(np.loadtxt('/Users/olafsohr/Documents/hipergator/RBP_webpage/kd_absolute.csv',\
                       dtype=object,delimiter=',',max_rows=1)[1:])
rbps.append("UNBOUND")

default_val="GGAGTCTCACTGTGTCTCCCAGCCTGGAAAGCAATGGCCGCGATCTCAGCTCACTGCAACCTCCGCCTCCTGAGTTCAAGCGATTCTCCTGCCTCAGTCTCCTGAGTAAGTGGGATTCAGACGCCTGCCACCATGCCAGGCTAACTTTTGTATTTTTAGTAGAGACGGAGTTTCACTATGTTGGCCAGGCTGGTCTTGAACTCCTGACCTCAGAGGAAAAAAAATAATGTAACTAGAGAAGAGAGTTGAGCAGAGAACTGAGCTCTTTGAAGTGCTGGTAGTTTGTCTCAACCCATGAAGCAGCTTAACAAGCAACTATAACTAAGCATAGAGGTTGGTACTAAGAAGTGCCTTTCCTGACGTCTCTGCTGCTTGGAACCGCTTCTAGAGCAGTCTCTGCTTTTGCCTTGCTTGCTGCCAGCTAGACTGTGACGACAGCACATCCACCCTCCACCTCTAGCCCAGACACCCCCATTTCTACTTATAATCAAGAGAAAAGCTCTAAGTATCTGGCATTGCCCTAGGCTGCTTTAGTGTTAAAAGAAAAGTTTGCTGAAAAAGTAAGATATCTTCTGCCAGGAAATCAAGGAGGAAAAAAAAAATCATTTTCTCGATTTTGCTCTAAACTGCTGCATCTGTCTATGCCAAACTAATCAATACCGATTGCACCACCAAACTCCATTGCAAATTCAGCTGTGAGGAGATTCCCTTTCAGACAACTTTGCTGAAAGCAGCTTGGAAATTCGGTGTCGAAGGGTCTGCCACGTTTTCATGCTTGCATTTTGGGCTCCAAATTGGCACTGGGAAGGGGTTACTGAGAGCACAAGGCTGATACCAGGCCCTACTTTTAAACGTTCATCTACTTACAATCCTAGTATTTCTCTAAAAACCAAAACCTCTTTGAATTAACAGTTTCATGCTGTGAATTTCTAGTGGGAGATCTTTTCCTTGATATTGACGACACAATTTTCCATGTACTTTTAAAGCAGGGAGTGGGGAAAAGTATTTTGAGGGGACATTTTCATCATCAGTTCAGCTT"

mongo_setup.global_init()



#DEFINE FUNCTIONS

def retrieve_data(tx_name):
    """
    Create a DataFrame (DF) of sampling matrix stacked by tissue type for Tx
    Retrieve transcript sequence
    RETURN dataFrame, sequence
    """

    #Fetch data from mongo database
    found_tx = list(Sequence.objects(rna_name=tx_name).all())
    found_tissues = [x.tissue_name for x in found_tx]
    found_matrix = np.array([x.sampling_matrix for x in found_tx])
    found_seq = [x.rna_sequence for x in found_tx]

    #initiate data frame
    df = pd.DataFrame(columns = ["RBP", "Sequence", "Probability", "Tissue"])
    #define index
    idx = [f'{x}_{found_seq[0][x]}' for x in range(len(list(found_seq[0])))]
    
    #add specific tissue data to dataFrame
    for i in range(len(found_tissues)):
        tissue = found_tissues[i]
        values = found_matrix[i]
        # create DF of sampling data matrix for specific tissue
        df1 = pd.DataFrame(columns = idx, index = rbps, data = values)
        #stack specific tissue data 
        df1 = pd.DataFrame(df1.stack(), columns = ["Probability"]).reset_index()
        df1.columns = ["RBP", "Sequence", "Probability"]
        #add column of tissue names
        df1["Tissue"] = tissue
        #append stacked DF to master DF
        df = pd.concat([df, df1])

    #return master DF and sequence
    return df, found_seq[0]


def update_sequence(seq):
    """
    Run model.py for all tissues on sequence
    """
    df = pd.DataFrame(columns = ['RBP', "Sequence", "Probability", "Tissue"])
    idx = [f'{x}_{seq[x]}' for x in range(len(seq))]

    for t in tissues:
        sites = np.ndarray.tolist(mdl.run_model(seq, t))
        
        df1 = pd.DataFrame(columns=idx, index=rbps, data=sites)
        df1 = pd.DataFrame(df1.stack(), columns=["Probability"]).reset_index()
        df1.columns = ['RBP', 'Sequence', 'Probability']
        df1["Tissue"] = t
        df = pd.concat([df,df1])
    
    return df, idx


def nix(val, lst):
    """
    Remove value (val) from given list (lst), return edited list
    """
    return [x for x in lst if x != val]



# SET UP DATA

df, idx = update_sequence(default_val)
source = ColumnDataSource(data=dict(seq=[], rbp=[], t1_prob=[], t2_prob=[]))
mapper = LinearColorMapper(palette="Magma256", low=0, high=1)



# SET UP PLOT

p1 = figure(x_range=idx, y_range=list(rbps), width=15*len(idx), height=500,\
            x_axis_location="below", toolbar_location="below",\
            tools = "hover,save,pan,box_zoom,reset,wheel_zoom,box_select")
p1.rect(x="seq", y="rbp", width=1, height=1, source=source,\
        fill_color={'field':'t1_prob','transform':mapper},\
        line_color=None)


p2 = figure(x_range=idx, y_range=list(rbps), width=15*len(idx), height=500,\
            x_axis_location="below", toolbar_location="below",\
            tools = "hover,save,pan,box_zoom,reset,wheel_zoom,box_select")
p2.rect(x="seq", y="rbp", width=1, height=1, source=source,\
        fill_color={'field':'t2_prob',  'transform':mapper},\
        line_color=None)



# SET UP WIDGETS

text = TextAreaInput(title="RNA Sequence", value=default_val)
t1_select = Select(value="Heart", options=nix("Smooth muscle", tissues))
t2_select = Select(value="Smooth muscle", options=nix("Heart", tissues))



# SET UP CALLBACKS

def text_change(attrname, old, new):

    #run model for the RNA sequence
    df, idx = update_sequence(text.value)

text.on_change("value", text_change)

def t1_change(attrname, old, new):
    t2_select.options = nix(new, tissues)
    update()

t1_select.on_change("value", t1_change)

def t2_change(attrname, old, new):
    t1_select.options = mix(new, tissues)
    update()

t2_select.on_change("value", t2_change)

def get_data(t1, t2):
    
    d1 = df[(df['Tissue']==t1)]
    d2 = df[(df['Tissue']==t2)]
    source.data["seq"]=d1["Sequence"]
    source.data["rbp"]=d1["RBP"]
    source.data["t1_prob"] = d1["Probability"]
    source.data["t2_prob"] = d2["Probability"]

def update(selected=None):
    t1, t2 = t1_select.value, t2_select.value
    print(f"Tissues: {t1} & {t2}")
    data = get_data(t1,t2)
    

update()

curdoc().add_root(column(text, row(column(p1, t1_select), column(p2, t2_select))))
curdoc().title = "RBP Binding Probability by Tissue"
