sgs reproducible example

Fabio Feser

2024-07-13

Introduction

The sgs R package fits sparse-group SLOPE (SGS) and group SLOPE (gSLOPE) models. The package has implementations for linear and logisitic regression, both of which are demonstrated here. The package also uses strong screening rules to speed up computational time, described in detail in F. Feser, M. Evangelou (2024) “Strong Screening Rules for Group-based SLOPE Models”. Screening rules are applied by default here. However, the impact of screening is demonstrated in the Screening section at the end.

Sparse-group SLOPE

Sparse-group SLOPE (SGS) is a penalised regression approach that performs bi-level selection with FDR control under orthogonal designs. SGS is described in detail in F. Feser, M. Evangelou (2023) “Sparse-group SLOPE: adaptive bi-level selection with FDR-control”.

Linear regression

Data

For this example, a \(400 \times 500\) input matrix is used with a simple grouping structure, sampled from a multivariate Gaussian distribution with no correlation.

library(sgs)
groups = c(rep(1:20, each = 3), rep(21:40, each = 4), rep(41:60,
    each = 5), rep(61:80, each = 6), rep(81:100, each = 7))

data = gen_toy_data(p = 500, n = 400, groups = groups, seed_id = 3)

Fitting an SGS model

We now fit an SGS model to the data using linear regression. The SGS model has many different hyperparameters which can be tuned/selected. Of particular importance is the \(\lambda\) parameter, which defines the level of sparsity in the model. First, we select this manually and then next use cross-validation to tune it. The other parameters we leave as their default values, although they can easily be changed.

model = fit_sgs(X = data$X, y = data$y, groups = groups, type = "linear",
    lambda = 0.5, alpha = 0.95, vFDR = 0.1, gFDR = 0.1, standardise = "l2",
    intercept = TRUE, verbose = FALSE, screen = TRUE)

Note: we have fit an intercept and applied \(\ell_2\) standardisation. This is the recommended usage when applying SGS with linear regression. The lambda values can also be calculated automatically, starting at the null model and continuing as specified by and :

model_path = fit_sgs(X = data$X, y = data$y, groups = groups,
    type = "linear", lambda = "path", path_length = 5, min_frac = 0.1,
    alpha = 0.95, vFDR = 0.1, gFDR = 0.1, standardise = "l2",
    intercept = TRUE, verbose = FALSE, screen = TRUE)

Output of model

The package provides several useful outputs after fitting a model. The vector shows the fitted values (note the intercept). We can also recover the indices of the non-zero variables and groups, which are indexed from the first variable, not the intercept.

model$beta[model$selected_var + 1]  # the +1 is to account for the intercept
##  [1] -1.6734720 -1.4893849  3.1678789  1.6350236  5.0369209 -1.4619740
##  [7]  1.2779414 -0.9364065  4.9809334 -0.6548784 -1.4470782 -2.2343330
## [13]  1.6188817 -1.6665858 -0.7298015
model$group_effects[model$selected_grp]
## [1] 3.879978 5.295647 1.461974 1.277941 5.068190 1.588364 2.234333 2.435343
model$selected_var
##  [1]  97  99 100 133 136 170 217 231 234 260 263 334 391 393 394
model$selected_grp
## [1] 30 39 46 56 59 64 76 85

Defining a function that lets us calculate various metrics (including the FDR and sensitivity):

fdr_sensitivity = function(fitted_ids, true_ids, num_coef) {
    # calculates FDR, FPR, and sensitivity
    num_true = length(intersect(fitted_ids, true_ids))
    num_false = length(fitted_ids) - num_true
    num_missed = length(true_ids) - num_true
    num_true_negatives = num_coef - length(true_ids)
    out = c()
    out$fdr = num_false/(num_true + num_false)
    if (is.nan(out$fdr)) {
        out$fdr = 0
    }
    out$sensitivity = num_true/length(true_ids)
    if (length(true_ids) == 0) {
        out$sensitivity = 1
    }
    out$fpr = num_false/num_true_negatives
    out$f1 = (2 * num_true)/(2 * num_true + num_false + num_missed)
    if (is.nan(out$f1)) {
        out$f1 = 1
    }
    return(out)
}

Calculating relevant metrics give

fdr_sensitivity(fitted_ids = model$selected_var, true_ids = data$true_var_id,
    num_coef = 500)
## $fdr
## [1] 0
## 
## $sensitivity
## [1] 0.5357143
## 
## $fpr
## [1] 0
## 
## $f1
## [1] 0.6976744
fdr_sensitivity(fitted_ids = model$selected_grp, true_ids = data$true_grp_id,
    num_coef = 100)
## $fdr
## [1] 0
## 
## $sensitivity
## [1] 0.8
## 
## $fpr
## [1] 0
## 
## $f1
## [1] 0.8888889

The model is currently too sparse, as our choice of \(\lambda\) is too high. We can instead use cross-validation.

Cross validation

Cross-validation is used to fit SGS models along a \(\lambda\) path of length \(20\). The first value, \(\lambda_\text{max}\), is chosen to give the null model and the path is terminated at \(\lambda_\text{min} = \delta \lambda_\text{max}\), where \(\delta\) is some value between \(0\) and \(1\) (given by in the function). The 1se rule (as in the package) is used to choose the optimal model.

