Skip to content

Commit

Permalink
Improve parameter safety with more refinement types
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkatz committed Jun 7, 2018
1 parent 3dd1ddc commit 76014f3
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 84 deletions.
8 changes: 4 additions & 4 deletions examples/binary_classification/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ import LightGBM.Utils.Test (fileDiff)

trainParams :: [P.Param]
trainParams =
[ P.App P.Binary
[ P.Objective P.Binary
, P.Metric [P.BinaryLogloss, P.AUC]
, P.TrainingMetric True
, P.LearningRate 0.1
, P.NumLeaves 63
, P.LearningRate $$(refineTH 0.1)
, P.NumLeaves $$(refineTH 63)
, P.FeatureFraction $$(refineTH 0.8)
, P.BaggingFreq $$(refineTH 5)
, P.BaggingFraction $$(refineTH 0.8)
, P.MinDataInLeaf 50
, P.MinSumHessianInLeaf 5.0
, P.MinSumHessianInLeaf $$(refineTH 5.0)
, P.IsSparse True
]

Expand Down
6 changes: 3 additions & 3 deletions examples/lambdarank/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import LightGBM.Utils.Test (fileDiff)

trainParams :: [P.Param]
trainParams =
[ P.App P.LambdaRank
[ P.Objective P.LambdaRank
, P.Metric [P.NDCG (Just [1, 3, 5])]
, P.TrainingMetric True
, P.LearningRate 0.1
, P.LearningRate $$(refineTH 0.1)
, P.BaggingFreq $$(refineTH 1)
, P.BaggingFraction $$(refineTH 0.9)
, P.MinDataInLeaf 50
, P.MinSumHessianInLeaf 5.0
, P.MinSumHessianInLeaf $$(refineTH 5.0)
, P.IsSparse True
]

Expand Down
4 changes: 2 additions & 2 deletions examples/multiclass_classification/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import LightGBM.Utils.Test (fileDiff)

trainParams :: [P.Param]
trainParams =
[ P.App (P.MultiClass P.MultiClassSimple 5)
[ P.Objective (P.MultiClass P.MultiClassSimple 5)
, P.TrainingMetric True
, P.EarlyStoppingRounds $$(refineTH 10)
, P.LearningRate 0.05
, P.LearningRate $$(refineTH 0.05)
]

-- The data files for this test don't have any headers
Expand Down
6 changes: 3 additions & 3 deletions examples/regression/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import LightGBM.Utils.Test (fileDiff)

trainParams :: [P.Param]
trainParams =
[ P.App $ P.Regression P.L2
[ P.Objective $ P.Regression P.L2
, P.TrainingMetric True
, P.LearningRate 0.05
, P.LearningRate $$(refineTH 0.05)
, P.FeatureFraction $$(refineTH 0.9)
, P.BaggingFreq $$(refineTH 5)
, P.BaggingFraction $$(refineTH 0.8)
, P.MinDataInLeaf 100
, P.MinSumHessianInLeaf 5.0
, P.MinSumHessianInLeaf $$(refineTH 5.0)
, P.IsSparse True
]

Expand Down
8 changes: 4 additions & 4 deletions examples/titanic/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ import ConvertData (csvFilter, predsToKaggleFormat, testFilter)

trainParams :: [P.Param]
trainParams =
[ P.App P.Binary
[ P.Objective P.Binary
, P.Metric [P.BinaryLogloss, P.AUC]
, P.TrainingMetric True
, P.LearningRate 0.1
, P.NumLeaves 63
, P.LearningRate $$(refineTH 0.1)
, P.NumLeaves $$(refineTH 63)
, P.FeatureFraction $$(refineTH 0.8)
, P.BaggingFreq $$(refineTH 5)
, P.BaggingFraction $$(refineTH 0.8)
, P.MinDataInLeaf 50
, P.MinSumHessianInLeaf 5.0
, P.MinSumHessianInLeaf $$(refineTH 5.0)
, P.IsSparse True
, P.LabelColumn $ P.ColName "Survived"
, P.IgnoreColumns [P.ColName "PassengerId"]
Expand Down
60 changes: 31 additions & 29 deletions src/LightGBM/Internal/CommandLineWrapper.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ mkTweedieString :: P.TweedieRegressionParam -> String
mkTweedieString (P.TweedieVariancePower p) = "tweedie_variance_power=" ++ show p

mkDartString :: P.DARTParam -> String
mkDartString (P.DropRate r) = "drop_rate=" ++ show r
mkDartString (P.SkipDrop r) = "skip_drop=" ++ show r
mkDartString (P.MaxDrop r) = "max_drop=" ++ show r
mkDartString (P.DropRate r) = "drop_rate=" ++ show (unrefine r)
mkDartString (P.SkipDrop r) = "skip_drop=" ++ show (unrefine r)
mkDartString (P.MaxDrop r) = "max_drop=" ++ show (unrefine r)
mkDartString (P.UniformDrop b) = "uniform_drop=" ++ show b
mkDartString (P.XGBoostDARTMode b) = "xgboost_dart_mode=" ++ show b
mkDartString (P.DropSeed b) = "drop_seed=" ++ show b
Expand All @@ -109,13 +109,13 @@ colSelPrefix (P.ColName _) = "name:"

-- | Construct the option string for the command.
mkOptionString :: P.Param -> [String]
mkOptionString (P.App (P.MultiClass P.MultiClassSimple n)) =
mkOptionString (P.Objective (P.MultiClass P.MultiClassSimple n)) =
["application=multiclass", "num_classes=" ++ show n]
mkOptionString (P.App (P.MultiClass P.MultiClassOneVsAll n)) =
mkOptionString (P.Objective (P.MultiClass P.MultiClassOneVsAll n)) =
["application=multiclassova", "num_classes=" ++ show n]
mkOptionString (P.App (P.Regression (P.Tweedie tparams))) =
mkOptionString (P.Objective (P.Regression (P.Tweedie tparams))) =
["application=tweedie"] ++ map mkTweedieString tparams
mkOptionString (P.App a) = ["application=" ++ (applicationPMap M.! a)]
mkOptionString (P.Objective a) = ["application=" ++ (applicationPMap M.! a)]
mkOptionString (P.BoostingType (P.DART dartParams)) =
["boosting=dart"] ++ map mkDartString dartParams
mkOptionString (P.BoostingType b) = ["boosting=" ++ (boosterPMap M.! b)]
Expand All @@ -124,8 +124,8 @@ mkOptionString (P.ValidationData fs) =
["valid=" ++ intercalate "," (map show fs)]
mkOptionString (P.PredictionData f) = ["data=" ++ show f]
mkOptionString (P.Iterations n) = ["num_iterations=" ++ show n]
mkOptionString (P.LearningRate d) = ["learning_rate=" ++ show d]
mkOptionString (P.NumLeaves n) = ["num_leaves=" ++ show n]
mkOptionString (P.LearningRate d) = ["learning_rate=" ++ show (unrefine d)]
mkOptionString (P.NumLeaves n) = ["num_leaves=" ++ show (unrefine n)]
mkOptionString (P.Parallelism P.Serial) = ["tree_learner=serial"]
mkOptionString (P.Parallelism (P.FeaturePar params)) =
"tree_learner=feature" : mkParaOptions params
Expand All @@ -137,10 +137,11 @@ mkOptionString (P.NumThreads n) = ["num_threads=" ++ show n]
mkOptionString (P.Device P.CPU) = ["device=cpu"]
mkOptionString (P.Device (P.GPU gpuParams)) =
"device=gpu" : map mkGPUOption gpuParams
mkOptionString (P.RandomSeed s) = ["seed=" ++ show s]
mkOptionString (P.MaxDepth n) = ["max_depth=" ++ show n]
mkOptionString (P.MinDataInLeaf n) = ["min_data_in_leaf=" ++ show n]
mkOptionString (P.MinSumHessianInLeaf d) =
["min_sum_hessian_in_leaf=" ++ show d]
["min_sum_hessian_in_leaf=" ++ show (unrefine d)]
mkOptionString (P.FeatureFraction f) =
["feature_fraction=" ++ show (unrefine f)]
mkOptionString (P.FeatureFractionSeed s) = ["feature_fraction_seed=" ++ show s]
Expand All @@ -150,22 +151,22 @@ mkOptionString (P.BaggingFreq n) = ["bagging_freq=" ++ show (unrefine n)]
mkOptionString (P.BaggingFractionSeed n) = ["bagging_seed=" ++ show n]
mkOptionString (P.EarlyStoppingRounds r) =
["early_stopping_round=" ++ show (unrefine r)]
mkOptionString (P.Regularization_L1 d) = ["lambda_l1=" ++ show d]
mkOptionString (P.Regularization_L2 d) = ["lambda_l2=" ++ show d]
mkOptionString (P.MaxDeltaStep s) = ["max_delta_step=" ++ show s]
mkOptionString (P.MinSplitGain sg) = ["min_split_gain=" ++ show sg]
mkOptionString (P.TopRate b) = ["top_rate=" ++ show b]
mkOptionString (P.OtherRate b) = ["other_rate=" ++ show b]
mkOptionString (P.MinDataPerGroup b) = ["min_data_per_group=" ++ show b]
mkOptionString (P.MaxCatThreshold b) = ["max_cat_threshold=" ++ show b]
mkOptionString (P.CatSmooth b) = ["cat_smooth=" ++ show b]
mkOptionString (P.CatL2 b) = ["cat_l2=" ++ show b]
mkOptionString (P.MaxCatToOneHot b) = ["max_cat_to_onehot=" ++ show b]
mkOptionString (P.TopK b) = ["top_k=" ++ show b]
mkOptionString (P.Regularization_L1 d) = ["lambda_l1=" ++ show (unrefine d)]
mkOptionString (P.Regularization_L2 d) = ["lambda_l2=" ++ show (unrefine d)]
mkOptionString (P.MaxDeltaStep s) = ["max_delta_step=" ++ show (unrefine s)]
mkOptionString (P.MinSplitGain sg) = ["min_split_gain=" ++ show (unrefine sg)]
mkOptionString (P.TopRate b) = ["top_rate=" ++ show (unrefine b)]
mkOptionString (P.OtherRate b) = ["other_rate=" ++ show (unrefine b)]
mkOptionString (P.MinDataPerGroup b) = ["min_data_per_group=" ++ show (unrefine b)]
mkOptionString (P.MaxCatThreshold b) = ["max_cat_threshold=" ++ show (unrefine b)]
mkOptionString (P.CatSmooth b) = ["cat_smooth=" ++ show (unrefine b)]
mkOptionString (P.CatL2 b) = ["cat_l2=" ++ show (unrefine b)]
mkOptionString (P.MaxCatToOneHot b) = ["max_cat_to_onehot=" ++ show (unrefine b)]
mkOptionString (P.TopK b) = ["top_k=" ++ show (unrefine b)]
mkOptionString (P.MonotoneConstraint cs) =
["monotone_constraint=" ++ intercalate "," (map (directionPMap M.!) cs)]
mkOptionString (P.MaxBin n) = ["max_bin=" ++ show (unrefine n)]
mkOptionString (P.MinDataInBin n) = ["min_data_in_bin=" ++ show n]
mkOptionString (P.MinDataInBin n) = ["min_data_in_bin=" ++ show (unrefine n)]
mkOptionString (P.DataRandomSeed i) = ["data_random_seed=" ++ show i]
mkOptionString (P.OutputModel f) = ["output_model=" ++ show f]
mkOptionString (P.InputModel f) = ["input_model=" ++ show f]
Expand All @@ -189,7 +190,7 @@ mkOptionString (P.PredictRawScore b) = ["predict_raw_score=" ++ show b]
mkOptionString (P.PredictLeafIndex b) = ["predict_leaf_index=" ++ show b]
mkOptionString (P.PredictContrib b) = ["predict_contrib=" ++ show b]
mkOptionString (P.BinConstructSampleCount n) =
["bin_construct_sample_cnt=" ++ show n]
["bin_construct_sample_cnt=" ++ show (unrefine n)]
mkOptionString (P.NumIterationsPredict n) =
["num_iterations_predict=" ++ show n]
mkOptionString (P.PredEarlyStop b) = ["pred_early_stop=" ++ show b]
Expand All @@ -201,14 +202,15 @@ mkOptionString (P.InitScoreFile f) = ["init_score_file=" ++ f]
mkOptionString (P.ValidInitScoreFile f) =
["valid_init_score_file=" ++ intercalate "," f]
mkOptionString (P.ForcedSplits f) = ["forced_splits=" ++ f]
mkOptionString (P.Sigmoid d) = ["sigmoid=" ++ show d]
mkOptionString (P.Alpha d) = ["alpha=" ++ show d]
mkOptionString (P.FairC d) = ["fair_c=" ++ show d]
mkOptionString (P.PoissonMaxDeltaStep d) = ["poisson_max_delta_step=" ++ show d]
mkOptionString (P.Sigmoid d) = ["sigmoid=" ++ show (unrefine d)]
mkOptionString (P.Alpha d) = ["alpha=" ++ show (unrefine d)]
mkOptionString (P.FairC d) = ["fair_c=" ++ show (unrefine d)]
mkOptionString (P.PoissonMaxDeltaStep d) =
["poisson_max_delta_step=" ++ show (unrefine d)]
mkOptionString (P.ScalePosWeight d) = ["scale_pos_weight=" ++ show d]
mkOptionString (P.BoostFromAverage b) = ["boost_from_average=" ++ show b]
mkOptionString (P.IsUnbalance b) = ["is_unbalance=" ++ show b]
mkOptionString (P.MaxPosition n) = ["max_position=" ++ show n]
mkOptionString (P.MaxPosition n) = ["max_position=" ++ show (unrefine n)]
mkOptionString (P.LabelGain ds) =
["label_gain=" ++ intercalate "," (map show ds)]
mkOptionString (P.RegSqrt b) = ["reg_sqrt=" ++ show b]
Expand Down
84 changes: 45 additions & 39 deletions src/LightGBM/Parameters.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,48 +45,54 @@ import GHC.Generics (Generic)
import Numeric.Natural (Natural)

import LightGBM.Utils.Types
( OneToTwoLeftSemiClosed
( IntGreaterThanOne
, LeftOpenProperFraction
, NonNegativeDouble
, OneToTwoLeftSemiClosed
, OpenProperFraction
, PositiveDouble
, PositiveInt
, ProperFraction
)

-- | Parameters control the behavior of lightGBM.
data Param
= App Application -- ^ Application (regression, binary classification, etc.)
= Objective Application -- ^ Regression, binary classification, etc.
| BoostingType Booster -- ^ Booster to apply - by default is 'GBDT'
| TrainingData FilePath -- ^ Path to training data
| ValidationData [FilePath] -- ^ Paths to validation data files (supports multi-validation)
| PredictionData FilePath -- ^ Path to data to use for a prediction
| Iterations Natural -- ^ Number of boosting iterations - default is 100
| LearningRate Double
| NumLeaves Natural
| Parallelism ParallelismStyle
| NumThreads Natural
| Device Device
| MaxDepth Natural
| MinDataInLeaf Natural
| MinSumHessianInLeaf Double
| FeatureFraction ProperFraction
| FeatureFractionSeed Int
| BaggingFraction ProperFraction
| LearningRate PositiveDouble -- ^ Scale how quickly parameters change in training
| NumLeaves PositiveInt -- ^ Maximum number of leaves in one tree
| Parallelism ParallelismStyle -- ^ Called 'tree_learner' in the LightGBM docs
| NumThreads Natural -- ^ Number of threads for LightGBM to use
| Device Device -- ^ GPU or CPU
| RandomSeed Int -- ^ A random seed used to generate other random seeds
| MaxDepth Natural -- ^ Limit the depth of the tree model
| MinDataInLeaf Natural -- ^ Minimum data count in a leaf
| MinSumHessianInLeaf NonNegativeDouble -- ^ Minimal sum of the Hessian in one leaf
| BaggingFraction LeftOpenProperFraction
| BaggingFreq PositiveInt
| BaggingFractionSeed Int
| EarlyStoppingRounds PositiveInt
| Regularization_L1 Double
| Regularization_L2 Double
| MaxDeltaStep Double
| MinSplitGain Double
| TopRate Double -- ^ GOSS only
| OtherRate Double -- ^ GOSS only
| MinDataPerGroup Natural -- ^ Minimum number of data points per categorial group
| MaxCatThreshold Natural
| CatSmooth Double
| CatL2 Double -- ^ L2 regularization in categorical split
| FeatureFraction LeftOpenProperFraction
| FeatureFractionSeed Int
| EarlyStoppingRounds PositiveInt -- ^ Stop training if a validation metric doesn't improve in the last n rounds
| Regularization_L1 NonNegativeDouble
| Regularization_L2 NonNegativeDouble
| MaxDeltaStep PositiveDouble
| MinSplitGain NonNegativeDouble
| TopRate ProperFraction -- ^ GOSS only
| OtherRate ProperFraction -- ^ GOSS only
| MinDataPerGroup PositiveInt -- ^ Minimum number of data points per categorial group
| MaxCatThreshold PositiveInt
| CatSmooth NonNegativeDouble
| CatL2 NonNegativeDouble -- ^ L2 regularization in categorical split
| MaxCatToOneHot PositiveInt
| TopK Natural -- ^ VotingPar only
| TopK PositiveInt -- ^ VotingPar only
| MonotoneConstraint [Direction] -- ^ Length of directions = number of features
| MaxBin PositiveInt
| MinDataInBin Natural
| MaxBin IntGreaterThanOne
| MinDataInBin PositiveInt
| DataRandomSeed Int
| OutputModel FilePath -- ^ Where to persist the model after training
| InputModel FilePath -- ^ Filepath to a persisted model to use for prediction or additional training
Expand All @@ -104,7 +110,7 @@ data Param
| PredictRawScore Bool -- ^ Prediction Only; true = raw scores only, false = transformed scores
| PredictLeafIndex Bool -- ^ Prediction Only
| PredictContrib Bool -- ^ Prediction Only
| BinConstructSampleCount Natural
| BinConstructSampleCount PositiveInt
| NumIterationsPredict Natural -- ^ Prediction Only; how many trained predictions
| PredEarlyStop Bool
| PredEarlyStopFreq Natural
Expand All @@ -114,14 +120,14 @@ data Param
| InitScoreFile FilePath
| ValidInitScoreFile [FilePath]
| ForcedSplits FilePath
| Sigmoid Double -- ^ Used in Binary classification and LambdaRank
| Alpha Double -- ^ Used in Huber loss and Quantile regression
| FairC Double -- ^ Used in Fair loss
| PoissonMaxDeltaStep Double -- ^ Used in Poisson regression
| Sigmoid PositiveDouble -- ^ Used in Binary classification and LambdaRank
| Alpha OpenProperFraction -- ^ Used in Huber loss and Quantile regression
| FairC PositiveDouble -- ^ Used in Fair loss
| PoissonMaxDeltaStep PositiveDouble -- ^ Used in Poisson regression
| ScalePosWeight Double -- ^ Used in Binary classification
| BoostFromAverage Bool -- ^ Used only in RegressionL2 task
| IsUnbalance Bool -- ^ Used in Binary classification (set to true if training data are unbalanced)
| MaxPosition Natural -- ^ Used in LambdaRank
| MaxPosition PositiveInt -- ^ Used in LambdaRank
| LabelGain [Double] -- ^ Used in LambdaRank
| RegSqrt Bool -- ^ Only used in RegressionL2
| Metric [Metric] -- ^ Loss Metric
Expand Down Expand Up @@ -198,8 +204,8 @@ data ParallelismStyle
instance Hashable ParallelismStyle

data GPUParam
= GpuPlatformId Int
| GpuDeviceId Int
= GpuPlatformId Natural
| GpuDeviceId Natural
| GpuUseDP Bool
deriving (Eq, Show, Generic)
instance Hashable GPUParam
Expand Down Expand Up @@ -264,15 +270,15 @@ data Application
| Binary -- ^ Binary classification
| MultiClass MultiClassStyle NumClasses -- ^ Multi-class
| CrossEntropy XEApp
| LambdaRank
| LambdaRank -- ^ A ranking algo
deriving (Eq, Show, Generic)
instance Hashable Application

-- | Parameters exclusively for the DART booster
data DARTParam
= DropRate Double
| SkipDrop Double
| MaxDrop PositiveInt
= DropRate ProperFraction -- ^ Dropout rate
| SkipDrop ProperFraction -- ^ Probablility of skipping a drop
| MaxDrop PositiveInt -- ^ Max number of dropped trees on one iteration
| UniformDrop Bool
| XGBoostDARTMode Bool
| DropSeed Int
Expand Down
23 changes: 23 additions & 0 deletions src/LightGBM/Utils/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ module LightGBM.Utils.Types
-- * Refined Types
OneToTwoLeftSemiClosed
, ProperFraction
, LeftOpenProperFraction
, OpenProperFraction
, PositiveInt
, IntGreaterThanOne
, PositiveDouble
, NonNegativeDouble

-- * Logging Types
, OutLog (..)
, ErrLog (..)
Expand All @@ -26,13 +32,30 @@ import qualified Refined as R
type ProperFraction
= R.Refined (R.And (R.Not (R.LessThan 0)) (R.Not (R.GreaterThan 1))) Double

-- | A 'Double' in the range (0, 1]
type LeftOpenProperFraction
= R.Refined (R.And (R.GreaterThan 0) (R.Not (R.GreaterThan 1))) Double

-- | A 'Double' in the range (0, 1)
type OpenProperFraction
= R.Refined (R.And (R.GreaterThan 0) (R.LessThan 1)) Double

-- | A 'Double' in the range [1, 2)
type OneToTwoLeftSemiClosed
= R.Refined (R.And (R.Not (R.LessThan 1)) (R.LessThan 2)) Double

-- | An 'Int' in the range [1, @'maxBound' :: 'Int'@]
type PositiveInt = R.Refined R.Positive Int

-- | An 'Int' in the range [2, @'maxBound' :: 'Int'@]
type IntGreaterThanOne = R.Refined (R.GreaterThan 1) Int

-- | A 'Double' > 0.0
type PositiveDouble = R.Refined R.Positive Double

-- | A 'Double' >= 0.0
type NonNegativeDouble = R.Refined R.NonNegative Double

instance (Hashable a, R.Predicate p a) => Hashable (R.Refined p a) where
hashWithSalt salt refinedA = hashWithSalt salt (R.unrefine refinedA)

Expand Down

0 comments on commit 76014f3

Please sign in to comment.