Skip to content
/ s2net Public

The generalized semi-supervised elastic-net

Notifications You must be signed in to change notification settings

jlaria/s2net

Folders and files

NameName
Last commit message
Last commit date
Jun 30, 2022
Jun 30, 2022
Nov 14, 2019
Jun 30, 2022
Dec 10, 2019
Dec 10, 2019
Mar 4, 2024
Mar 4, 2024
Mar 4, 2024
Dec 15, 2019
Jan 14, 2020
Jun 30, 2022
Jan 8, 2020
Mar 4, 2024
Dec 16, 2019
Jun 30, 2022
Jun 30, 2022
Jan 15, 2020
Jun 30, 2022
Jun 30, 2022

Repository files navigation

s2net

R-CMD-check

Overview

R package s2net

  • Our method extends the supervised elastic-net problem, and thus it is a practical solution to the problem of feature selection in semi-supervised contexts.
  • Its mathematical formulation is presented from a general perspective, covering a wide range of models.
  • We develop a flexible and fast implementation for s2net in R, written in C++ using RcppArmadillo and integrated into R via Rcpp modules.

Installation

You can install the released version of s2net from CRAN with:

install.packages("s2net")

The development version can be installed with:

devtools::install_github("jlaria/s2net", build_vignettes = TRUE)

Features

Example

This is a basic example which shows you how to use the package. Detailed examples can be found in the documentation and vignettes.

library(s2net)
# Auto-MPG dataset is included for benchmark
data("auto_mpg")

Semi-supervised data is made of a labeled dataset xL, the labels yL, and unlabeled data xU. Package s2net includes the function s2Data to process semi-supervised datasets.

head(auto_mpg$P2$xL, 2) # labeled data
#>    displacement horsepower weight acceleration year origin
#> 15          113         91   2372         15.0   70      3
#> 19           97         84   2130         14.5   70      3
head(auto_mpg$P2$yL, 2) # labels
#> [1] 24 27
head(auto_mpg$P2$xU, 2) # unlabeled data
#>   displacement horsepower weight acceleration year origin
#> 1          307         17   3504         12.0   70      1
#> 2          350         35   3693         11.5   70      1

train = s2Data(auto_mpg$P2$xL, auto_mpg$P2$yL, auto_mpg$P2$xU, preprocess = TRUE)

head(train$xL, 2)
#>    displacement horsepower     weight acceleration      year    origin2
#> 15    0.1788500  1.0544632  0.1799762   -0.6392182 -1.878311 -0.6575667
#> 19   -0.5510247  0.7397884 -0.5209896   -0.8471591 -1.878311 -0.6575667
#>     origin3
#> 15 1.356622
#> 19 1.356622

The data is centered and scaled, and factor variables are automatically converted to numerical dummies. Constant columns are also removed. If we wanted to use validation/test data, we must pre-process it according to the training data, with:

valid = s2Data(auto_mpg$P2$xU, auto_mpg$P2$yU, preprocess = train)

There are two ways to fit a semi-supervised elastic-net using s2net. The easiest way is using the function s2netR, that returns an object of S3 class s2netR.

model = s2netR(train, params = s2Params(lambda1 = 0.01, lambda2 = 0.01, gamma1 = 0.01, gamma2 = 100, gamma3 = 0.1))

class(model)
#> [1] "s2netR"
model$beta
#>             [,1]
#> [1,] -0.28152012
#> [2,]  0.04116177
#> [3,] -3.02848437
#> [4,]  0.61602553
#> [5,]  3.65674054
#> [6,]  0.71547766
#> [7,]  0.43169913

ypred = predict(model, valid$xL)

If we are fitting the semi-supervised elastic-net many times, using the same train data (for example, searching for the best hyperparameters), then it is faster to use the C++ class s2net instead.

obj = new(s2net, train, 0) # 0 = linear

# We fit the model with
obj$fit(s2Params(0.01, 0.01, 0.01, 100, 0.1), 0, 2) # frame = 0 (ExtJT), proj = 2 (auto)

obj$beta
#>             [,1]
#> [1,] -0.28700933
#> [2,]  0.04228791
#> [3,] -3.02580178
#> [4,]  0.61559052
#> [5,]  3.65723926
#> [6,]  0.71451133
#> [7,]  0.43040118

ypred = obj$predict(valid$xL, 0) #0=default predictions
# or
ypred = predict(obj, valid$xL)
#> Warning in if (class(newX) == "s2Data") {: the condition has length > 1 and only
#> the first element will be used

Further examples

More examples can be found in the package documentation and vignettes.

vignette(package="s2net")

About

The generalized semi-supervised elastic-net

Resources

Citation

Stars

Watchers

Forks

Packages

No packages published