Skip to content

Commit

Permalink
Fix - return variables
Browse files Browse the repository at this point in the history
  • Loading branch information
ymarghi committed Feb 15, 2024
1 parent 9a5b570 commit c023d1e
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions mmidas/utils/cluster_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def K_selection(data_dict, num_category, n_arm, thr=0.95):

with sns.axes_style("darkgrid"):
data_dict['num_pruned'] = np.array(data_dict['num_pruned'])
data_dict['dz'] = np.array(data_dict['dz'])
data_dict['d_qz'] = np.array(data_dict['d_qz'])
data_dict['dc'] = np.array(data_dict['dc'])
data_dict['d_qc'] = np.array(data_dict['d_qc'])
data_dict['con_min'] = np.array(data_dict['con_min'])
data_dict['con_min'] = np.reshape(data_dict['con_min'], (int(n_comb), len(data_dict['d_qz'])))
data_dict['con_min'] = np.reshape(data_dict['con_min'], (int(n_comb), len(data_dict['d_qc'])))
data_dict['con_mean'] = np.array(data_dict['con_mean'])
data_dict['con_mean'] = np.reshape(data_dict['con_mean'], (int(n_comb), len(data_dict['d_qz'])))
data_dict['con_mean'] = np.reshape(data_dict['con_mean'], (int(n_comb), len(data_dict['d_qc'])))
indx = np.argsort(data_dict['num_pruned'])
norm_aitchison_dist = data_dict['dz'] - np.min(data_dict['dz'])
norm_aitchison_dist = data_dict['dc'] - np.min(data_dict['dc'])
norm_aitchison_dist = norm_aitchison_dist / np.max(norm_aitchison_dist)
recon_loss = []
norm_recon = []
Expand All @@ -157,33 +157,43 @@ def K_selection(data_dict, num_category, n_arm, thr=0.95):
mean_cost = (neg_cons + norm_recon_mean + norm_aitchison_dist) / 3 # cplmixVAE_data['d_qz']

# suggest the number of clusters
tmp_ind = np.where(consensus[indx] > thr)[0]
ordered_cons = consensus[indx]
for tt in range(len(tmp_ind)):
i = len(tmp_ind) - tt - 1
coeff = (np.abs(ordered_cons[tmp_ind[i]] - ordered_cons[tmp_ind[i]-1]) +
np.abs(ordered_cons[tmp_ind[i]] - ordered_cons[tmp_ind[i]+1])) / 2
coeff = np.round(coeff, 3)
if coeff < 5e-3:
break
selected_idx = tmp_ind[i]

if thr > max(consensus):
print("Privded minimum consensus is too high, please provide a lower value.")
plot_flag = False
K = None
else:
plot_flag = True
ordered_rec = norm_recon_mean[indx]
ordered_cons = consensus[indx]
tmp_ind = np.where(ordered_cons > 0.95)[0]
for tt in range(len(tmp_ind)):
i = len(tmp_ind) - tt - 1
if (ordered_cons[tmp_ind[i]] > ordered_cons[tmp_ind[i]-1]) and (ordered_rec[tmp_ind[i]] < ordered_rec[tmp_ind[i]-1]):
selected_idx = tmp_ind[i]
K = data_dict['num_pruned'][indx][selected_idx]
break

fig = plt.figure(figsize=[10, 5])
ax = fig.add_subplot()
ax.plot(data_dict['num_pruned'][indx], data_dict['d_qz'][indx], label='Average Distance')
ax.plot(data_dict['num_pruned'][indx], data_dict['d_qc'][indx], label='Average Distance')
ax.plot(data_dict['num_pruned'][indx], neg_cons[indx], label='Average Dissent (1 - Consensus)')
ax.set_xlim([np.min(data_dict['num_pruned'][indx])-1, num_category + 1])
ax.set_xlabel('Categories', fontsize=14)
ax.set_xticks(data_dict['num_pruned'][indx])
ax.set_xticklabels(data_dict['num_pruned'][indx], fontsize=8, rotation=90)
ax.vlines(data_dict['num_pruned'][indx][selected_idx], 0, 1, colors='gray', linestyles='dotted')
ax.hlines(neg_cons[indx][selected_idx], min(data_dict['num_pruned']), max(data_dict['num_pruned']), colors='gray', linestyles='dotted')
if plot_flag:
ax.vlines(data_dict['num_pruned'][indx][selected_idx], 0, 1, colors='gray', linestyles='dotted')
ax.hlines(neg_cons[indx][selected_idx], min(data_dict['num_pruned']), max(data_dict['num_pruned']), colors='gray', linestyles='dotted')

ax.legend(loc='upper right')
ax.set_ylim([0, 1])
ax.grid(True)
plt.show()
print(f"Selected number of clusters: {data_dict['num_pruned'][indx][selected_idx]} with consensus {consensus[indx][selected_idx]}")
return data_dict['num_pruned'], mean_cost, consensus, indx, data_dict['num_pruned'][indx][selected_idx]

if plot_flag:
print(f"Selected number of clusters: {data_dict['num_pruned'][indx][selected_idx]} with consensus {consensus[indx][selected_idx]}")

return data_dict['num_pruned'][indx], norm_recon_mean[indx], consensus[indx], K



Expand Down

0 comments on commit c023d1e

Please sign in to comment.