Skip to content

Commit

Permalink
added save_path feature for plot_model()
Browse files Browse the repository at this point in the history
  • Loading branch information
bhanuteja2001 committed Aug 26, 2021
1 parent 33a65e0 commit 513d2c7
Showing 1 changed file with 69 additions and 22 deletions.
91 changes: 69 additions & 22 deletions pycaret/internal/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -6080,8 +6080,12 @@ def cluster():
plot_filename = f"{plot_name}.html"

if save:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in current active directory")
if save == True:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in current active directory")
else:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in directory mentioned")

elif system:
if display_format == "streamlit":
Expand Down Expand Up @@ -6225,8 +6229,13 @@ def _tsne_anomaly():
plot_filename = f"{plot_name}.html"

if save:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
if save == True:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
else:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in directory mentioned")

elif system:
if display_format == "streamlit":
st.write(fig)
Expand Down Expand Up @@ -6317,8 +6326,12 @@ def _tsne_clustering():
plot_filename = f"{plot_name}.html"

if save:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
if save == True:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
else:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in directory mentioned")
elif system:
if display_format == "streamlit":
st.write(fig)
Expand Down Expand Up @@ -6386,8 +6399,12 @@ def distribution():
plot_filename = f"{plot_name}.html"

if save:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
if save == True:
fig.write_html(f"{plot_filename}")
logger.info(f"Saving '{plot_filename}' in current active directory")
else:
fig.write_html(plot_filename)
logger.info(f"Saving '{plot_filename}' in directory mentioned")
elif system:
fig.show()

Expand Down Expand Up @@ -6750,8 +6767,12 @@ def lift():
y_test__, predict_proba__, figsize=(10, 6)
)
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand All @@ -6775,8 +6796,12 @@ def gain():
y_test__, predict_proba__, figsize=(10, 6)
)
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand Down Expand Up @@ -6918,8 +6943,12 @@ def tree():
display.move_progress()
display.clear_output()
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand Down Expand Up @@ -6963,8 +6992,12 @@ def calibration():
display.move_progress()
display.clear_output()
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand Down Expand Up @@ -7223,8 +7256,12 @@ def _feature(n: int):
display.move_progress()
display.clear_output()
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand Down Expand Up @@ -7262,8 +7299,12 @@ def ks():
data_y, predict_proba__, figsize=(10, 6)
)
if save:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
if save == True:
logger.info(f"Saving '{plot_name}.png' in current active directory")
plt.savefig(f"{plot_name}.png", bbox_inches="tight")
else:
logger.info(f"Saving '{plot_name}.png' in the path specified")
plt.savefig("{}/{}.png".format(save,plot_name))
elif system:
plt.show()
plt.close()
Expand Down Expand Up @@ -7607,7 +7648,10 @@ def summary(show: bool = True):
shap_plot = shap.summary_plot(shap_values, test_X, show=show, **kwargs)

if save:
plt.savefig(f"SHAP {plot}.png", bbox_inches="tight")
if save == True:
plt.savefig(f"SHAP {plot}.png", bbox_inches="tight")
else:
plt.savefig("{}/{}.png".format(save,plot))
return shap_plot

def correlation(show: bool = True):
Expand Down Expand Up @@ -7640,7 +7684,10 @@ def correlation(show: bool = True):
logger.info("model type detected: type 2")
shap.dependence_plot(dependence, shap_values, test_X, show=show, **kwargs)
if save:
plt.savefig(f"SHAP {plot}.png", bbox_inches="tight")
if save == True:
plt.savefig(f"SHAP {plot}.png", bbox_inches="tight")
else:
plt.savefig("{}/{}.png".format(save,plot_name))
return None

def reason(show: bool = True):
Expand Down

0 comments on commit 513d2c7

Please sign in to comment.