cv_model = fit_sgs_cv(X = data$X, y = data$y, groups = groups,
    type = "linear", path_length = 20, nfolds = 10, alpha = 0.95,
    vFDR = 0.1, gFDR = 0.1, min_frac = 0.05, standardise = "l2",
    intercept = TRUE, verbose = TRUE, screen = TRUE)
##  [1] "Fold 1/10 done. Error: 10556.1875568791"
##  [2] "Fold 1/10 done. Error: 10176.3075312482"
##  [3] "Fold 1/10 done. Error: 9594.76270583567"
##  [4] "Fold 1/10 done. Error: 8927.68806554738"
##  [5] "Fold 1/10 done. Error: 8258.45089987855"
##  [6] "Fold 1/10 done. Error: 7300.23189665579"
##  [7] "Fold 1/10 done. Error: 5632.95350477924"
##  [8] "Fold 1/10 done. Error: 4297.15577778921"
##  [9] "Fold 1/10 done. Error: 3297.94689016112"
## [10] "Fold 1/10 done. Error: 2552.59512158659"
## [11] "Fold 1/10 done. Error: 1966.64378517332"
## [12] "Fold 1/10 done. Error: 1488.63381602324"
## [13] "Fold 1/10 done. Error: 1117.25280329421"
## [14] "Fold 1/10 done. Error: 846.029195120426"
## [15] "Fold 1/10 done. Error: 648.663435142151"
## [16] "Fold 1/10 done. Error: 491.946844991723"
## [17] "Fold 1/10 done. Error: 372.100964214661"
## [18] "Fold 1/10 done. Error: 283.271603105957"
## [19] "Fold 1/10 done. Error: 220.055000961875"
## [20] "Fold 1/10 done. Error: 173.002145611579"
##  [1] "Fold 2/10 done. Error: 9701.5337813471" 
##  [2] "Fold 2/10 done. Error: 9327.9632666549" 
##  [3] "Fold 2/10 done. Error: 8515.37096726923"
##  [4] "Fold 2/10 done. Error: 7730.98407526827"
##  [5] "Fold 2/10 done. Error: 6954.61096532999"
##  [6] "Fold 2/10 done. Error: 6084.91002790858"
##  [7] "Fold 2/10 done. Error: 4640.4911961945" 
##  [8] "Fold 2/10 done. Error: 3518.28844172758"
##  [9] "Fold 2/10 done. Error: 2682.35798174033"
## [10] "Fold 2/10 done. Error: 2074.0985691893" 
## [11] "Fold 2/10 done. Error: 1600.0457464207" 
## [12] "Fold 2/10 done. Error: 1201.21077110527"
## [13] "Fold 2/10 done. Error: 907.519555401387"
## [14] "Fold 2/10 done. Error: 693.062982682486"
## [15] "Fold 2/10 done. Error: 532.020606657758"
## [16] "Fold 2/10 done. Error: 414.085312963308"
## [17] "Fold 2/10 done. Error: 325.101634094595"
## [18] "Fold 2/10 done. Error: 255.817644750413"
## [19] "Fold 2/10 done. Error: 204.450581281332"
## [20] "Fold 2/10 done. Error: 165.863886563115"
##  [1] "Fold 3/10 done. Error: 8933.04777724514"
##  [2] "Fold 3/10 done. Error: 8570.34645083006"
##  [3] "Fold 3/10 done. Error: 7787.05616993491"
##  [4] "Fold 3/10 done. Error: 6950.61817164697"
##  [5] "Fold 3/10 done. Error: 6165.5026457958" 
##  [6] "Fold 3/10 done. Error: 5012.13323372854"
##  [7] "Fold 3/10 done. Error: 3999.42239272502"
##  [8] "Fold 3/10 done. Error: 3123.6709338521" 
##  [9] "Fold 3/10 done. Error: 2465.4201557276" 
## [10] "Fold 3/10 done. Error: 1965.61776389263"
## [11] "Fold 3/10 done. Error: 1548.0319946035" 
## [12] "Fold 3/10 done. Error: 1184.69549543653"
## [13] "Fold 3/10 done. Error: 912.275358953856"
## [14] "Fold 3/10 done. Error: 709.432475321353"
## [15] "Fold 3/10 done. Error: 557.533573418488"
## [16] "Fold 3/10 done. Error: 430.036049657087"
## [17] "Fold 3/10 done. Error: 330.342085692851"
## [18] "Fold 3/10 done. Error: 253.391638672593"
## [19] "Fold 3/10 done. Error: 195.043122409481"
## [20] "Fold 3/10 done. Error: 152.408529288772"
##  [1] "Fold 4/10 done. Error: 10447.1219276337"
##  [2] "Fold 4/10 done. Error: 10160.9993436738"
##  [3] "Fold 4/10 done. Error: 9563.86072998723"
##  [4] "Fold 4/10 done. Error: 8826.81026023373"
##  [5] "Fold 4/10 done. Error: 7889.70074457788"
##  [6] "Fold 4/10 done. Error: 6741.96561308748"
##  [7] "Fold 4/10 done. Error: 5374.76551609077"
##  [8] "Fold 4/10 done. Error: 4182.3609752351" 
##  [9] "Fold 4/10 done. Error: 3264.44160220578"
## [10] "Fold 4/10 done. Error: 2573.51835514238"
## [11] "Fold 4/10 done. Error: 2039.20607228275"
## [12] "Fold 4/10 done. Error: 1592.21708242753"
## [13] "Fold 4/10 done. Error: 1198.30261062227"
## [14] "Fold 4/10 done. Error: 910.928513043909"
## [15] "Fold 4/10 done. Error: 693.396868636509"
## [16] "Fold 4/10 done. Error: 533.675122500809"
## [17] "Fold 4/10 done. Error: 416.531934070464"
## [18] "Fold 4/10 done. Error: 318.765958323353"
## [19] "Fold 4/10 done. Error: 246.79697595908" 
## [20] "Fold 4/10 done. Error: 193.907644941551"
##  [1] "Fold 5/10 done. Error: 7348.93447972298"
##  [2] "Fold 5/10 done. Error: 7084.02742456543"
##  [3] "Fold 5/10 done. Error: 6484.93419843303"
##  [4] "Fold 5/10 done. Error: 5710.58799293503"
##  [5] "Fold 5/10 done. Error: 5032.16053019211"
##  [6] "Fold 5/10 done. Error: 4093.52749333715"
##  [7] "Fold 5/10 done. Error: 3192.87606426484"
##  [8] "Fold 5/10 done. Error: 2503.55047741129"
##  [9] "Fold 5/10 done. Error: 2005.57925174851"
## [10] "Fold 5/10 done. Error: 1644.10226091624"
## [11] "Fold 5/10 done. Error: 1285.41228246945"
## [12] "Fold 5/10 done. Error: 1006.42768966073"
## [13] "Fold 5/10 done. Error: 759.301423714843"
## [14] "Fold 5/10 done. Error: 578.076488783258"
## [15] "Fold 5/10 done. Error: 444.607646100111"
## [16] "Fold 5/10 done. Error: 339.87419820333" 
## [17] "Fold 5/10 done. Error: 263.086500284417"
## [18] "Fold 5/10 done. Error: 204.145389120758"
## [19] "Fold 5/10 done. Error: 160.79435990752" 
## [20] "Fold 5/10 done. Error: 128.764425644495"
##  [1] "Fold 6/10 done. Error: 11876.2011690611"
##  [2] "Fold 6/10 done. Error: 11584.1155930641"
##  [3] "Fold 6/10 done. Error: 11041.8870591511"
##  [4] "Fold 6/10 done. Error: 9964.99368626758"
##  [5] "Fold 6/10 done. Error: 8950.04746639382"
##  [6] "Fold 6/10 done. Error: 7485.12307515559"
##  [7] "Fold 6/10 done. Error: 5870.82801499167"
##  [8] "Fold 6/10 done. Error: 4642.42850063496"
##  [9] "Fold 6/10 done. Error: 3718.24939933727"
## [10] "Fold 6/10 done. Error: 3019.80647995502"
## [11] "Fold 6/10 done. Error: 2407.43347297786"
## [12] "Fold 6/10 done. Error: 1865.49907660826"
## [13] "Fold 6/10 done. Error: 1425.59039541041"
## [14] "Fold 6/10 done. Error: 1099.87311555504"
## [15] "Fold 6/10 done. Error: 854.062276685889"
## [16] "Fold 6/10 done. Error: 666.747537258937"
## [17] "Fold 6/10 done. Error: 526.701048800189"
## [18] "Fold 6/10 done. Error: 414.174683252046"
## [19] "Fold 6/10 done. Error: 327.372999613192"
## [20] "Fold 6/10 done. Error: 262.561009883553"
##  [1] "Fold 7/10 done. Error: 10049.7155169484"
##  [2] "Fold 7/10 done. Error: 9740.24047520683"
##  [3] "Fold 7/10 done. Error: 8906.86209977837"
##  [4] "Fold 7/10 done. Error: 8044.07831496467"
##  [5] "Fold 7/10 done. Error: 7033.63538123545"
##  [6] "Fold 7/10 done. Error: 5823.40115836027"
##  [7] "Fold 7/10 done. Error: 4428.58306914515"
##  [8] "Fold 7/10 done. Error: 3327.11000137871"
##  [9] "Fold 7/10 done. Error: 2524.13716080726"
## [10] "Fold 7/10 done. Error: 1924.9606420577" 
## [11] "Fold 7/10 done. Error: 1423.63263296581"
## [12] "Fold 7/10 done. Error: 1040.4532686365" 
## [13] "Fold 7/10 done. Error: 765.319912646728"
## [14] "Fold 7/10 done. Error: 568.862001569514"
## [15] "Fold 7/10 done. Error: 424.744842531921"
## [16] "Fold 7/10 done. Error: 322.723240003498"
## [17] "Fold 7/10 done. Error: 249.844629559846"
## [18] "Fold 7/10 done. Error: 195.736838599198"
## [19] "Fold 7/10 done. Error: 156.695072359077"
## [20] "Fold 7/10 done. Error: 127.699521295209"
##  [1] "Fold 8/10 done. Error: 8398.02089705292"
##  [2] "Fold 8/10 done. Error: 8057.58572752561"
##  [3] "Fold 8/10 done. Error: 7199.67225395134"
##  [4] "Fold 8/10 done. Error: 6567.96505373682"
##  [5] "Fold 8/10 done. Error: 5986.43402477515"
##  [6] "Fold 8/10 done. Error: 5070.2112726879" 
##  [7] "Fold 8/10 done. Error: 3964.99476337378"
##  [8] "Fold 8/10 done. Error: 3092.5870482136" 
##  [9] "Fold 8/10 done. Error: 2431.62606021511"
## [10] "Fold 8/10 done. Error: 1916.1887879502" 
## [11] "Fold 8/10 done. Error: 1498.56344382699"
## [12] "Fold 8/10 done. Error: 1143.15792873192"
## [13] "Fold 8/10 done. Error: 875.22284327683" 
## [14] "Fold 8/10 done. Error: 678.831028531544"
## [15] "Fold 8/10 done. Error: 532.332734739539"
## [16] "Fold 8/10 done. Error: 409.717175846945"
## [17] "Fold 8/10 done. Error: 314.806027294298"
## [18] "Fold 8/10 done. Error: 245.302334713172"
## [19] "Fold 8/10 done. Error: 194.111722075755"
## [20] "Fold 8/10 done. Error: 156.213909406013"
##  [1] "Fold 9/10 done. Error: 9906.21445729386"
##  [2] "Fold 9/10 done. Error: 9673.5778494729" 
##  [3] "Fold 9/10 done. Error: 8706.89778059913"
##  [4] "Fold 9/10 done. Error: 7584.62067191002"
##  [5] "Fold 9/10 done. Error: 6647.70010944011"
##  [6] "Fold 9/10 done. Error: 5364.96198079038"
##  [7] "Fold 9/10 done. Error: 4073.04118845149"
##  [8] "Fold 9/10 done. Error: 3091.26225184671"
##  [9] "Fold 9/10 done. Error: 2363.48436038567"
## [10] "Fold 9/10 done. Error: 1809.92401484567"
## [11] "Fold 9/10 done. Error: 1354.2041418273" 
## [12] "Fold 9/10 done. Error: 1005.11124242728"
## [13] "Fold 9/10 done. Error: 746.035216936631"
## [14] "Fold 9/10 done. Error: 558.360925053933"
## [15] "Fold 9/10 done. Error: 420.292826392314"
## [16] "Fold 9/10 done. Error: 319.054427302379"
## [17] "Fold 9/10 done. Error: 242.538120035674"
## [18] "Fold 9/10 done. Error: 182.15600632965" 
## [19] "Fold 9/10 done. Error: 139.192579186823"
## [20] "Fold 9/10 done. Error: 107.500742305331"
##  [1] "Fold 10/10 done. Error: 8020.37927493245"
##  [2] "Fold 10/10 done. Error: 7697.88346726543"
##  [3] "Fold 10/10 done. Error: 7031.12374296387"
##  [4] "Fold 10/10 done. Error: 6294.32718409467"
##  [5] "Fold 10/10 done. Error: 5615.5007155595" 
##  [6] "Fold 10/10 done. Error: 4459.8208525311" 
##  [7] "Fold 10/10 done. Error: 3321.47838394126"
##  [8] "Fold 10/10 done. Error: 2497.98978127158"
##  [9] "Fold 10/10 done. Error: 1915.64625146079"
## [10] "Fold 10/10 done. Error: 1487.13968831616"
## [11] "Fold 10/10 done. Error: 1124.67389999549"
## [12] "Fold 10/10 done. Error: 847.217759120669"
## [13] "Fold 10/10 done. Error: 644.045437152742"
## [14] "Fold 10/10 done. Error: 497.475319279266"
## [15] "Fold 10/10 done. Error: 387.312978726709"
## [16] "Fold 10/10 done. Error: 306.915855797651"
## [17] "Fold 10/10 done. Error: 246.216858492722"
## [18] "Fold 10/10 done. Error: 196.936936931074"
## [19] "Fold 10/10 done. Error: 160.103168436104"
## [20] "Fold 10/10 done. Error: 133.417302410296"

