Skip to content

Commit

Permalink
Move getColumn out of DataSet
Browse files Browse the repository at this point in the history
It doesn't really belong there - currently the DataSet is just the
interface to feed data to LightGBM.  Nothing more.
  • Loading branch information
dpkatz committed Jun 21, 2018
1 parent f16b255 commit 8b002a9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
11 changes: 10 additions & 1 deletion examples/titanic/Main.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- | Titanic survivorship example

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

module Main where
Expand Down Expand Up @@ -52,6 +53,14 @@ accuracy predictions knowns =
totalCount = length matches
in fromIntegral matchCount / fromIntegral totalCount

-- | Convert a DataSet into a list of records for whatever type is relevant.
getColumn :: Read a => Int -> DS.DataSet -> IO [a]
getColumn colIndex DS.CSVFile {..} =
V.toList . readColumn colIndex (conv hasHeader) <$> BSL.readFile dataPath
where
conv (DS.HasHeader True) = CSV.HasHeader
conv (DS.HasHeader False) = CSV.NoHeader

trainModel :: IO LGBM.Model
trainModel =
TMP.withSystemTempFile "filtered_train" $ \trainFile trainHandle -> do
Expand All @@ -75,7 +84,7 @@ trainModel =
case predResults of
Left e -> error $ "Error preticting results: " ++ show e
Right predictionSet -> do
predictions <- DS.getColumn 0 predictionSet :: IO [Double]
predictions <- getColumn 0 predictionSet :: IO [Double]
LGBM.writeCsvFile predictionFile predictionSet

valData <- BSL.readFile valFile
Expand Down
16 changes: 1 addition & 15 deletions src/LightGBM/DataSet.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@ module LightGBM.DataSet (
DataSet (..)
, HasHeader(..)
, readCsvFile
, writeCsvFile
, getColumn) where
, writeCsvFile) where

import qualified Data.ByteString.Lazy as BSL
import qualified Data.Csv as CSV
import qualified Data.Vector as V
import System.Directory (copyFile)

import LightGBM.Utils.Csv (readColumn)

-- N.B. Right now it's just a data file, but we can add better types
-- (e.g. some sort of dataframe) as other options as we move forward.
-- | A set of data to use for training or prediction.
Expand Down Expand Up @@ -51,11 +45,3 @@ writeCsvFile ::
-> DataSet -- ^ The data to persist
-> IO ()
writeCsvFile outPath CSVFile {..} = copyFile dataPath outPath

-- | Convert a DataSet into a list of records for whatever type is relevant.
getColumn :: Read a => Int -> DataSet -> IO [a]
getColumn colIndex CSVFile {..} =
V.toList . readColumn colIndex (conv hasHeader) <$> BSL.readFile dataPath
where
conv (HasHeader True) = CSV.HasHeader
conv (HasHeader False) = CSV.NoHeader

0 comments on commit 8b002a9

Please sign in to comment.