1
0
mirror of https://github.com/msberends/AMR.git synced 2025-10-19 11:06:20 +02:00

(v3.0.0.9003) eucast_rules fix, new tidymodels integration

This commit is contained in:
2025-06-13 14:03:21 +02:00
parent 3742e9e994
commit 72db2b2562
22 changed files with 760 additions and 107 deletions

View File

@@ -1244,7 +1244,9 @@ try_colour <- function(..., before, after, collapse = " ") {
}
}
is_dark <- function() {
if (is.null(AMR_env$is_dark_theme)) {
AMR_env$current_theme <- tryCatch(getExportedValue("getThemeInfo", ns = asNamespace("rstudioapi"))()$editor, error = function(e) NULL)
if (!identical(AMR_env$current_theme, AMR_env$former_theme) || is.null(AMR_env$is_dark_theme)) {
AMR_env$former_theme <- AMR_env$current_theme
AMR_env$is_dark_theme <- !has_colour() || tryCatch(isTRUE(getExportedValue("getThemeInfo", ns = asNamespace("rstudioapi"))()$dark), error = function(e) FALSE)
}
isTRUE(AMR_env$is_dark_theme)

4
R/ab.R
View File

@@ -655,7 +655,9 @@ generalise_antibiotic_name <- function(x) {
x <- trimws(gsub(" +", " ", x, perl = TRUE))
# remove last couple of words if they numbers or units
x <- gsub("( ([0-9]{3,}|U?M?C?G|L))+$", "", x, perl = TRUE)
# move HIGH to end
# remove whitespace prior to numbers if preceded by A-Z
x <- gsub("([A-Z]+) +([0-9]+)", "\\1\\2", x, perl = TRUE)
# move HIGH to the end
x <- trimws(gsub("(.*) HIGH(.*)", "\\1\\2 HIGH", x, perl = TRUE))
x
}

View File

@@ -208,7 +208,7 @@ age_groups <- function(x, split_at = c(12, 25, 55, 75), na.rm = FALSE) {
split_at <- c(0, split_at)
}
split_at <- split_at[!is.na(split_at)]
stop_if(length(split_at) == 1, "invalid value for `split_at`") # only 0 is available
stop_if(length(split_at) == 1, "invalid value for `split_at`.") # only 0 is available
# turn input values to 'split_at' indices
y <- x

View File

@@ -361,3 +361,15 @@
#' @examples
#' dosage
"dosage"
#' Data Set with `r format(nrow(esbl_isolates), big.mark = " ")` ESBL Isolates
#'
#' A data set containing `r format(nrow(esbl_isolates), big.mark = " ")` microbial isolates with MIC values of common antibiotics and a binary `esbl` column for extended-spectrum beta-lactamase (ESBL) production. This data set contains randomised fictitious data but reflects reality and can be used to practise AMR-related machine learning, e.g., classification modelling with [tidymodels](https://amr-for-r.org/articles/AMR_with_tidymodels.html).
#' @format A [tibble][tibble::tibble] with `r format(nrow(esbl_isolates), big.mark = " ")` observations and `r ncol(esbl_isolates)` variables:
#' - `esbl`\cr Logical indicator if the isolate is ESBL-producing
#' - `genus`\cr Genus of the microorganism
#' - `AMC:COL`\cr MIC values for 17 antimicrobial agents, transformed to class [`mic`] (see [as.mic()])
#' @details See our [tidymodels integration][amr-tidymodels] for an example using this data set.
#' @examples
#' esbl_isolates
"esbl_isolates"

View File

@@ -442,7 +442,7 @@ eucast_rules <- function(x,
# big speed gain! only analyse unique rows:
pm_distinct(`.rowid`, .keep_all = TRUE) %pm>%
as.data.frame(stringsAsFactors = FALSE)
x[, col_mo] <- as.mo(as.character(x[, col_mo, drop = TRUE]), info = info)
x[, col_mo] <- as.mo(as.character(x[, col_mo, drop = TRUE]), info = FALSE)
# rename col_mo to prevent interference with joined columns
colnames(x)[colnames(x) == col_mo] <- ".col_mo"
col_mo <- ".col_mo"
@@ -450,8 +450,8 @@ eucast_rules <- function(x,
x <- left_join_microorganisms(x, by = col_mo, suffix = c("_oldcols", ""))
x$gramstain <- mo_gramstain(x[, col_mo, drop = TRUE], language = NULL, info = FALSE)
x$genus_species <- trimws(paste(x$genus, x$species))
if (isTRUE(info) && NROW(x) > 10000) {
message_(" OK.", add_fn = list(font_green, font_bold), as_note = FALSE)
if (isTRUE(info) && NROW(x.bak) > 10000) {
message_("OK.", add_fn = list(font_green, font_bold), as_note = FALSE)
}
n_added <- 0
@@ -624,31 +624,16 @@ eucast_rules <- function(x,
eucast_rules_df <- eucast_rules_df %pm>%
rbind_AMR(eucast_rules_df_total %pm>%
subset(reference.rule_group %like% "breakpoint" & reference.version == version_breakpoints))
# eucast_rules_df <- subset(
# eucast_rules_df,
# reference.rule_group %unlike% "breakpoint" |
# (reference.rule_group %like% "breakpoint" & reference.version == version_breakpoints)
# )
}
if (any(c("all", "expected_phenotypes") %in% rules)) {
eucast_rules_df <- eucast_rules_df %pm>%
rbind_AMR(eucast_rules_df_total %pm>%
subset(reference.rule_group %like% "expected" & reference.version == version_expected_phenotypes))
# eucast_rules_df <- subset(
# eucast_rules_df,
# reference.rule_group %unlike% "expected" |
# (reference.rule_group %like% "expected" & reference.version == version_expected_phenotypes)
# )
}
if (any(c("all", "expert") %in% rules)) {
eucast_rules_df <- eucast_rules_df %pm>%
rbind_AMR(eucast_rules_df_total %pm>%
subset(reference.rule_group %like% "expert" & reference.version == version_expertrules))
# eucast_rules_df <- subset(
# eucast_rules_df,
# reference.rule_group %unlike% "expert" |
# (reference.rule_group %like% "expert" & reference.version == version_expertrules)
# )
}
## filter out AmpC de-repressed cephalosporin-resistant mutants ----
# no need to filter on version number here - the rules contain these version number, so are inherently filtered
@@ -671,6 +656,9 @@ eucast_rules <- function(x,
# we only hints on remaining rows in `eucast_rules_df`
screening_abx <- as.character(AMR::antimicrobials$ab[which(AMR::antimicrobials$ab %like% "-S$")])
screening_abx <- screening_abx[screening_abx %in% unique(unlist(strsplit(EUCAST_RULES_DF$and_these_antibiotics[!is.na(EUCAST_RULES_DF$and_these_antibiotics)], ", *")))]
if (isTRUE(info)) {
cat("\n")
}
for (ab_s in screening_abx) {
ab <- gsub("-S$", "", ab_s)
if (ab %in% names(cols_ab) && !ab_s %in% names(cols_ab)) {
@@ -901,7 +889,9 @@ eucast_rules <- function(x,
}
for (i in seq_len(length(custom_rules))) {
rule <- custom_rules[[i]]
rows <- which(eval(parse(text = rule$query), envir = x))
rows <- tryCatch(which(eval(parse(text = rule$query), envir = x)),
error = function(e) stop_(paste0(conditionMessage(e), font_red(" (check available data and compare with the custom rules set)")), call = FALSE)
)
cols <- as.character(rule$result_group)
cols <- c(
cols[cols %in% colnames(x)], # direct column names
@@ -915,9 +905,8 @@ eucast_rules <- function(x,
get_antibiotic_names(cols)
)
if (isTRUE(info)) {
# print rule
cat(italicise_taxonomy(
word_wrap(format_custom_query_rule(rule$query, colours = FALSE),
word_wrap(rule_text,
width = getOption("width") - 30,
extra_indent = 6
),

View File

@@ -432,11 +432,17 @@ pillar_shaft.mic <- function(x, ...) {
}
crude_numbers <- as.double(x)
operators <- gsub("[^<=>]+", "", as.character(x))
# colourise operators
operators[!is.na(operators) & operators != ""] <- font_silver(operators[!is.na(operators) & operators != ""], collapse = NULL)
out <- trimws(paste0(operators, trimws(format(crude_numbers))))
out[is.na(x)] <- font_na(NA)
# make trailing zeroes less visible
out[out %like% "[.]"] <- gsub("([.]?0+)$", font_silver("\\1"), out[out %like% "[.]"], perl = TRUE)
if (is_dark()) {
fn <- font_silver
} else {
fn <- font_white
}
out[out %like% "[.]"] <- gsub("([.]?0+)$", fn("\\1"), out[out %like% "[.]"], perl = TRUE)
create_pillar_column(out, align = "right", width = max(nchar(font_stripstyle(out))))
}

View File

@@ -31,13 +31,17 @@
#'
#' These functions can be used for generating random MIC values and disk diffusion diameters, for AMR data analysis practice. By providing a microorganism and antimicrobial drug, the generated results will reflect reality as much as possible.
#' @param size Desired size of the returned vector. If used in a [data.frame] call or `dplyr` verb, will get the current (group) size if left blank.
#' @param mo Any [character] that can be coerced to a valid microorganism code with [as.mo()].
#' @param mo Any [character] that can be coerced to a valid microorganism code with [as.mo()]. Can be the same length as `size`.
#' @param ab Any [character] that can be coerced to a valid antimicrobial drug code with [as.ab()].
#' @param prob_SIR A vector of length 3: the probabilities for "S" (1st value), "I" (2nd value) and "R" (3rd value).
#' @param skew Direction of skew for MIC or disk values, either `"right"` or `"left"`. A left-skewed distribution has the majority of the data on the right.
#' @param severity Skew severity; higher values will increase the skewedness. Default is `2`; use `0` to prevent skewedness.
#' @param ... Ignored, only in place to allow future extensions.
#' @details The base \R function [sample()] is used for generating values.
#'
#' Generated values are based on the EUCAST `r max(as.integer(gsub("[^0-9]", "", subset(clinical_breakpoints, guideline %like% "EUCAST")$guideline)))` guideline as implemented in the [clinical_breakpoints] data set. To create specific generated values per bug or drug, set the `mo` and/or `ab` argument.
#' @details
#' Internally, MIC and disk zone values are sampled based on clinical breakpoints defined in the [clinical_breakpoints] data set. To create specific generated values per bug or drug, set the `mo` and/or `ab` argument. The MICs are sampled on a log2 scale and disks linearly, using weighted probabilities. The weights are based on the `skew` and `severity` arguments:
#' * `skew = "right"` places more emphasis on lower MIC or higher disk values.
#' * `skew = "left"` places more emphasis on higher MIC or lower disk values.
#' * `severity` controls the exponential bias applied.
#' @return class `mic` for [random_mic()] (see [as.mic()]) and class `disk` for [random_disk()] (see [as.disk()])
#' @name random
#' @rdname random
@@ -47,8 +51,13 @@
#' random_disk(25)
#' random_sir(25)
#'
#' # add more skewedness, make more realistic by setting a bug and/or drug:
#' disks <- random_disk(100, severity = 2, mo = "Escherichia coli", ab = "CIP")
#' plot(disks)
#' # `plot()` and `ggplot2::autoplot()` allow for coloured bars if `mo` and `ab` are set
#' plot(disks, mo = "Escherichia coli", ab = "CIP", guideline = "CLSI 2025")
#'
#' \donttest{
#' # make the random generation more realistic by setting a bug and/or drug:
#' random_mic(25, "Klebsiella pneumoniae") # range 0.0625-64
#' random_mic(25, "Klebsiella pneumoniae", "meropenem") # range 0.0625-16
#' random_mic(25, "Streptococcus pneumoniae", "meropenem") # range 0.0625-4
@@ -57,26 +66,60 @@
#' random_disk(25, "Klebsiella pneumoniae", "ampicillin") # range 11-17
#' random_disk(25, "Streptococcus pneumoniae", "ampicillin") # range 12-27
#' }
random_mic <- function(size = NULL, mo = NULL, ab = NULL, ...) {
random_mic <- function(size = NULL, mo = NULL, ab = NULL, skew = "right", severity = 1, ...) {
meet_criteria(size, allow_class = c("numeric", "integer"), has_length = 1, is_positive = TRUE, is_finite = TRUE, allow_NULL = TRUE)
meet_criteria(mo, allow_class = "character", has_length = 1, allow_NULL = TRUE)
meet_criteria(mo, allow_class = "character", has_length = c(1, size), allow_NULL = TRUE)
meet_criteria(ab, allow_class = "character", has_length = 1, allow_NULL = TRUE)
meet_criteria(skew, allow_class = "character", is_in = c("right", "left"), has_length = 1)
meet_criteria(severity, allow_class = c("numeric", "integer"), has_length = 1, is_positive_or_zero = TRUE, is_finite = TRUE)
if (is.null(size)) {
size <- NROW(get_current_data(arg_name = "size", call = -3))
}
random_exec("MIC", size = size, mo = mo, ab = ab)
if (length(mo) > 1) {
out <- rep(NA_mic_, length(size))
p <- progress_ticker(n = length(unique(mo)), n_min = 10, title = "Generating random MIC values")
for (mo_ in unique(mo)) {
p$tick()
out[which(mo == mo_)] <- random_exec("MIC", size = sum(mo == mo_), mo = mo_, ab = ab, skew = skew, severity = severity)
}
out <- as.mic(out, keep_operators = "none")
if (stats::runif(1) > 0.5 && length(unique(out)) > 1) {
out[out == min(out)] <- paste0("<=", out[out == min(out)])
}
if (stats::runif(1) > 0.5 && length(unique(out)) > 1) {
out[out == max(out)] <- paste0(">=", out[out == max(out)])
}
return(out)
} else {
random_exec("MIC", size = size, mo = mo, ab = ab, skew = skew, severity = severity)
}
}
#' @rdname random
#' @export
random_disk <- function(size = NULL, mo = NULL, ab = NULL, ...) {
random_disk <- function(size = NULL, mo = NULL, ab = NULL, skew = "left", severity = 1, ...) {
meet_criteria(size, allow_class = c("numeric", "integer"), has_length = 1, is_positive = TRUE, is_finite = TRUE, allow_NULL = TRUE)
meet_criteria(mo, allow_class = "character", has_length = 1, allow_NULL = TRUE)
meet_criteria(mo, allow_class = "character", has_length = c(1, size), allow_NULL = TRUE)
meet_criteria(ab, allow_class = "character", has_length = 1, allow_NULL = TRUE)
meet_criteria(skew, allow_class = "character", is_in = c("right", "left"), has_length = 1)
meet_criteria(severity, allow_class = c("numeric", "integer"), has_length = 1, is_positive_or_zero = TRUE, is_finite = TRUE)
if (is.null(size)) {
size <- NROW(get_current_data(arg_name = "size", call = -3))
}
random_exec("DISK", size = size, mo = mo, ab = ab)
if (length(mo) > 1) {
out <- rep(NA_mic_, length(size))
p <- progress_ticker(n = length(unique(mo)), n_min = 10, title = "Generating random MIC values")
for (mo_ in unique(mo)) {
p$tick()
out[which(mo == mo_)] <- random_exec("DISK", size = sum(mo == mo_), mo = mo_, ab = ab, skew = skew, severity = severity)
}
out <- as.disk(out)
return(out)
} else {
random_exec("DISK", size = size, mo = mo, ab = ab, skew = skew, severity = severity)
}
}
#' @rdname random
@@ -90,78 +133,60 @@ random_sir <- function(size = NULL, prob_SIR = c(0.33, 0.33, 0.33), ...) {
sample(as.sir(c("S", "I", "R")), size = size, replace = TRUE, prob = prob_SIR)
}
random_exec <- function(method_type, size, mo = NULL, ab = NULL) {
df <- AMR::clinical_breakpoints %pm>%
pm_filter(guideline %like% "EUCAST") %pm>%
pm_arrange(pm_desc(guideline)) %pm>%
subset(guideline == max(guideline) &
method == method_type &
type == "human")
random_exec <- function(method_type, size, mo = NULL, ab = NULL, skew = "right", severity = 1) {
df <- AMR::clinical_breakpoints %pm>% subset(method == method_type & type == "human")
if (!is.null(mo)) {
mo_coerced <- as.mo(mo)
mo_include <- c(
mo_coerced,
as.mo(mo_genus(mo_coerced)),
as.mo(mo_family(mo_coerced)),
as.mo(mo_order(mo_coerced))
)
df_new <- df %pm>%
subset(mo %in% mo_include)
if (nrow(df_new) > 0) {
df <- df_new
} else {
warning_("in `random_", tolower(method_type), "()`: no rows found that match mo '", mo, "', ignoring argument `mo`")
}
mo_coerced <- as.mo(mo, info = FALSE)
mo_include <- c(mo_coerced, as.mo(mo_genus(mo_coerced)), as.mo(mo_family(mo_coerced)), as.mo(mo_order(mo_coerced)))
df_new <- df %pm>% subset(mo %in% mo_include)
if (nrow(df_new) > 0) df <- df_new
}
if (!is.null(ab)) {
ab_coerced <- as.ab(ab)
df_new <- df %pm>%
subset(ab %in% ab_coerced)
if (nrow(df_new) > 0) {
df <- df_new
} else {
warning_("in `random_", tolower(method_type), "()`: no rows found that match ab '", ab, "' (", ab_name(ab_coerced, tolower = TRUE, language = NULL), "), ignoring argument `ab`")
}
df_new <- df %pm>% subset(ab %in% ab_coerced)
if (nrow(df_new) > 0) df <- df_new
}
if (method_type == "MIC") {
# set range
mic_range <- c(0.001, 0.002, 0.005, 0.010, 0.025, 0.0625, 0.125, 0.250, 0.5, 1, 2, 4, 8, 16, 32, 64, 128, 256)
lowest_mic <- min(df$breakpoint_S, na.rm = TRUE)
lowest_mic <- log2(lowest_mic) + sample(c(-3:2), 1)
lowest_mic <- 2^lowest_mic
highest_mic <- max(df$breakpoint_R, na.rm = TRUE)
highest_mic <- log2(highest_mic) + sample(c(-3:1), 1)
highest_mic <- max(lowest_mic * 2, 2^highest_mic)
# get highest/lowest +/- random 1 to 3 higher factors of two
max_range <- mic_range[min(
length(mic_range),
which(mic_range == max(df$breakpoint_R[!is.na(df$breakpoint_R)], na.rm = TRUE)) + sample(c(1:3), 1)
)]
min_range <- mic_range[max(
1,
which(mic_range == min(df$breakpoint_S, na.rm = TRUE)) - sample(c(1:3), 1)
)]
mic_range_new <- mic_range[mic_range <= max_range & mic_range >= min_range]
if (length(mic_range_new) == 0) {
mic_range_new <- mic_range
}
out <- as.mic(sample(mic_range_new, size = size, replace = TRUE))
# 50% chance that lowest will get <= and highest will get >=
out <- skewed_values(COMMON_MIC_VALUES, size = size, min = lowest_mic, max = highest_mic, skew = skew, severity = severity)
if (stats::runif(1) > 0.5 && length(unique(out)) > 1) {
out[out == min(out)] <- paste0("<=", out[out == min(out)])
}
if (stats::runif(1) > 0.5 && length(unique(out)) > 1) {
out[out == max(out)] <- paste0(">=", out[out == max(out)])
}
return(out)
return(as.mic(out))
} else if (method_type == "DISK") {
set_range <- seq(
from = as.integer(min(df$breakpoint_R[!is.na(df$breakpoint_R)], na.rm = TRUE) / 1.25),
to = as.integer(max(df$breakpoint_S, na.rm = TRUE) * 1.25),
disk_range <- seq(
from = floor(min(df$breakpoint_R[!is.na(df$breakpoint_R)], na.rm = TRUE) / 1.25),
to = ceiling(max(df$breakpoint_S[df$breakpoint_S != 50], na.rm = TRUE) * 1.25),
by = 1
)
out <- sample(set_range, size = size, replace = TRUE)
out[out < 6] <- sample(c(6:10), length(out[out < 6]), replace = TRUE)
out[out > 50] <- sample(c(40:50), length(out[out > 50]), replace = TRUE)
disk_range <- disk_range[disk_range >= 6 & disk_range <= 50]
out <- skewed_values(disk_range, size = size, min = min(disk_range), max = max(disk_range), skew = skew, severity = severity)
return(as.disk(out))
}
}
skewed_values <- function(values, size, min, max, skew = c("right", "left"), severity = 1) {
skew <- match.arg(skew)
range_vals <- values[values >= min & values <= max]
if (length(range_vals) < 2) range_vals <- values
ranks <- seq_along(range_vals)
weights <- switch(skew,
right = rev(ranks)^severity,
left = ranks^severity
)
weights <- weights / sum(weights)
sample(range_vals, size = size, replace = TRUE, prob = weights)
}

View File

@@ -159,7 +159,7 @@
#'
#' The function [is.sir()] detects if the input contains class `sir`. If the input is a [data.frame] or [list], it iterates over all columns/items and returns a [logical] vector.
#'
#' The base R function [as.double()] can be used to retrieve quantitative values from a `sir` object: `"S"` = 1, `"I"`/`"SDD"` = 2, `"R"` = 3. All other values are rendered `NA` . **Note:** Do not use `as.integer()`, since that (because of how R works internally) will return the factor level indices, and not these aforementioned quantitative values.
#' The base R function [as.double()] can be used to retrieve quantitative values from a `sir` object: `"S"` = 1, `"I"`/`"SDD"` = 2, `"R"` = 3. All other values are rendered `NA`. **Note:** Do not use `as.integer()`, since that (because of how R works internally) will return the factor level indices, and not these aforementioned quantitative values.
#'
#' The function [is_sir_eligible()] returns `TRUE` when a column contains at most 5% potentially invalid antimicrobial interpretations, and `FALSE` otherwise. The threshold of 5% can be set with the `threshold` argument. If the input is a [data.frame], it iterates over all columns and returns a [logical] vector.
#' @section Interpretation of SIR:

Binary file not shown.

262
R/tidymodels.R Normal file
View File

@@ -0,0 +1,262 @@
#' AMR Extensions for Tidymodels
#'
#' This family of functions allows using AMR-specific data types such as `<mic>` and `<sir>` inside `tidymodels` pipelines.
#' @inheritParams recipes::step_center
#' @details
#' You can read more in our online [AMR with tidymodels introduction](https://amr-for-r.org/articles/AMR_with_tidymodels.html).
#'
#' Tidyselect helpers include:
#' - [all_mic()] and [all_mic_predictors()] to select `<mic>` columns
#' - [all_sir()] and [all_sir_predictors()] to select `<sir>` columns
#'
#' Pre-processing pipeline steps include:
#' - [step_mic_log2()] to convert MIC columns to numeric (via `as.numeric()`) and apply a log2 transform, to be used with [all_mic_predictors()]
#' - [step_sir_numeric()] to convert SIR columns to numeric (via `as.numeric()`), to be used with [all_sir_predictors()]: `"S"` = 1, `"I"`/`"SDD"` = 2, `"R"` = 3. All other values are rendered `NA`. Keep this in mind for further processing, especially if the model does not allow for `NA` values.
#'
#' These steps integrate with `recipes::recipe()` and work like standard preprocessing steps. They are useful for preparing data for modelling, especially with classification models.
#' @seealso [recipes::recipe()], [as.mic()], [as.sir()]
#' @name amr-tidymodels
#' @keywords internal
#' @export
#' @examples
#' library(tidymodels)
#'
#' # The below approach formed the basis for this paper: DOI 10.3389/fmicb.2025.1582703
#' # Presence of ESBL genes was predicted based on raw MIC values.
#'
#'
#' # example data set in the AMR package
#' esbl_isolates
#'
#' # Prepare a binary outcome and convert to ordered factor
#' data <- esbl_isolates %>%
#' mutate(esbl = factor(esbl, levels = c(FALSE, TRUE), ordered = TRUE))
#'
#' # Split into training and testing sets
#' split <- initial_split(data)
#' training_data <- training(split)
#' testing_data <- testing(split)
#'
#' # Create and prep a recipe with MIC log2 transformation
#' mic_recipe <- recipe(esbl ~ ., data = training_data) %>%
#' # Optionally remove non-predictive variables
#' remove_role(genus, old_role = "predictor") %>%
#' # Apply the log2 transformation to all MIC predictors
#' step_mic_log2(all_mic_predictors()) %>%
#' prep()
#'
#' # View prepped recipe
#' mic_recipe
#'
#' # Apply the recipe to training and testing data
#' out_training <- bake(mic_recipe, new_data = NULL)
#' out_testing <- bake(mic_recipe, new_data = testing_data)
#'
#' # Fit a logistic regression model
#' fitted <- logistic_reg(mode = "classification") %>%
#' set_engine("glm") %>%
#' fit(esbl ~ ., data = out_training)
#'
#' # Generate predictions on the test set
#' predictions <- predict(fitted, out_testing) %>%
#' bind_cols(out_testing)
#'
#' # Evaluate predictions using standard classification metrics
#' our_metrics <- metric_set(accuracy, kap, ppv, npv)
#' metrics <- our_metrics(predictions, truth = esbl, estimate = .pred_class)
#'
#' # Show performance:
#' # - negative predictive value (NPV) of ~98%
#' # - positive predictive value (PPV) of ~94%
#' metrics
all_mic <- function() {
x <- tidymodels_amr_select(levels(NA_mic_))
names(x)
}
#' @rdname amr-tidymodels
#' @export
all_mic_predictors <- function() {
x <- tidymodels_amr_select(levels(NA_mic_))
intersect(x, recipes::has_role("predictor"))
}
#' @rdname amr-tidymodels
#' @export
all_sir <- function() {
x <- tidymodels_amr_select(levels(NA_sir_))
names(x)
}
#' @rdname amr-tidymodels
#' @export
all_sir_predictors <- function() {
x <- tidymodels_amr_select(levels(NA_sir_))
intersect(x, recipes::has_role("predictor"))
}
#' @rdname amr-tidymodels
#' @export
step_mic_log2 <- function(
recipe,
...,
role = NA,
trained = FALSE,
columns = NULL,
skip = FALSE,
id = recipes::rand_id("mic_log2")) {
recipes::add_step(
recipe,
step_mic_log2_new(
terms = rlang::enquos(...),
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
)
}
step_mic_log2_new <- function(terms, role, trained, columns, skip, id) {
recipes::step(
subclass = "mic_log2",
terms = terms,
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::prep, step_mic_log2)
prep.step_mic_log2 <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, training, info)
recipes::check_type(training[, col_names], types = "ordered")
step_mic_log2_new(
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
skip = x$skip,
id = x$id
)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::bake, step_mic_log2)
bake.step_mic_log2 <- function(object, new_data, ...) {
recipes::check_new_data(object$columns, object, new_data)
for (col in object$columns) {
new_data[[col]] <- log2(as.numeric(as.mic(new_data[[col]])))
}
new_data
}
#' @export
print.step_mic_log2 <- function(x, width = max(20, options()$width - 35), ...) {
title <- "Log2 transformation of MIC columns"
recipes::print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::tidy, step_mic_log2)
tidy.step_mic_log2 <- function(x, ...) {
if (recipes::is_trained(x)) {
res <- tibble::tibble(terms = x$columns)
} else {
res <- tibble::tibble(terms = recipes::sel2char(x$terms))
}
res$id <- x$id
res
}
#' @rdname amr-tidymodels
#' @export
step_sir_numeric <- function(
recipe,
...,
role = NA,
trained = FALSE,
columns = NULL,
skip = FALSE,
id = recipes::rand_id("sir_numeric")) {
recipes::add_step(
recipe,
step_sir_numeric_new(
terms = rlang::enquos(...),
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
)
}
step_sir_numeric_new <- function(terms, role, trained, columns, skip, id) {
recipes::step(
subclass = "sir_numeric",
terms = terms,
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::prep, step_sir_numeric)
prep.step_sir_numeric <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, training, info)
recipes::check_type(training[, col_names], types = "ordered")
step_sir_numeric_new(
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
skip = x$skip,
id = x$id
)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::bake, step_sir_numeric)
bake.step_sir_numeric <- function(object, new_data, ...) {
recipes::check_new_data(object$columns, object, new_data)
for (col in object$columns) {
new_data[[col]] <- as.numeric(as.sir(new_data[[col]]))
}
new_data
}
#' @export
print.step_sir_numeric <- function(x, width = max(20, options()$width - 35), ...) {
title <- "Numeric transformation of SIR columns"
recipes::print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
}
#' @rawNamespace if(getRversion() >= "3.0.0") S3method(recipes::tidy, step_sir_numeric)
tidy.step_sir_numeric <- function(x, ...) {
if (recipes::is_trained(x)) {
res <- tibble::tibble(terms = x$columns)
} else {
res <- tibble::tibble(terms = recipes::sel2char(x$terms))
}
res$id <- x$id
res
}
tidymodels_amr_select <- function(check_vector) {
df <- get_current_data()
ind <- which(
vapply(
FUN.VALUE = logical(1),
df,
function(x) all(x %in% c(check_vector, NA), na.rm = TRUE) & any(x %in% check_vector),
USE.NAMES = TRUE
),
useNames = TRUE
)
ind
}