The fitting verbose contains useful information, showing the error for each fold. Aside from the fitting verbose, we can see a more succinct summary by using the function

print(cv_model)
## 
##  regression type:  
## 
##          lambda     error estimated_non_zero
##  [1,] 1.8885960 9523.7357                  0
##  [2,] 1.6131093 9207.3047                  1
##  [3,] 1.3778075 8483.2428                  2
##  [4,] 1.1768288 7660.2673                  3
##  [5,] 1.0051665 6853.3743                  4
##  [6,] 0.8585444 5743.6287                 14
##  [7,] 0.7333098 4449.9434                 15
##  [8,] 0.6263430 3427.6404                 15
##  [9,] 0.5349793 2666.8889                 15
## [10,] 0.4569427 2096.7952                 16
## [11,] 0.3902891 1624.7847                 22
## [12,] 0.3333582 1237.4624                 22
## [13,] 0.2847318  935.0866                 22
## [14,] 0.2431984  714.0932                 22
## [15,] 0.2077234  549.4968                 24
## [16,] 0.1774231  423.4776                 24
## [17,] 0.1515426  328.7270                 26
## [18,] 0.1294373  254.9699                 26
## [19,] 0.1105565  200.4616                 27
## [20,] 0.0944298  160.1339                 27

The best model is found to be the one at the end of the path:

