Skip to content

Commit

Permalink
Encapsulate Binary Classification parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkatz committed Jun 12, 2018
1 parent 3facc11 commit b9c613f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/binary_classification/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import LightGBM.Utils.Test (fileDiff)

trainParams :: [P.Param]
trainParams =
[ P.Objective P.BinaryClassification
[ P.Objective $ P.BinaryClassification []
, P.Metric [P.BinaryLogloss, P.AUC]
, P.TrainingMetric True
, P.LearningRate $$(refineTH 0.1)
Expand Down
2 changes: 1 addition & 1 deletion examples/titanic/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import ConvertData (csvFilter, predsToKaggleFormat, testFilter)

trainParams :: [P.Param]
trainParams =
[ P.Objective P.BinaryClassification
[ P.Objective $ P.BinaryClassification []
, P.Metric [P.BinaryLogloss, P.AUC]
, P.TrainingMetric True
, P.LearningRate $$(refineTH 0.1)
Expand Down
7 changes: 5 additions & 2 deletions src/LightGBM/Internal/CommandLineWrapper.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ applicationPMap =
, (P.Regression P.Quantile, "quantile")
, (P.Regression P.MAPE, "mape")
, (P.Regression P.Gamma, "gamma")
, (P.BinaryClassification, "binary")
, (P.CrossEntropy P.XEntropy, "xentropy")
, (P.CrossEntropy P.XEntropyLambda, "xentlambda")
, (P.LambdaRank, "lambdarank")
]

mkBinaryClassString :: P.BinaryClassParam -> String
mkBinaryClassString (P.IsUnbalance b) = "is_unbalance=" ++ fmap toLower (show b)

mkTweedieString :: P.TweedieRegressionParam -> String
mkTweedieString (P.TweedieVariancePower p) = "tweedie_variance_power=" ++ show p

Expand Down Expand Up @@ -114,6 +116,8 @@ colSelPrefix (P.ColName _) = "name:"

-- | Construct the option string for the command.
mkOptionString :: P.Param -> [String]
mkOptionString (P.Objective (P.BinaryClassification bcParams)) =
["application=binary"] ++ map mkBinaryClassString bcParams
mkOptionString (P.Objective (P.MultiClass P.MultiClassSimple n)) =
["application=multiclass", "num_classes=" ++ show n]
mkOptionString (P.Objective (P.MultiClass P.MultiClassOneVsAll n)) =
Expand Down Expand Up @@ -215,7 +219,6 @@ mkOptionString (P.Sigmoid d) = ["sigmoid=" ++ show (unrefine d)]
mkOptionString (P.Alpha d) = ["alpha=" ++ show (unrefine d)]
mkOptionString (P.ScalePosWeight d) = ["scale_pos_weight=" ++ show d]
mkOptionString (P.BoostFromAverage b) = ["boost_from_average=" ++ show b]
mkOptionString (P.IsUnbalance b) = ["is_unbalance=" ++ show b]
mkOptionString (P.MaxPosition n) = ["max_position=" ++ show (unrefine n)]
mkOptionString (P.LabelGain ds) =
["label_gain=" ++ intercalate "," (map show ds)]
Expand Down
9 changes: 7 additions & 2 deletions src/LightGBM/Parameters.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module LightGBM.Parameters
( -- * Parameters
Param(..)
, Application(..)
, BinaryClassParam(..)
, Booster(..)
, DARTParam(..)
, Device(..)
Expand Down Expand Up @@ -123,7 +124,6 @@ data Param
| Alpha OpenProperFraction -- ^ Used in Huber loss and Quantile regression
| ScalePosWeight Double -- ^ Used in Binary classification
| BoostFromAverage Bool -- ^ Used only in RegressionL2 task
| IsUnbalance Bool -- ^ Used in Binary classification (set to true if training data are unbalanced)
| MaxPosition PositiveInt -- ^ Used in LambdaRank
| LabelGain [Double] -- ^ Used in LambdaRank
| RegSqrt Bool -- ^ Only used in RegressionL2
Expand All @@ -132,6 +132,11 @@ data Param
| TrainingMetric Bool
deriving (Eq, Show)

data BinaryClassParam =
IsUnbalance Bool -- ^ Set to true if training data are unbalanced
deriving (Eq, Show, Generic)
instance Hashable BinaryClassParam

--- | Parameters for Fair loss regression
data FairRegressionParam =
FairC PositiveDouble
Expand Down Expand Up @@ -269,7 +274,7 @@ type NumClasses = Natural
-- | LightGBM can be used for a variety of applications
data Application
= Regression RegressionApp
| BinaryClassification
| BinaryClassification [BinaryClassParam]
| MultiClass MultiClassStyle NumClasses
| CrossEntropy XEApp
| LambdaRank -- ^ A ranking algorithm
Expand Down

0 comments on commit b9c613f

Please sign in to comment.