Skip to content

Commit

Permalink
fixed gain index sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
tobybaker committed Mar 28, 2024
1 parent dd3648a commit e519ace
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions gritic/posteriortablegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import numpy as np
import pandas as pd


def apply_penalty_to_table(sample_table,prior_penalty):

prior_table = sample_table[['Sample_ID','Segment_ID','Route','Average_N_Events','Probability']].copy().drop_duplicates()
n_vars = len(prior_table[['Sample_ID','Segment_ID']].drop_duplicates())
print(prior_table)
if n_vars >1:
prior_table['Probability'] = prior_table.groupby(['Sample_ID','Segment_ID'], group_keys=False).apply(lambda g: np.multiply(g.Probability,np.exp(-prior_penalty*g.Average_N_Events))/np.sum(np.multiply(g.Probability,np.exp(-prior_penalty*g.Average_N_Events))))
else:
prior_table['Probability'] = np.multiply(prior_table['Probability'].values,np.exp(-prior_penalty*prior_table['Average_N_Events'].values))/np.sum(np.multiply(prior_table['Probability'].values,np.exp(-prior_penalty*prior_table['Average_N_Events'].values)))
sample_table = sample_table.drop(columns=['Probability']).merge(prior_table,how='inner')
return sample_table.copy()

def load_route_table(path):
read_cols = ['Sample_ID','Segment_ID','Route','Average_N_Events','Average_Pre_WGD_Losses','Average_Post_WGD_Losses','Probability','Chromosome','Segment_Start','Segment_End','Major_CN','Minor_CN','WGD_Status','N_Mutations']
Expand Down Expand Up @@ -56,9 +66,11 @@ def load_timing_from_dict(segment_path):
timing_dict = pickle.load(input_file)
input_file.close()
return timing_dict
def get_sample_posterior_table(sample_table_path,input_dir,sample_id):
def get_sample_posterior_table(sample_table_path,input_dir,sample_id,apply_penalty,prior_penalty=2.7):

sample_table = pd.read_csv(sample_table_path,sep='\t',dtype={'Chromosome':str})
if apply_penalty:
sample_table = apply_penalty_to_table(sample_table,prior_penalty)
full_segment_table = sample_table[['Segment_ID','Chromosome','Segment_Start','Segment_End','Major_CN','Minor_CN','N_Mutations']].drop_duplicates()
node_table = sample_table[['Route','Node','Node_Phasing','Major_CN','Minor_CN','WGD_Status']].drop_duplicates()

Expand All @@ -84,12 +96,12 @@ def get_sample_posterior_table(sample_table_path,input_dir,sample_id):
segment_frame = pd.concat(segment_frames)
segment_frame = pd.merge(full_segment_table,segment_frame,on=['Segment_ID'],how='inner')
segment_frame = pd.merge(segment_frame,node_table,how='inner')
segment_frame = segment_frame.sort_values(by=['Segment_ID','Posterior_Sample_Index','Gain_Timing'])
segment_frame['Gain_Index'] = segment_frame.groupby(['Segment_ID','Posterior_Sample_Index']).cumcount()
segment_frame = segment_frame.sort_values(by=['Segment_ID','Posterior_Sample_Index','Gain_Index'])
return segment_frame

def get_segment_posterior_table_summary(segment_posterior_table):
segment_posterior_table['Gain_Index'] = segment_posterior_table.groupby(['Segment_ID','Posterior_Sample_Index']).cumcount()
n_samples = segment_posterior_table['Posterior_Sample_Index'].max()+1

segment_posterior_summary = {'Gain_Index':[],'Proportion':[],'Timing_Median':[],'Timing_Low_CI':[],'Timing_High_CI':[],'Pre_WGD_Probability':[],'Post_WGD_Probability':[],'WGD_Timing_Median':[],'WGD_Timing_Low_CI':[],'WGD_Timing_High_CI':[]}
Expand Down

0 comments on commit e519ace

Please sign in to comment.