cv_model$best_lambda_id
## [1] 20

Checking the metrics again, we see how CV has generated a model with the correct amount of sparsity that gives FDR levels below the specified values.

fdr_sensitivity(fitted_ids = cv_model$fit$selected_var, true_ids = data$true_var_id,
    num_coef = 500)
## $fdr
## [1] 0.03703704
## 
## $sensitivity
## [1] 0.9285714
## 
## $fpr
## [1] 0.002118644
## 
## $f1
## [1] 0.9454545
fdr_sensitivity(fitted_ids = cv_model$fit$selected_grp, true_ids = data$true_grp_id,
    num_coef = 100)
## $fdr
## [1] 0.09090909
## 
## $sensitivity
## [1] 1
## 
## $fpr
## [1] 0.01111111
## 
## $f1
## [1] 0.952381

Plot

We can visualise the solution using the plot function:

plot(cv_model, how_many = 10)

Prediction

The package has an implemented predict function to allow for easy prediction. The predict function can be used on regular and CV model fits.

predict(model, data$X)[1:5]
## [1]  7.5631945 -2.5428616  7.8497662  0.8790488 -4.8362971
dim(predict(cv_model, data$X))
## [1] 400  20

Logistic regression

As mentioned, the package can also be used to fit SGS to a binary response. First, we generate some binary data. We can use the same input matrix, \(X\), and true \(\beta\) as before. We split the data into train and test to test the models classification performance.

