Skip to content

Commit

Permalink
Encapsulate fair regression params
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkatz committed Jun 12, 2018
1 parent 750f2cc commit a51d433
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/LightGBM/Internal/CommandLineWrapper.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ applicationPMap =
[ (P.Regression P.L1, "regression_l1")
, (P.Regression P.L2, "regression_l2")
, (P.Regression P.Huber, "huber")
, (P.Regression P.Fair, "fair")
, (P.Regression P.Poisson, "poisson")
, (P.Regression P.Quantile, "quantile")
, (P.Regression P.MAPE, "mape")
Expand All @@ -92,6 +91,9 @@ applicationPMap =
mkTweedieString :: P.TweedieRegressionParam -> String
mkTweedieString (P.TweedieVariancePower p) = "tweedie_variance_power=" ++ show p

mkFairString :: P.FairRegressionParam -> String
mkFairString (P.FairC pd) = "fair_c=" ++ show (unrefine pd)

mkDartString :: P.DARTParam -> String
mkDartString (P.DropRate r) = "drop_rate=" ++ show (unrefine r)
mkDartString (P.SkipDrop r) = "skip_drop=" ++ show (unrefine r)
Expand All @@ -116,6 +118,8 @@ mkOptionString (P.Objective (P.MultiClass P.MultiClassOneVsAll n)) =
["application=multiclassova", "num_classes=" ++ show n]
mkOptionString (P.Objective (P.Regression (P.Tweedie tparams))) =
["application=tweedie"] ++ map mkTweedieString tparams
mkOptionString (P.Objective (P.Regression (P.Fair fparams))) =
["application=fair"] ++ map mkFairString fparams
mkOptionString (P.Objective a) = ["application=" ++ (applicationPMap M.! a)]
mkOptionString (P.BoostingType (P.DART dartParams)) =
["boosting=dart"] ++ map mkDartString dartParams
Expand Down Expand Up @@ -205,7 +209,6 @@ mkOptionString (P.ValidInitScoreFile f) =
mkOptionString (P.ForcedSplits f) = ["forced_splits=" ++ f]
mkOptionString (P.Sigmoid d) = ["sigmoid=" ++ show (unrefine d)]
mkOptionString (P.Alpha d) = ["alpha=" ++ show (unrefine d)]
mkOptionString (P.FairC d) = ["fair_c=" ++ show (unrefine d)]
mkOptionString (P.PoissonMaxDeltaStep d) =
["poisson_max_delta_step=" ++ show (unrefine d)]
mkOptionString (P.ScalePosWeight d) = ["scale_pos_weight=" ++ show d]
Expand Down
10 changes: 8 additions & 2 deletions src/LightGBM/Parameters.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module LightGBM.Parameters
, DARTParam(..)
, Device(..)
, Direction(..)
, FairRegressionParam(..)
, GOSSParam(..)
, GPUParam(..)
, LocalListenPort
Expand Down Expand Up @@ -119,7 +120,6 @@ data Param
| ForcedSplits FilePath
| Sigmoid PositiveDouble -- ^ Used in Binary classification and LambdaRank
| Alpha OpenProperFraction -- ^ Used in Huber loss and Quantile regression
| FairC PositiveDouble -- ^ Used in Fair loss
| PoissonMaxDeltaStep PositiveDouble -- ^ Used in Poisson regression
| ScalePosWeight Double -- ^ Used in Binary classification
| BoostFromAverage Bool -- ^ Used only in RegressionL2 task
Expand All @@ -132,6 +132,12 @@ data Param
| TrainingMetric Bool
deriving (Eq, Show)

-- | Parameters for Fair loss regression
data FairRegressionParam =
FairC PositiveDouble
deriving (Eq, Show, Generic)
instance Hashable FairRegressionParam

-- | Different types of Boosting approaches
data Booster
= GBDT -- ^ Gradient Boosting Decision Tree
Expand All @@ -151,7 +157,7 @@ data RegressionApp
= L1 -- ^ Absolute error metric
| L2 -- ^ RMS errror metric
| Huber
| Fair
| Fair [FairRegressionParam]
| Poisson
| Quantile
| MAPE
Expand Down

0 comments on commit a51d433

Please sign in to comment.