Skip to content

Commit

Permalink
Improve the interface of the predict function
Browse files Browse the repository at this point in the history
Have the predict function return a DataSet consisting of the predicted
output rather than nothing.
  • Loading branch information
dpkatz committed Jun 2, 2018
1 parent 1e196f8 commit a4e8b61
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 15 deletions.
4 changes: 3 additions & 1 deletion examples/binary_classification/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ main = do
LGBM.trainNewModel modelName trainParams trainingData testData 100
case model of
Left e -> print e
Right m -> LGBM.predict m testData predictionFile
Right m -> do
_ <- LGBM.predict m testData predictionFile
return ()

modelB <- fileDiff modelName "golden_model.txt"
modelP <- fileDiff predictionFile "golden_prediction.txt"
Expand Down
4 changes: 3 additions & 1 deletion examples/lambdarank/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ main = do

case model of
Left e -> print e
Right m -> LGBM.predict m testData predictionFile
Right m -> do
_ <- LGBM.predict m testData predictionFile
return ()

modelB <- fileDiff modelName "golden_model.txt"
modelP <- fileDiff predictionFile "golden_prediction.txt"
Expand Down
4 changes: 3 additions & 1 deletion examples/multiclass_classification/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ main = do
LGBM.trainNewModel modelName trainParams trainingData testData 100
case model of
Left e -> print e
Right m -> LGBM.predict m testData predictionFile
Right m -> do
_ <- LGBM.predict m testData predictionFile
return ()

modelB <- fileDiff modelName "golden_model.txt"
modelP <- fileDiff predictionFile "golden_prediction.txt"
Expand Down
4 changes: 3 additions & 1 deletion examples/regression/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ main = do
LGBM.trainNewModel modelName trainParams trainingData testData 100
case model of
Left e -> print e
Right m -> LGBM.predict m testData predictionFile
Right m -> do
_ <- LGBM.predict m testData predictionFile
return ()

modelB <- fileDiff modelName "golden_model.txt"
modelP <- fileDiff predictionFile "golden_prediction.txt"
Expand Down
7 changes: 3 additions & 4 deletions examples/titanic/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ trainModel =
Right m -> do
print $ "Model trained and saved to file: " ++ modelName

LGBM.predict m validationData predictionFile
predictions <-
map read . lines <$> readFile predictionFile :: IO [Double]
predictionSet <- LGBM.predict m validationData predictionFile
predictions <- LGBM.dsToList predictionSet :: IO [Double]
valData <- BSL.readFile valFile
let knowns = V.toList $ readColumn 0 CSV.HasHeader valData :: [Int]
print $ "Self Accuracy: " ++ show (accuracy (round <$> predictions) knowns :: Double)
Expand All @@ -91,7 +90,7 @@ main = do
hClose testHandle
TMP.withSystemTempFile "predictions" $ \predFile predHandle -> do
hClose predHandle
LGBM.predict m (loadData testFile) predFile
_ <- LGBM.predict m (loadData testFile) predFile

withFile "TitanicSubmission.csv" WriteMode $ \submHandle -> do
testBytes <- BSL.readFile testFile
Expand Down
14 changes: 7 additions & 7 deletions src/LightGBM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ module LightGBM
loadDataFromFile
, DataSet
, HasHeader(..)
, dsToList
-- * Models
, Model
, trainNewModel
Expand Down Expand Up @@ -85,6 +86,10 @@ newtype HasHeader = HasHeader
loadDataFromFile :: HasHeader -> FilePath -> DataSet
loadDataFromFile = flip DataSet

-- | Convert a DataSet into a list of records for whatever type is relevant.
dsToList :: Read a => DataSet -> IO [a]
dsToList ds = map read . lines <$> readFile (dataPath ds)

-- | A model to use to make predictions
data Model = Model
{ modelPath :: FilePath
Expand Down Expand Up @@ -117,18 +122,13 @@ trainNewModel modelOutputPath trainingParams trainingData validationData numRoun
loadModelFromFile :: FilePath -> Model
loadModelFromFile = Model

-- FIXME:
-- - we might want to return the predictions in a better form
-- than just the file...
-- - Duplication of the exec path between predict and
-- train. Use a Reader monad maybe?
-- | Predict the results of new inputs and persist the results to an
-- output file.
predict ::
Model -- ^ A model to do prediction with
-> DataSet -- ^ The new input data for prediction
-> FilePath -- ^ Where to persist the prediction outputs
-> IO ()
-> IO DataSet -- ^ The prediction output DataSet
predict model inputData predictionOutputPath = do
let dataParams = [P.HasHeader (getHeader . hasHeader $ inputData)]
runParams =
Expand All @@ -138,4 +138,4 @@ predict model inputData predictionOutputPath = do
, P.OutputResult predictionOutputPath
]
_ <- CLW.run lightgbmExe $ concat [dataParams, runParams]
return ()
return $ DataSet predictionOutputPath (HasHeader False)

0 comments on commit a4e8b61

Please sign in to comment.