sigmoid = function(x) {
    1/(1 + exp(-x))
}
y = ifelse(sigmoid(data$X %*% data$true_beta + rnorm(400)) >
    0.5, 1, 0)
train_y = y[1:350]
test_y = y[351:400]
train_X = data$X[1:350, ]
test_X = data$X[351:400, ]

Fitting and prediction

We can again apply CV.

cv_model = fit_sgs_cv(X = train_X, y = train_y, groups = groups,
    type = "logistic", path_length = 20, nfolds = 10, alpha = 0.95,
    vFDR = 0.1, gFDR = 0.1, min_frac = 0.05, standardise = "l2",
    intercept = FALSE, verbose = TRUE, screen = TRUE)
##  [1] "Fold 1/10 done. Error: 0.457142857142857"
##  [2] "Fold 1/10 done. Error: 0.342857142857143"
##  [3] "Fold 1/10 done. Error: 0.314285714285714"
##  [4] "Fold 1/10 done. Error: 0.285714285714286"
##  [5] "Fold 1/10 done. Error: 0.285714285714286"
##  [6] "Fold 1/10 done. Error: 0.171428571428571"
##  [7] "Fold 1/10 done. Error: 0.142857142857143"
##  [8] "Fold 1/10 done. Error: 0.114285714285714"
##  [9] "Fold 1/10 done. Error: 0.114285714285714"
## [10] "Fold 1/10 done. Error: 0.114285714285714"
## [11] "Fold 1/10 done. Error: 0.142857142857143"
## [12] "Fold 1/10 done. Error: 0.142857142857143"
## [13] "Fold 1/10 done. Error: 0.171428571428571"
## [14] "Fold 1/10 done. Error: 0.171428571428571"
## [15] "Fold 1/10 done. Error: 0.2"              
## [16] "Fold 1/10 done. Error: 0.228571428571429"
## [17] "Fold 1/10 done. Error: 0.228571428571429"
## [18] "Fold 1/10 done. Error: 0.2"              
## [19] "Fold 1/10 done. Error: 0.2"              
## [20] "Fold 1/10 done. Error: 0.2"              
##  [1] "Fold 2/10 done. Error: 0.4"              
##  [2] "Fold 2/10 done. Error: 0.4"              
##  [3] "Fold 2/10 done. Error: 0.342857142857143"
##  [4] "Fold 2/10 done. Error: 0.285714285714286"
##  [5] "Fold 2/10 done. Error: 0.285714285714286"
##  [6] "Fold 2/10 done. Error: 0.314285714285714"
##  [7] "Fold 2/10 done. Error: 0.285714285714286"
##  [8] "Fold 2/10 done. Error: 0.142857142857143"
##  [9] "Fold 2/10 done. Error: 0.142857142857143"
## [10] "Fold 2/10 done. Error: 0.142857142857143"
## [11] "Fold 2/10 done. Error: 0.142857142857143"
## [12] "Fold 2/10 done. Error: 0.142857142857143"
## [13] "Fold 2/10 done. Error: 0.142857142857143"
## [14] "Fold 2/10 done. Error: 0.171428571428571"
## [15] "Fold 2/10 done. Error: 0.142857142857143"
## [16] "Fold 2/10 done. Error: 0.142857142857143"
## [17] "Fold 2/10 done. Error: 0.142857142857143"
## [18] "Fold 2/10 done. Error: 0.142857142857143"
## [19] "Fold 2/10 done. Error: 0.142857142857143"
## [20] "Fold 2/10 done. Error: 0.142857142857143"
##  [1] "Fold 3/10 done. Error: 0.514285714285714" 
##  [2] "Fold 3/10 done. Error: 0.314285714285714" 
##  [3] "Fold 3/10 done. Error: 0.314285714285714" 
##  [4] "Fold 3/10 done. Error: 0.285714285714286" 
##  [5] "Fold 3/10 done. Error: 0.314285714285714" 
##  [6] "Fold 3/10 done. Error: 0.257142857142857" 
##  [7] "Fold 3/10 done. Error: 0.228571428571429" 
##  [8] "Fold 3/10 done. Error: 0.171428571428571" 
##  [9] "Fold 3/10 done. Error: 0.142857142857143" 
## [10] "Fold 3/10 done. Error: 0.114285714285714" 
## [11] "Fold 3/10 done. Error: 0.0857142857142857"
## [12] "Fold 3/10 done. Error: 0.0857142857142857"
## [13] "Fold 3/10 done. Error: 0.0857142857142857"
## [14] "Fold 3/10 done. Error: 0.0857142857142857"
## [15] "Fold 3/10 done. Error: 0.0857142857142857"
## [16] "Fold 3/10 done. Error: 0.0857142857142857"
## [17] "Fold 3/10 done. Error: 0.0857142857142857"
## [18] "Fold 3/10 done. Error: 0.114285714285714" 
## [19] "Fold 3/10 done. Error: 0.142857142857143" 
## [20] "Fold 3/10 done. Error: 0.142857142857143" 
##  [1] "Fold 4/10 done. Error: 0.542857142857143"
##  [2] "Fold 4/10 done. Error: 0.314285714285714"
##  [3] "Fold 4/10 done. Error: 0.285714285714286"
##  [4] "Fold 4/10 done. Error: 0.342857142857143"
##  [5] "Fold 4/10 done. Error: 0.314285714285714"
##  [6] "Fold 4/10 done. Error: 0.257142857142857"
##  [7] "Fold 4/10 done. Error: 0.285714285714286"
##  [8] "Fold 4/10 done. Error: 0.257142857142857"
##  [9] "Fold 4/10 done. Error: 0.257142857142857"
## [10] "Fold 4/10 done. Error: 0.285714285714286"
## [11] "Fold 4/10 done. Error: 0.257142857142857"
## [12] "Fold 4/10 done. Error: 0.257142857142857"
## [13] "Fold 4/10 done. Error: 0.257142857142857"
## [14] "Fold 4/10 done. Error: 0.257142857142857"
## [15] "Fold 4/10 done. Error: 0.228571428571429"
## [16] "Fold 4/10 done. Error: 0.228571428571429"
## [17] "Fold 4/10 done. Error: 0.228571428571429"
## [18] "Fold 4/10 done. Error: 0.228571428571429"
## [19] "Fold 4/10 done. Error: 0.2"              
## [20] "Fold 4/10 done. Error: 0.2"              
##  [1] "Fold 5/10 done. Error: 0.485714285714286" 
##  [2] "Fold 5/10 done. Error: 0.314285714285714" 
##  [3] "Fold 5/10 done. Error: 0.314285714285714" 
##  [4] "Fold 5/10 done. Error: 0.314285714285714" 
##  [5] "Fold 5/10 done. Error: 0.228571428571429" 
##  [6] "Fold 5/10 done. Error: 0.171428571428571" 
##  [7] "Fold 5/10 done. Error: 0.2"               
##  [8] "Fold 5/10 done. Error: 0.171428571428571" 
##  [9] "Fold 5/10 done. Error: 0.0857142857142857"
## [10] "Fold 5/10 done. Error: 0.114285714285714" 
## [11] "Fold 5/10 done. Error: 0.114285714285714" 
## [12] "Fold 5/10 done. Error: 0.142857142857143" 
## [13] "Fold 5/10 done. Error: 0.114285714285714" 
## [14] "Fold 5/10 done. Error: 0.142857142857143" 
## [15] "Fold 5/10 done. Error: 0.142857142857143" 
## [16] "Fold 5/10 done. Error: 0.142857142857143" 
## [17] "Fold 5/10 done. Error: 0.142857142857143" 
## [18] "Fold 5/10 done. Error: 0.142857142857143" 
## [19] "Fold 5/10 done. Error: 0.142857142857143" 
## [20] "Fold 5/10 done. Error: 0.142857142857143" 
##  [1] "Fold 6/10 done. Error: 0.4"               
##  [2] "Fold 6/10 done. Error: 0.4"               
##  [3] "Fold 6/10 done. Error: 0.314285714285714" 
##  [4] "Fold 6/10 done. Error: 0.257142857142857" 
##  [5] "Fold 6/10 done. Error: 0.228571428571429" 
##  [6] "Fold 6/10 done. Error: 0.171428571428571" 
##  [7] "Fold 6/10 done. Error: 0.142857142857143" 
##  [8] "Fold 6/10 done. Error: 0.142857142857143" 
##  [9] "Fold 6/10 done. Error: 0.114285714285714" 
## [10] "Fold 6/10 done. Error: 0.0857142857142857"
## [11] "Fold 6/10 done. Error: 0.0857142857142857"
## [12] "Fold 6/10 done. Error: 0.0857142857142857"
## [13] "Fold 6/10 done. Error: 0.0857142857142857"
## [14] "Fold 6/10 done. Error: 0.0857142857142857"
## [15] "Fold 6/10 done. Error: 0.0857142857142857"
## [16] "Fold 6/10 done. Error: 0.0857142857142857"
## [17] "Fold 6/10 done. Error: 0.0857142857142857"
## [18] "Fold 6/10 done. Error: 0.114285714285714" 
## [19] "Fold 6/10 done. Error: 0.114285714285714" 
## [20] "Fold 6/10 done. Error: 0.114285714285714" 
##  [1] "Fold 7/10 done. Error: 0.428571428571429" 
##  [2] "Fold 7/10 done. Error: 0.285714285714286" 
##  [3] "Fold 7/10 done. Error: 0.285714285714286" 
##  [4] "Fold 7/10 done. Error: 0.314285714285714" 
##  [5] "Fold 7/10 done. Error: 0.228571428571429" 
##  [6] "Fold 7/10 done. Error: 0.2"               
##  [7] "Fold 7/10 done. Error: 0.2"               
##  [8] "Fold 7/10 done. Error: 0.171428571428571" 
##  [9] "Fold 7/10 done. Error: 0.0571428571428572"
## [10] "Fold 7/10 done. Error: 0.0571428571428572"
## [11] "Fold 7/10 done. Error: 0.0857142857142857"
## [12] "Fold 7/10 done. Error: 0.0857142857142857"
## [13] "Fold 7/10 done. Error: 0.0857142857142857"
## [14] "Fold 7/10 done. Error: 0.114285714285714" 
## [15] "Fold 7/10 done. Error: 0.0857142857142857"
## [16] "Fold 7/10 done. Error: 0.0857142857142857"
## [17] "Fold 7/10 done. Error: 0.0857142857142857"
## [18] "Fold 7/10 done. Error: 0.0857142857142857"
## [19] "Fold 7/10 done. Error: 0.114285714285714" 
## [20] "Fold 7/10 done. Error: 0.0857142857142857"
##  [1] "Fold 8/10 done. Error: 0.457142857142857" 
##  [2] "Fold 8/10 done. Error: 0.314285714285714" 
##  [3] "Fold 8/10 done. Error: 0.371428571428571" 
##  [4] "Fold 8/10 done. Error: 0.371428571428571" 
##  [5] "Fold 8/10 done. Error: 0.342857142857143" 
##  [6] "Fold 8/10 done. Error: 0.257142857142857" 
##  [7] "Fold 8/10 done. Error: 0.2"               
##  [8] "Fold 8/10 done. Error: 0.171428571428571" 
##  [9] "Fold 8/10 done. Error: 0.142857142857143" 
## [10] "Fold 8/10 done. Error: 0.142857142857143" 
## [11] "Fold 8/10 done. Error: 0.142857142857143" 
## [12] "Fold 8/10 done. Error: 0.142857142857143" 
## [13] "Fold 8/10 done. Error: 0.114285714285714" 
## [14] "Fold 8/10 done. Error: 0.0857142857142857"
## [15] "Fold 8/10 done. Error: 0.0857142857142857"
## [16] "Fold 8/10 done. Error: 0.0857142857142857"
## [17] "Fold 8/10 done. Error: 0.0857142857142857"
## [18] "Fold 8/10 done. Error: 0.0857142857142857"
## [19] "Fold 8/10 done. Error: 0.114285714285714" 
## [20] "Fold 8/10 done. Error: 0.114285714285714" 
##  [1] "Fold 9/10 done. Error: 0.4"              
##  [2] "Fold 9/10 done. Error: 0.4"              
##  [3] "Fold 9/10 done. Error: 0.371428571428571"
##  [4] "Fold 9/10 done. Error: 0.342857142857143"
##  [5] "Fold 9/10 done. Error: 0.314285714285714"
##  [6] "Fold 9/10 done. Error: 0.2"              
##  [7] "Fold 9/10 done. Error: 0.228571428571429"
##  [8] "Fold 9/10 done. Error: 0.2"              
##  [9] "Fold 9/10 done. Error: 0.171428571428571"
## [10] "Fold 9/10 done. Error: 0.171428571428571"
## [11] "Fold 9/10 done. Error: 0.142857142857143"
## [12] "Fold 9/10 done. Error: 0.142857142857143"
## [13] "Fold 9/10 done. Error: 0.142857142857143"
## [14] "Fold 9/10 done. Error: 0.171428571428571"
## [15] "Fold 9/10 done. Error: 0.2"              
## [16] "Fold 9/10 done. Error: 0.171428571428571"
## [17] "Fold 9/10 done. Error: 0.171428571428571"
## [18] "Fold 9/10 done. Error: 0.171428571428571"
## [19] "Fold 9/10 done. Error: 0.171428571428571"
## [20] "Fold 9/10 done. Error: 0.171428571428571"
##  [1] "Fold 10/10 done. Error: 0.457142857142857"
##  [2] "Fold 10/10 done. Error: 0.285714285714286"
##  [3] "Fold 10/10 done. Error: 0.285714285714286"
##  [4] "Fold 10/10 done. Error: 0.314285714285714"
##  [5] "Fold 10/10 done. Error: 0.257142857142857"
##  [6] "Fold 10/10 done. Error: 0.142857142857143"
##  [7] "Fold 10/10 done. Error: 0.171428571428571"
##  [8] "Fold 10/10 done. Error: 0.114285714285714"
##  [9] "Fold 10/10 done. Error: 0.142857142857143"
## [10] "Fold 10/10 done. Error: 0.142857142857143"
## [11] "Fold 10/10 done. Error: 0.171428571428571"
## [12] "Fold 10/10 done. Error: 0.2"              
## [13] "Fold 10/10 done. Error: 0.2"              
## [14] "Fold 10/10 done. Error: 0.2"              
## [15] "Fold 10/10 done. Error: 0.2"              
## [16] "Fold 10/10 done. Error: 0.2"              
## [17] "Fold 10/10 done. Error: 0.2"              
## [18] "Fold 10/10 done. Error: 0.171428571428571"
## [19] "Fold 10/10 done. Error: 0.171428571428571"
## [20] "Fold 10/10 done. Error: 0.171428571428571"

