Skip to content

Commit f35c3c5

Browse files
authored
Updated error-checking codes
1 parent 9c478ef commit f35c3c5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

utils/NN_trainer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,16 @@ def drop_static_cols(self):
100100
if df1[c].nunique() == 1:
101101
cols_to_be_dropped.append(c)
102102
df2 = df1.drop(cols_to_be_dropped, axis=1)
103-
if len(cols_to_be_dropped) > 0:
103+
if len(cols_to_be_dropped) == 0:
104+
print("Nothing to be dropped")
105+
if len(cols_to_be_dropped) == 1:
106+
print("Dropped the following column:", cols_to_be_dropped[0])
107+
if len(cols_to_be_dropped) > 1:
104108
print("Dropped the following columns:", end=" ")
105109
for i in cols_to_be_dropped[:-1]:
106110
print(i, end=", ")
107111
print("and " + cols_to_be_dropped[-1], end=".")
108-
else:
109-
print("Nothing to be dropped")
112+
110113
df2 = df1
111114
self.df = df2
112115

@@ -496,14 +499,17 @@ def rmse_test(self):
496499

497500
return round(error,3)
498501

499-
def save_model(self):
502+
def save_model(self,filename=None):
500503
"""
501504
Saves the fitted model in a h5 file
502505
"""
503506
if self.fitted_:
504507
model = self.model
505-
var = str(self.output_var)
506-
filename = var + "_model" + ".h5"
508+
if filename is not None:
509+
filename = filename
510+
else:
511+
var = str(self.output_var)
512+
filename = "model_" + var + ".h5"
507513
model.save(filename)
508514
else:
509515
print("Nothing to be saved. Model not fitted yet!")

0 commit comments

Comments
 (0)