Skip to content

Commit

Permalink
Update - push last changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ymarghi committed Feb 16, 2024
1 parent aa456c8 commit d922440
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 72 deletions.
Binary file modified mmidas/utils/__pycache__/analysis_cells_tree.cpython-39.pyc
Binary file not shown.
Binary file modified mmidas/utils/__pycache__/cluster_analysis.cpython-39.pyc
Binary file not shown.
Binary file modified mmidas/utils/__pycache__/dataloader.cpython-39.pyc
Binary file not shown.
Binary file modified mmidas/utils/__pycache__/tree_based_analysis.cpython-39.pyc
Binary file not shown.
103 changes: 55 additions & 48 deletions mmidas/utils/analysis_cells_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from copy import deepcopy

class Node():
'''Simple Node class. Each instance contains a list of children and parents.'''
'''
Simple Node class. Each instance contains a list of children and parents.
'''

def __init__(self,name,C_list=[],P_list=[]):
self.name=name
Expand All @@ -29,9 +31,11 @@ def children(self,C_list=[],P_list=[]):
return [Node(n,C_list,P_list) for n in self.C_name_list]

def get_valid_classifications(current_node_list,C_list,P_list,valid_classes):
'''Recursively generates all possible classifications that are valid,
based on the hierarchical tree defined by `C_list` and `P_list` \n
`current_node_list` is a list of Node objects. It is initialized as a list with only the root Node.'''
'''
Recursively generates all possible classifications that are valid,
based on the hierarchical tree defined by `C_list` and `P_list` \n
`current_node_list` is a list of Node objects. It is initialized as a list with only the root Node.
'''

current_node_list.sort(key=lambda x: x.name)
valid_classes.append(sorted([node.name for node in current_node_list]))
Expand All @@ -47,10 +51,12 @@ def get_valid_classifications(current_node_list,C_list,P_list,valid_classes):


class HTree():
'''Class to work with hierarchical tree .csv generated for the transcriptomic data.
`htree_file` is full path to a .csv. The original .csv was generated from dend.RData,
processed with `dend_functions.R` and `dend_parents.R` (Ref. Rohan/Zizhen)'''
def __init__(self,htree_df=None,htree_file=None):
'''
Class to work with hierarchical tree .csv generated for the transcriptomic data.
`htree_file` is full path to a .csv. The original .csv was generated from dend.RData,
processed with `dend_functions.R` and `dend_parents.R` (Ref. Rohan/Zizhen)
'''
def __init__(self, htree_df=None, htree_file=None):

#Load and rename columns from filename
if htree_file is not None:
Expand All @@ -67,6 +73,7 @@ def __init__(self,htree_df=None,htree_file=None):
htree_df['y'].values[htree_df['isleaf'].values] = 0.0
htree_df['col'].fillna('#000000',inplace=True)
htree_df['parent'].fillna('root',inplace=True)
htree_df['child'] = np.array([c.strip() for c in htree_df['child']])

#Sorting for convenience
htree_df = htree_df.sort_values(by=['y', 'x'], axis=0, ascending=[True, True]).copy(deep=True)
Expand All @@ -87,11 +94,30 @@ def df2obj(self,htree_df):
for key in htree_df.columns:
setattr(self, key, htree_df[key].values)
return

def get_marker(self, exclude=[]):

if len(exclude)==0:
subclass_list = ['L2/3', 'L4', 'L5', 'L6', 'IT', 'PT', 'NP', 'CT', 'VISp', 'ALM', 'Sst', 'Vip', 'Lamp5', 'Pvalb', 'Sncg', 'Serpinf1']

t_clusters = self.child[self.isleaf]
marker_genes = []
for ttype in t_clusters:
indxs = [ch for ch in range(len(ttype)) if ttype[ch].find(' ')==0]
indxs = np.array(indxs + [len(ttype)])
for i_idx in range(len(indxs)-1):
tmp_gene = ttype[indxs[i_idx]+1:indxs[i_idx+1]]
if tmp_gene not in subclass_list:
marker_genes.append(tmp_gene)

return np.unique(marker_genes)