and again, use the predict function

predictions = predict(cv_model, test_X)

For logistic regression, the function returns both the predicted class probabilities (response) and the predicted class (class). We can use this to check the prediction accuracy, given as \(82\%\).

predictions$response[1:5, cv_model$best_lambda_id]
## [1] 0.4090457 0.6485261 0.1668541 0.1466649 0.3901439
predictions$class[1:5, cv_model$best_lambda_id]
## [1] 0 1 0 0 0
sum(predictions$class[, cv_model$best_lambda_id] == test_y)/length(test_y)
## [1] 0.82

Group SLOPE

Group SLOPE (gSLOPE) applies adaptive group penalisation to control the group FDR under orthogonal designs. gSLOPE is described in detail in Brzyski, D., Gossmann, A., Su, W., Bogdan, M. (2019). Group SLOPE – Adaptive Selection of Groups of Predictors. gSLOPE is implemented in the sgs package with the same features as SGS. Here, we briefly demonstrate how to fit a gSLOPE model.

groups = c(rep(1:20, each = 3), rep(21:40, each = 4), rep(41:60,
    each = 5), rep(61:80, each = 6), rep(81:100, each = 7))

data = gen_toy_data(p = 500, n = 400, groups = groups, seed_id = 3)
model = fit_gslope(X = data$X, y = data$y, groups = groups, type = "linear",
    lambda = 0.5, gFDR = 0.1, standardise = "l2", intercept = TRUE,
    verbose = FALSE, screen = TRUE)

