Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexhaoge committed Dec 27, 2020
1 parent e9f6ec4 commit 02b8b9a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
81 changes: 73 additions & 8 deletions MLSR/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def grid_search_and_result(
log_dir: str,
score=None,
verbose: int = 2,
k: int = 5):
k: int = 5,
fit_params: dict = None):
"""
交叉验证网格搜索,测试集和训练集得分,混淆矩阵和ROC曲线绘制
Args:
Expand All @@ -82,6 +83,7 @@ def grid_search_and_result(
score: 评分指标,默认使用f1和acc,最后用f1 refit
verbose: 日志级别,0为静默
k: 交叉验证折数
fit_params: 训练时参数
Returns: 训练好的GridSearchCV模型
Expand All @@ -94,12 +96,16 @@ def grid_search_and_result(
}
gsCV = GridSearchCV(
estimator=pipe,
cv=k, n_jobs=-1, param_grid=grid,
cv=k, n_jobs=-1,
param_grid=grid,
scoring=scoring,
refit='f1',
verbose=verbose
)
gsCV.fit(Xtrain, ytrain)
if fit_params is None:
gsCV.fit(Xtrain, ytrain)
else:
gsCV.fit(Xtrain, ytrain, **fit_params)
dump(gsCV, log_dir + '/gsCV')
dump(gsCV.best_estimator_, log_dir + '/best_model')
file_prefix = log_dir + '/' + strftime("%Y_%m_%d_%H_%M_%S", localtime())
Expand Down Expand Up @@ -171,17 +177,29 @@ def do_random_forest(dataset: DataSet, log_dir: str = '../log', grid: dict = Non
"""
from sklearn.ensemble import RandomForestClassifier
if grid is None:
# raw grid
# grid = {
# 'rf__criterion': ['gini', 'entropy'],
# 'rf__n_estimators': [100, 300, 600, 800, 1200],
# 'rf__min_samples_split': [2, 5], # 这里数字是随机给的,无根据
# 'rf__min_samples_leaf': [1, 4], # 这里数字是随机给的,无根据
# 'rf__bootstrap': [True, False],
# 'rf__min_impurity_decrease': [0., 0.01, 0.1],
# 'rf__class_weight': ['balanced', 'balanced_subsample', None],
# 'rf__warm_start': [True, False],
# 'rf__oob_score': [True, False],
# 'rf__ccp_alpha': [0., 0.1, 0.5]
# }
# fine grid
grid = {
'rf__criterion': ['gini', 'entropy'],
'rf__n_estimators': [100, 300, 600, 800, 1200],
'rf__min_samples_split': [2, 5], # 这里数字是随机给的,无根据
'rf__n_estimators': [80, 100, 150, 200, 500],
'rf__min_samples_split': [1, 2], # 这里数字是随机给的,无根据
'rf__min_samples_leaf': [1, 4], # 这里数字是随机给的,无根据
'rf__bootstrap': [True, False],
'rf__min_impurity_decrease': [0., 0.01, 0.1],
'rf__class_weight': ['balanced', 'balanced_subsample', None],
'rf__warm_start': [True, False],
'rf__oob_score': [True, False],
'rf__ccp_alpha': [0., 0.1, 0.5]
'rf__ccp_alpha': [0., 0.1, 0.001]
}
pipe = Pipeline([
('scaler', MinMaxScaler()),
Expand Down Expand Up @@ -221,8 +239,10 @@ def do_svm(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
'SVM__kernel': ['linear', 'rbf', 'poly'],
'SVM__C': [0.7, 0.8, 0.9, 0.95, 1, 1.05, 1.1, 1.2, 1.5, 2],
'SVM__degree': [2, 3, 4],
'SVM__gamma': [0.001, 'scale'],
'SVM__decision_function_shape': ['ovo', 'ovr'],
'SVM__break_ties': [True, False],
'SVM__tol': [1e-2, 1e-3, 1e-4, 1e-5]
}
pipe = Pipeline([
('scaler', MinMaxScaler()),
Expand Down Expand Up @@ -285,3 +305,48 @@ def do_naive_bayes(dataset: DataSet, log_dir: str = '../log', grid: dict = None)
])
Xtrain, Xtest, ytrain, ytest = train_test_split(dataset.features, dataset.label, train_size=0.7)
return grid_search_and_result(Xtrain, ytrain, Xtest, ytest, pipe, grid, log_dir)


def do_xgb(dataset: DataSet, log_dir: str = '../log', grid: dict = None):
"""
训练Xgboost
Args:
grid:超参数搜索空间的网格,不填则使用默认搜索空间
dataset:输入数据集,将会按照0.7, 0.3比例分为训练集和测试集
log_dir:输出结果文件的目录
Returns:返回训练好的GridSearchCV模型
"""
from xgboost import XGBClassifier
if grid is None:
grid = {
'xgb__n_estimators': [80, 100, 150, 200, 400, 500, 600, 800],
'xgb__max_depth': [6, 8, 10, 15, 20],
'xgb__colsample_bytree': [0.8, 1],
'xgb__learning_rate': [0.01, 0.1, 0.3],
# 'xgb__n_estimators': [1]
}
train_param = {
'xgb__early_stopping_rounds': 100
}
pipe = Pipeline([
('scaler', MinMaxScaler()),
('xgb', XGBClassifier(
objective='multi:softmax',
n_jobs=-1,
booster='gbtree',
verbosity = 2,
verbose = True
)
)
])
Xtrain, Xtest, ytrain, ytest = train_test_split(dataset.features, dataset.label, train_size=0.7)
gscv = grid_search_and_result(Xtrain, ytrain, Xtest, ytest, pipe, grid, log_dir)
best_model = gscv.best_estimator_
file = open(log_dir + '/feature.txt', 'a')
XGBClassifier.coef_
file.write('\nfeature importance\n')
file.write(best_model['xgb'].feature_importances_.__str__())
file.close()
return gscv
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def get_arguments():
parser.add_argument('--nb', action='store_true', help='Train naive bayes')
parser.add_argument('--svm', action='store_true', help='Train svm')
parser.add_argument('--lr', action='store_true', help='Train logistic regression')
parser.add_argument('--xgb', action='store_true', help='Train xgboost')
return parser.parse_args()


Expand All @@ -37,3 +38,5 @@ def get_arguments():
do_svm(zz, 'log/svm')
if args.lr:
do_svm(zz, 'log/lr')
if args.xgb:
do_xgb(zz, 'log/xgb')

0 comments on commit 02b8b9a

Please sign in to comment.