def plot(self,figsize=(15,10),fontsize=10,skeletononly=True,
skeletoncol='#BBBBBB',skeletonalpha=1.0,ls='-',txtleafonly=True,
fig=None, ax=None, linewidth=1, save=False, path=[], n_node=0,
hline_nodes=[], n_c=[], cell_count=[0], add_marker=False):
skeletoncol='#BBBBBB',skeletonalpha=1.0, ls='-',txtleafonly=True,
fig=None, ax=None, linewidth=1, save=False, path=[], n_node=0,
marker='s', marker_size=12, hline_nodes=[], n_c=[], cell_count=[0],
add_marker=False, margin_y=0.001):
if fig is None:
fig = plt.figure(figsize=figsize)
ax = plt.gca()
Expand Down Expand Up @@ -120,20 +146,20 @@ def plot(self,figsize=(15,10),fontsize=10,skeletononly=True,
for i in np.flatnonzero(self.isleaf):
label = self.child[i]
plt.text(self.x[i], self.y[i], label,
color=self.col[i],
color='black',
horizontalalignment='center',
verticalalignment='top',
rotation=90,
fontsize=fontsize)

for parent in np.unique(self.parent):
#Get position of the parent node:
p_ind = np.flatnonzero(self.child==parent)
p_ind = np.flatnonzero(self.child==parent).squeeze()
if p_ind.size==0: #Enters here for any root node
p_ind = np.flatnonzero(self.parent==parent)
p_ind = np.flatnonzero(self.parent==parent).squeeze()
xp = self.x[p_ind]
yp = 1.1*np.max(self.y)
else:
else:
xp = self.x[p_ind]
yp = self.y[p_ind]

Expand All @@ -145,55 +171,36 @@ def plot(self,figsize=(15,10),fontsize=10,skeletononly=True,
for c_ind in all_c_inds:
xc = self.x[c_ind]
yc = self.y[c_ind]
plt.plot([xc, xc], [yc, yp], color=skeletoncol,
alpha=skeletonalpha,ls=ls, linewidth=linewidth)
plt.plot([xc, xp], [yp, yp], color=skeletoncol,
alpha=skeletonalpha, ls=ls, linewidth=linewidth)
plt.plot([xc, xc], [yc, yp], color=skeletoncol, alpha=skeletonalpha, ls=ls, linewidth=linewidth)
plt.plot([xc, xp], [yp, yp], color=skeletoncol, alpha=skeletonalpha, ls=ls, linewidth=linewidth)
if skeletononly==False:
ax.axis('off')
ax.set_xlim([np.min(self.x) - a, np.max(self.x) + a])
ax.set_ylim([np.min(self.y), 1.1*np.max(self.y)])
plt.tight_layout()
sc = [3.1, 3, 2.7]
if add_marker:
print('add marker')
ax.axis('off')
ax.set_xlim([np.min(self.x) - 1, np.max(self.x) + 1])
ax.set_ylim([np.min(self.y), 1.1 * np.max(self.y)])
for i, s in enumerate(self.child):
if i < n_node:
ax.plot(self.x[i], self.y[i], 's', c=col[i], ms=12)
if self.y[i] > 0:
m_y = self.y[i] + margin_y*self.y[i]
else:
m_y = margin_y
print(i, s, self.col[i])
if isinstance(marker, list):
ax.plot(self.x[i], m_y, marker[i], color=self.col[i], ms=marker_size)
else:
ax.plot(self.x[i], m_y, marker, color=self.col[i], ms=marker_size)
plt.tight_layout()
if save:
# for ii in range(len(hline_nodes)):
# ax.axhline(y=y_node[ii][0] + 0.1*y_node[ii][0], linewidth=1.5,
# ls='--', color='black')
# ax.text(0.01, sc[ii]*(y_node[ii][0] + 0.15*y_node[ii][0]),
# 'K = ' + str(n_c[ii]),
# transform=ax.transAxes, fontsize=18,
# verticalalignment='top')
for spine in plt.gca().spines.values():
spine.set_visible(False)
plt.savefig(path + '/subtree.png', dpi=600)

# fig = plt.figure(figsize=(15, 3))
# ax = plt.gca()
# ax.bar(range(len(cell_count)), -1*cell_count,
# color=self.col)
# ax.set_ylabel('# Cells', fontsize=20)
# ax.set_xlim([-1, len(cell_count)+1])
# ax.set_xticks([])
# ax.set_yticks([])
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)
# ax.text(np.argmax(cell_count)-2, -np.max(
# cell_count) - 120, np.max(
# cell_count), fontsize=16)
# # ax.text(np.argmin(cell_count)-1.5, 20,
# # np.min(cell_count), fontsize=16)
# plt.tight_layout()
# plt.savefig(path + '/cell_count.png', dpi=600)

return

def plotnodes(self,nodelist,fig=None):
Expand Down
31 changes: 19 additions & 12 deletions mmidas/utils/cluster_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,17 @@ def K_selection(data_dict, num_category, n_arm, thr=0.95):
norm_aitchison_dist = norm_aitchison_dist / np.max(norm_aitchison_dist)
recon_loss = []
norm_recon = []
l_recon = []

for a in range(n_arm):
recon_loss.append(np.array(data_dict['recon_loss'][a]))
# print(np.min(recon_loss[a]), np.max(recon_loss[a]))
tmp = recon_loss[a] - np.min(recon_loss[a])
norm_recon.append(tmp / np.max(tmp))
# norm_recon.append(recon_loss[a])
l_recon.append(recon_loss[a])

norm_recon_mean = np.mean(norm_recon, axis=0)
l_recon_mean = np.mean(l_recon, axis=0)
neg_cons = 1 - np.mean(data_dict['con_mean'], axis=0)
consensus = np.mean(data_dict['con_mean'], axis=0)
mean_cost = (neg_cons + norm_recon_mean + norm_aitchison_dist) / 3 # cplmixVAE_data['d_qz']
Expand All @@ -163,15 +165,19 @@ def K_selection(data_dict, num_category, n_arm, thr=0.95):
K = None
else:
plot_flag = True
ordered_rec = norm_recon_mean[indx]
ordered_rec = l_recon_mean[indx]
ordered_cons = consensus[indx]
tmp_ind = np.where(ordered_cons > 0.95)[0]
for tt in range(len(tmp_ind)):
i = len(tmp_ind) - tt - 1
if (ordered_cons[tmp_ind[i]] > ordered_cons[tmp_ind[i]-1]) and (ordered_rec[tmp_ind[i]] < ordered_rec[tmp_ind[i]-1]):
selected_idx = tmp_ind[i]
K = data_dict['num_pruned'][indx][selected_idx]
break
tmp_ind = np.where(ordered_cons > thr)[0]
max_changes_indx = np.where(np.diff(ordered_cons[tmp_ind]) == max(np.diff(ordered_cons[tmp_ind])))[0][0] + 1
selected_idx = max_changes_indx
K = data_dict['num_pruned'][indx][selected_idx]

# for tt in range(len(tmp_ind)):
# i = len(tmp_ind) - tt - 1
# if (ordered_cons[tmp_ind[i]] > ordered_cons[tmp_ind[i]-1]) and (ordered_rec[tmp_ind[i]] < ordered_rec[tmp_ind[i]-1]):
# selected_idx = tmp_ind[i]
# K = data_dict['num_pruned'][indx][selected_idx]
# break

fig = plt.figure(figsize=[10, 5])
ax = fig.add_subplot()
Expand All @@ -181,19 +187,20 @@ def K_selection(data_dict, num_category, n_arm, thr=0.95):
ax.set_xlabel('Categories', fontsize=14)
ax.set_xticks(data_dict['num_pruned'][indx])
ax.set_xticklabels(data_dict['num_pruned'][indx], fontsize=8, rotation=90)
y_max = np.max([np.max(data_dict['d_qc']), np.max(neg_cons)]) + 0.1
if plot_flag:
ax.vlines(data_dict['num_pruned'][indx][selected_idx], 0, 1, colors='gray', linestyles='dotted')
ax.vlines(data_dict['num_pruned'][indx][selected_idx], 0, y_max, colors='gray', linestyles='dotted')
ax.hlines(neg_cons[indx][selected_idx], min(data_dict['num_pruned']), max(data_dict['num_pruned']), colors='gray', linestyles='dotted')

ax.legend(loc='upper right')
ax.set_ylim([0, 1])
ax.set_ylim([0, y_max])
ax.grid(True)
plt.show()

if plot_flag:
print(f"Selected number of clusters: {data_dict['num_pruned'][indx][selected_idx]} with consensus {consensus[indx][selected_idx]}")

return data_dict['num_pruned'][indx], norm_recon_mean[indx], consensus[indx], K
return data_dict['num_pruned'][indx], l_recon_mean[indx], consensus[indx], K



Expand Down
2 changes: 1 addition & 1 deletion mmidas/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,4 @@ def get_loaders(dataset, label=[], seed=None, batch_size=128, train_size=0.9):
all_data = TensorDataset(data_set_troch, all_ind_torch)
alldata_loader = DataLoader(all_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)

return train_loader, test_loader, alldata_loader, train_ind, test_ind
return train_loader, test_loader, alldata_loader
68 changes: 57 additions & 11 deletions mmidas/utils/tree_based_analysis.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,63 @@
import matplotlib.pyplot as plt
from mmidas.utils.analysis_cells_tree import HTree, do_merges
from matplotlib.backends.backend_pdf import PdfPages
import scipy.stats as stats
from sklearn.linear_model import LinearRegression
import numpy as np
import pandas as pd
from mmidas.utils.analysis_cells_tree import HTree, do_merges

resolution = 600

def corr_analysis(state, cell):

n_gene = cell.shape[-1]
all_corr, all_geneID = [], []
for s in range(state.shape[-1]):
# compute cross correlation using Pearson correction coefficient
cor_coef, p_val = np.zeros(n_gene), np.zeros(n_gene)
for g in range(n_gene):
if np.max(cell[:, g]) > 0:
zind = np.where(cell[:, g] > 0)
if len(zind[0]) > 4:
cor_coef[g], p_val[g] = \
stats.pearsonr(state[zind[0], s],
cell[zind[0], g])
else:
cor_coef[g], p_val[g] = 0, 0
else:
cor_coef[g], p_val[g] = 0, 0

g_id = np.argsort(np.abs(cor_coef))
# gene.append(dataset['gene_id'][g_id[-10:]])
# max_corr.append(cor_coef[g_id])
all_corr.append(np.sort(np.abs(cor_coef)))
all_geneID.append(g_id)

# create a linear regression model
# zind = np.where(expression[:, g_id] > 0)
# x = s_val[zind[0], s]
# y = expression[zind[0], g_id]
# model = LinearRegression()
# model.fit(np.expand_dims(x, -1), np.expand_dims(y, -1))
#
# # predict y from the data
# x_new = np.linspace(np.min(x)-.2, np.max(x)+.2, 100)
# y_new = model.predict(x_new[:, np.newaxis])
#
# # plot the results
# ax = plt.axes()
# ax.scatter(x, y, alpha=0.3, s=15, c='black')
# ax.plot(x_new, y_new, c='black')
# ax.axis('scaled')
# ax.set_ylim([np.min(y)-0.2, np.max(y)+0.2])
# ax.set_xlabel('S conditioning on Z=' + str(cat+1),
# fontsize=8)
# ax.set_ylabel('Gene Expression for ' + gene[-1], fontsize=8)
# ax.set_title('corr. coef. {:.2f}'.format(max_corr[-1]))
# plt.savefig(folder + '/state_analysis/state_' + str(s) +
# '_gene_corr_K' + str(
# cat+1) + '_g_' + gene[-1] +
# '.png', dpi=resolution, bbox_inches='tight')
# plt.close('all')

return all_corr, all_geneID


def get_merged_types(htree_file, cells_labels, num_classes=0, ref_leaf=[], node='n4'):
# get the tree
Expand Down Expand Up @@ -56,14 +106,10 @@ def get_merged_types(htree_file, cells_labels, num_classes=0, ref_leaf=[], node=
#Plot updated tree:
kept_subtree = HTree(htree_df=kept_subtree_df)

kept_subtree_df['isleaf'].loc[
kept_subtree_df['child'].isin(kept_leaf_nodes)] = True
kept_subtree_df['y'].loc[
kept_subtree_df['child'].isin(kept_leaf_nodes)] = 0.0
kept_subtree_df['isleaf'].loc[kept_subtree_df['child'].isin(kept_leaf_nodes)] = True
kept_subtree_df['y'].loc[kept_subtree_df['child'].isin(kept_leaf_nodes)] = 0.0
mod_subtree = HTree(htree_df=kept_subtree_df)


mod_subtree.update_layout()


return merged_cells_labels, mod_subtree, subtree

0 comments on commit d922440

Please sign in to comment.