Screening

Screening rules allow the input dimensionality to be reduced before fitting. The strong screening rules for gSLOPE and SGS are described in detail in Feser, F., Evangelou, M. (2024). Strong Screening Rules for Group-based SLOPE Models. Here, we demonstrate the effectiveness of screening rules by looking at the speed improvement they provide. For SGS:

screen_time = system.time(model_screen <- fit_sgs(X = data$X,
    y = data$y, groups = groups, type = "linear", path_length = 100,
    alpha = 0.95, vFDR = 0.1, gFDR = 0.1, standardise = "l2",
    intercept = TRUE, verbose = FALSE, screen = TRUE))
no_screen_time = system.time(model_no_screen <- fit_sgs(X = data$X,
    y = data$y, groups = groups, type = "linear", path_length = 100,
    alpha = 0.95, vFDR = 0.1, gFDR = 0.1, standardise = "l2",
    intercept = TRUE, verbose = FALSE, screen = FALSE))
screen_time
##    user  system elapsed 
##    8.67    0.16    9.18
no_screen_time
##    user  system elapsed 
##   10.86    0.09   12.28

and for gSLOPE:

screen_time = system.time(model_screen <- fit_gslope(X = data$X,
    y = data$y, groups = groups, type = "linear", path_length = 100,
    gFDR = 0.1, standardise = "l2", intercept = TRUE, verbose = FALSE,
    screen = TRUE))
no_screen_time = system.time(model_no_screen <- fit_gslope(X = data$X,
    y = data$y, groups = groups, type = "linear", path_length = 100,
    gFDR = 0.1, standardise = "l2", intercept = TRUE, verbose = FALSE,
    screen = FALSE))
screen_time
##    user  system elapsed 
##    4.23    0.12    5.06
no_screen_time
##    user  system elapsed 
##   11.25    0.13   12.13

Reference