rm(list = ls()) # clean glob. computing environment
# gc() # garbage collect

We fetch the dataset from here.

df.head()
##    accountNumber  customerId  ...  expirationDateKeyInMatch  isFraud
## 0      737265056   737265056  ...                     False    False
## 1      737265056   737265056  ...                     False    False
## 2      737265056   737265056  ...                     False    False
## 3      737265056   737265056  ...                     False    False
## 4      830329091   830329091  ...                     False    False
## 
## [5 rows x 29 columns]
library(reticulate)

d <- reticulate::py$df # convert to r obj. for dplyr manipulation (faster), and 'gg' plotting (prettier)

Let’s switch to R for piping.

Q1: Load

“Programmatically download and load the transactions data. Describe the structure. Provide some additional basic summary statistics for each field. Be sure to include a count of null, minimum, maximum, and unique values where appropriate.”

d[1, 20] # `echoBuffer`, i.e. 20th col. in 1st row: ""
## [1] ""

We have some rows and cols with '' in lieu of NA. Let’s fix that. Specifically, we can do away with echoBuffer, merchantCity, merchantState, merchantZip, posOnPremises, and recurringAuthInd. (Their missingness is 100%.)

# `, fig.align='center', catche=TRUE, echo=F`
library(naniar)
library(dplyr)

d %>%
  mutate(across(.fns = ~replace(., . == '', NA))) %>%
  visdat::vis_dat(warn_large_data = FALSE) +
    ggplot2::scale_fill_manual(values = c('numeric' = '#FF9999', 'logical' = '#CCCCFF', 'missing' = 'gray53', 'character' = '#99CCFF')) + #CCFFFF #azure2
    ggplot2::labs(title = 'Missingness: \nby Data Type')

Let us display that missingness also numerically.

d %>%
  mutate(across(.fns = ~replace(., . == '', NA))) %>%
  is.na() %>%
  colSums() # missingness per col.
##            accountNumber               customerId              creditLimit 
##                        0                        0                        0 
##           availableMoney      transactionDateTime        transactionAmount 
##                        0                        0                        0 
##             merchantName               acqCountry      merchantCountryCode 
##                        0                     4562                      724 
##             posEntryMode         posConditionCode     merchantCategoryCode 
##                     4054                      409                        0 
##           currentExpDate          accountOpenDate  dateOfLastAddressChange 
##                        0                        0                        0 
##                  cardCVV               enteredCVV          cardLast4Digits 
##                        0                        0                        0 
##          transactionType               echoBuffer           currentBalance 
##                      698                   786363                        0 
##             merchantCity            merchantState              merchantZip 
##                   786363                   786363                   786363 
##              cardPresent            posOnPremises         recurringAuthInd 
##                        0                   786363                   786363 
## expirationDateKeyInMatch                  isFraud 
##                        0                        0

And let’s calculate some basic descriptive statistics for our numeric variables.

# summary stats for numeric cols
purrr::map_dfr(lst(min, median, mean, max, sd),
               ~ summarize(d[, !(names(d) %in% c('accountNumber', 'customerId', 'cardCVV', 'enteredCVV', 'cardLast4Digits'))], across(where(is.numeric), .x, na.rm = TRUE)), # drop non-pertinent cols, e.g. `accountNumber`
               .id = 'summary_stats')
##   summary_stats creditLimit availableMoney transactionAmount currentBalance
## 1           min      250.00      -1005.630            0.0000          0.000
## 2        median     7500.00       3184.860           87.9000       2451.760
## 3          mean    10759.46       6250.725          136.9858       4508.739
## 4           max    50000.00      50000.000         2011.5400      47498.810
## 5            sd    11636.17       8880.784          147.7256       6457.442

Do note that there are some suspect values, e.g. enteredCVV being nill (0), or 0 for cardLast4Digits.

# drop cols with 100% missingness, and drop na
e <- d %>%
  mutate(across(.fns = ~replace(., . == '', NA))) %>%
  select(-one_of(c('echoBuffer', 'merchantCity', 'merchantState',
                   'merchantZip', 'posOnPremises', 'recurringAuthInd'))) %>%
  tidyr::drop_na()

$0 transactions seem to be “verification” transactions allegedly used by merchants to validate card / account details.

Q2: Plot

“Plot a histogram of the processed amounts of each transaction, the transactionAmount column. Report any structure you find and any hypotheses you have about that structure.”

After dropping ~22k zero-dollar transactions, which we may surmise to have been so-called “verification” transactions, we have mode (8.2) < median (92.2) < mean (141) -> Chi-squared dist. or Poisson dist. Or right-skewed (positive skewness) distribution. What stands out is that most transactions happen in the lower dollar-brackets (dropping off rather dramatically after ~$500 threshold), and that by far the greatest chunk of the transactions happen under ~$92.

Another thing to think about is, it’d seem as though the transactions’ dynamics are driven by what we might refer to as “frivolous” transactions, i.e. small-amount transactions. Since these are credit cards, that would seem to be a suboptimal behavior for the card-holder who, instead, should aim to use their credit cards for big-ticket items. This would appear to be a feature to exploit, though perhaps the fact we see some semblance of this behavior is an indication that it has already been exploited.

Q3: Data Wrangling - Duplicate Transactions

“Duplicated transactions: (i) reversed transaction (purchase followed by a reversal), and (ii) multi-swipe transactions (vendor accidentally charges a customer’s card multiple times within a short time span).”

# reversed transactions (discounting verif. / $`0` transactions)
e %>%
  filter(transactionAmount != 0) %>%
  mutate(dateTime = lubridate::ymd_hms(e[e$transactionAmount != 0, 5])) %>%
  filter(transactionType == 'REVERSAL') %>%
  select(accountNumber, transactionAmount, currentBalance, dateTime, merchantName, transactionType) %>%
  group_by(accountNumber, transactionAmount) %>% # 19,488 records
  as.data.frame() %>%
  select(transactionAmount) %>%
  sum() # $2,787,895
## [1] 2787895

Discounting the zero-dollar verification transactions, we have 19,488 records worth $2,787,895 in reversed transactions.

firstTransactions <- multiSwipeTransaction %>%
  group_by(accountNumber, transactionAmount) %>%
  summarise() # %>% # 6,763 x 2
## `summarise()` has grouped output by 'accountNumber'. You can override using the
## `.groups` argument.
  # as.data.frame() %>%
  # select(transactionAmount) %>%
  # sum() # $996,154.8

dim(multiSwipeTransaction)[1] - dim(firstTransactions)[1] # 7442
## [1] 7442
sum(multiSwipeTransaction$transactionAmount) - sum(firstTransactions$transactionAmount) # $1,097,507
## [1] 1097507

Discounting the zero-dollar verification transactions, we have 7,442 records worth $1,097,507 in multi-swipe transactions.

Q4: Model

“Build a predictive model to determine whether a given transaction will be fraudulent (isFraud) or not.”

# a bit more of that feature engineering stuff
e$merchantName %>%
  unique() %>%
  length() # 2489 vendors
## [1] 2489

No particular pattern to observe as a time series, except that the count of no fraud increases over the year, which, however, may merely be on account of more transactions happening as time goes on. (We see no dramatic changes when looking at the data as a fraction of the total. We may cautiosly disregard time as a variable for the purposes of our modeling.)

library(RColorBrewer)
myCols = c(brewer.pal(name = 'Blues', n = 8)[3:8], brewer.pal(name = 'GnBu', n = 8)[2:8])

e %>%
  filter(transactionAmount != 0) %>%
  mutate(dateTime = lubridate::ymd_hms(e[e$transactionAmount != 0, 5])) %>%
  select(-one_of(c('transactionDateTime'))) %>%
  mutate(date = as.Date(dateTime)) %>%
  filter(transactionType == 'PURCHASE' & isFraud == 'TRUE') %>%
  ggplot(., aes(x = transactionAmount,
                fill = merchantCategoryCode)) +
    geom_histogram() +
    # facet_wrap(~merchantCategoryCode) +
    facet_wrap(~merchantCategoryCode, # facet_grid()
               labeller = as_labeller(c('rideshare' = 'Rideshare', 'entertainment' = 'Entertainment', 'mobileapps' = 'Mobile Apps', 'fastfood' = 'Fast Food', 'food_delivery' = 'Food Delivery',
                                        'auto' = 'Auto', 'online_retail' = 'Online Retail', 'gym' = 'Gym', 'health' = 'Health', 'personal care' = 'Personal Care',
                                        'food' = 'Food', 'fuel' = 'Fuel', 'online_subscriptions' = 'Online Subscriptions', 'online_gifts' = 'Online Gifts', 'hotels' = 'Hotels',
                                        'airline' = 'Airline', 'furniture' = 'Furniture', 'subscriptions' = 'Subscriptions', 'cable/phone' = 'Cable / Phone'))) +
    theme_minimal() +
    theme(panel.background = element_rect(fill = 'azure2'), panel.grid = element_line(color = 'white')) +
    scale_fill_manual('Legend',
                      labels = c('Airline', 'Auto', 'Entertainment', 'Fast Food', 'Food', 'Furniture', 'Health', 'Hotels', 'Online Gifts',
                                 'Online Retail', 'Personal Care', 'Rideshare', 'Subscriptions'),
                      values = myCols) +
    scale_color_manual(values = myCols) +
    labs(title = bquote(bold('Fraud') ~'per Merchant Category Code'), x = 'Transaction Amount (U.S. $)', y = 'Count')
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

It seems plausible that certain merchant-type transactions are more prone to fraud, such as online retail. Including this info in the model might not be bad idea.

Already did a corrplot and there’s not a whole lot in there. Let’s do some manipulations to come up with new variables that might help with the predictiveness. For instance, we shall surmise that frauds may tend to happen at a particular time in a day (perhaps later on), and that there are some merchants in that online retail business (or elsewhere) that are a good indicator of a fraud potentially occurring. Let’s add two new column indicators like that to help with the modeling.

e %>%
  filter(transactionAmount != 0) %>% # filter out verif. transactions
  mutate(dateTime = lubridate::ymd_hms(e[e$transactionAmount != 0, 5])) %>%
  select(-one_of(c('transactionDateTime'))) %>%
  mutate(date = as.Date(dateTime)) %>%
  mutate(isFraud = ifelse(isFraud == 'TRUE', 1, 0), # convert to numeric
         expirationDateKeyInMatch = ifelse(expirationDateKeyInMatch == 'TRUE', 1, 0), # ditto
         cardPresent = ifelse(cardPresent == 'TRUE', 1, 0)) %>% # ditto
  filter(transactionType == 'PURCHASE') %>%
  mutate(merchantCategoryCode = as.numeric(as.factor(merchantCategoryCode))) %>%
  filter(isFraud == 1) %>%
  group_by(merchantName, isFraud) %>%
  count() %>%
  arrange(desc(n)) %>%
  print(n=22) # top 22 with count >= 100
## # A tibble: 1,016 × 3
## # Groups:   merchantName, isFraud [1,016]
##    merchantName                  isFraud     n
##    <chr>                           <dbl> <int>
##  1 Lyft                                1   707
##  2 ebay.com                            1   609
##  3 Fresh Flowers                       1   516
##  4 Uber                                1   481
##  5 cheapfast.com                       1   415
##  6 walmart.com                         1   403
##  7 sears.com                           1   396
##  8 oldnavy.com                         1   376
##  9 staples.com                         1   374
## 10 alibaba.com                         1   359
## 11 amazon.com                          1   345
## 12 gap.com                             1   341
## 13 target.com                          1   338
## 14 apple.com                           1   335
## 15 discount.com                        1   316
## 16 American Airlines                   1   269
## 17 Fresh eCards                        1   200
## 18 Blue Mountain Online Services       1   196
## 19 Next Day Online Services            1   162
## 20 Blue Mountain eCards                1   113
## 21 Fresh Online Services               1   106
## 22 Mobile eCards                       1   100
## # … with 994 more rows

The top fraudulent transactions “offenders”. Displaying just the ones with more than 100 observations thereof.

topOffender = c('Lyft', 'ebay.com', 'Fresh Flowers', 'Uber', 'cheapfast.com',
                'walmart.com', 'sears.com', 'oldnavy.com', 'staples.com', 'alibaba.com',
                'amazon.com', 'gap.com', 'target.com', 'apple.com', 'discount.com',
                'American Airlines', 'Fresh eCards', 'Blue Mountain Online Services', 'Next Day Online Services', 'Blue Mountain eCards',
                'Fresh Online Services', 'Mobile eCards')
f <- e %>%
  filter(transactionAmount != 0) %>% # filter out verif. transactions
  mutate(dateTime = lubridate::ymd_hms(e[e$transactionAmount != 0, 5])) %>%
  select(-one_of(c('transactionDateTime'))) %>%
  # mutate(date = as.Date(dateTime)) %>%
  mutate(date = as.Date(dateTime),
         hour = lubridate::hour(dateTime),
         minute = lubridate::minute(dateTime),
         second = lubridate::second(dateTime)) %>%
  mutate(isFraud = ifelse(isFraud == 'TRUE', 1, 0), # convert to numeric
         expirationDateKeyInMatch = ifelse(expirationDateKeyInMatch == 'TRUE', 1, 0), # ditto
         cardPresent = ifelse(cardPresent == 'TRUE', 1, 0)) %>% # ditto
  filter(transactionType == 'PURCHASE') %>%
  mutate(merchantCategoryCode = as.numeric(as.factor(merchantCategoryCode))) %>% # convert to numeric
  mutate(theWeeHoursTransactions = ifelse(hour >= 18 | hour <= 6, 1, 0)) %>% # col. for transactions happening in "the wee hours" 6pm-6am
  mutate(topOffender = ifelse(merchantName %in% topOffender, 1, 0)) %>% # col. indicator if top "offender" merchant (fraud-wise)
  mutate(wrongPin = ifelse(as.character(cardCVV) == as.character(enteredCVV), 0, 1)) %>% # indicator of whether the cvv entered matched the card's
  select(-one_of(c('cardCVV', 'enteredCVV', 'hour', 'minute', 'second'))) %>%
  # group_by(isFraud, wrongPin) %>%
  # count()
  select_if(., is.numeric)

Another new variable we may be curious to create, and explore, is checking whether at the time of a transaction, the CVV code entered matched the card’s CVV. The assumption being that a fraudulent transaction may be one where the PIN isn’t matched. Let’s try that, and check.

f %>%
  cor() %>%
  round(., 4) %>%
  corrplot::corrplot(., method = 'number', tl.cex = 0.7, number.cex = 0.6)

Nothing to write home about. We barely see ~.01 correlations (between isFraud) and transactionAmount (positive), cardPresent (negative), and topOffender (positive).

Not the big kahuna I was hoping it to be, but that’s alright.

It’s important that we create a balanced sample before proceeding. We don’t want one value of the dependent variable to drown out the other in the sample. Let’s do that.

f_0 <- f %>% # `noFraud` table
  filter(isFraud == 0) %>% # 11,528 for T (1), 723,661 for F (0), i.e. ~0.0159 T:F ratio
  sample_frac(., size = 0.016, replace = F)

f_1 <- f %>% # `fraud` table
  filter(isFraud == 1)

g <- rbind(f_0, f_1) # balanced dataset for modeling
g <- g %>%
  select(-one_of(c('expirationDateKeyInMatch')))


Model 1: Plain Vanilla Logistic Regression

Let’s try our first model, which will be a regular machine learning (ML) one. We shall switch back to Python for this.

g_pd <- r_to_py(g)
r.g_pd
##        accountNumber   customerId  ...  topOffender  wrongPin
## 0        652722129.0  652722129.0  ...          0.0       0.0
## 1        284919032.0  284919032.0  ...          0.0       0.0
## 2        955678177.0  955678177.0  ...          1.0       0.0
## 3        253508360.0  253508360.0  ...          0.0       0.0
## 4        798413865.0  798413865.0  ...          0.0       0.0
## ...              ...          ...  ...          ...       ...
## 23102    207667444.0  207667444.0  ...          1.0       0.0
## 23103    207667444.0  207667444.0  ...          0.0       0.0
## 23104    428856030.0  428856030.0  ...          1.0       0.0
## 23105    657364505.0  657364505.0  ...          1.0       0.0
## 23106    899818521.0  899818521.0  ...          1.0       0.0
## 
## [23107 rows x 13 columns]
df2 = r.g_pd.copy()

We already balanced the sample, converted to numeric, and dropped NAs. There’s not much left in the way of preprocessing, but let’s scale and one hot encode.

Some helper functions to massage the data to the right form for feeding into the model.

import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
#  sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
num_features = ['creditLimit', 'availableMoney', 'transactionAmount', 'cardLast4Digits', 'currentBalance']
num_transformer = Pipeline(steps=[('scaler', StandardScaler())])

cat_features = ['accountNumber', 'customerId', 'merchantCategoryCode', 'cardPresent', # 'isFraud',
                'theWeeHoursTransactions', 'topOffender', 'wrongPin']
cat_transformer = Pipeline(steps=[('onehot', OneHotEncoder(handle_unknown='ignore'))])

preprocess = ColumnTransformer(transformers=[('num', num_transformer, num_features),
                                             ('cat', cat_transformer, cat_features)])
X = df2.drop('isFraud', axis=1)
y = df2['isFraud']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 42) # reserve 20% for testing
preprocess = preprocess.fit(X_train)
myPreprocessor(X_train).shape
## (18485, 6034)
print(X_train.shape, X_test.shape, 
      y_train.shape, y_test.shape)
## (18485, 12) (4622, 12) (18485,) (4622,)
hyperparams = {'C':np.logspace(1, 10, 100), 'penalty':['l2'], 'max_iter':[100]} # 'max_iter':[5]

logit = LogisticRegression()
logitCv = GridSearchCV(logit, hyperparams, cv = 10)
logitCv.fit(myPreprocessor(X_train), y_train)
GridSearchCV(cv=10, estimator=LogisticRegression(),
             param_grid={'C': array([1.00000000e+01, 1.23284674e+01, 1.51991108e+01, 1.87381742e+01,
       2.31012970e+01, 2.84803587e+01, 3.51119173e+01, 4.32876128e+01,
       5.33669923e+01, 6.57933225e+01, 8.11130831e+01, 1.00000000e+02,
       1.23284674e+02, 1.51991108e+02, 1.87381742e+02, 2.31012970e+02,
       2.84803587e+02, 3.51119173e+02, 4.32876...
       8.11130831e+07, 1.00000000e+08, 1.23284674e+08, 1.51991108e+08,
       1.87381742e+08, 2.31012970e+08, 2.84803587e+08, 3.51119173e+08,
       4.32876128e+08, 5.33669923e+08, 6.57933225e+08, 8.11130831e+08,
       1.00000000e+09, 1.23284674e+09, 1.51991108e+09, 1.87381742e+09,
       2.31012970e+09, 2.84803587e+09, 3.51119173e+09, 4.32876128e+09,
       5.33669923e+09, 6.57933225e+09, 8.11130831e+09, 1.00000000e+10]),
                         'max_iter': [100], 'penalty': ['l2']})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
print('Best Parameters: ', logitCv.best_params_)
## Best Parameters:  {'C': 10.0, 'max_iter': 100, 'penalty': 'l2'}
LogisticRegression(C=10.0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
## 0.7882607519610495
y_pred = model.predict(myPreprocessor(X_test))
y_pred
## array([0, 0, 1, ..., 1, 1, 0])
from sklearn.metrics import accuracy_score

print("Accuracy: {:.2f}%".format(accuracy_score(y_test, y_pred)*100))
## Accuracy: 69.04%

Better than tossing a coin, but we might do better still. Let’s use Bayesian inference for the job. We’ll switch back to R to call Stan that way. Let’s do logits to cf. regular machine learning with Bayesian workflow, performance-wise.

Model 2: Bayesian Logistic Regression

# a tiny bit of additional preprocessing like in python
g <- g %>%
  dplyr::mutate_at(vars(starts_with(colnames(g)[1:8])), ~c(scale(., center = T, scale = T)))
# some bayesian packages
library(rstanarm)
library(loo)
library(projpred)
library(bayesplot)

It’s very important that we check again how balanced the sample is with respect to the dependent variable. Computing the model matrix inverse would fail from precompiled stan_glm() models, should the data not be balanced out.

The resultant formula is sparse but let’s cautiously proceed.

formula(paste("isFraud ~", paste(colnames(g)[1:(dim(g)[2]-1)][c(-10, -4, -5, -8, -1, -2, -7)], collapse = " + ")))
## isFraud ~ creditLimit + merchantCategoryCode + cardPresent + 
##     theWeeHoursTransactions + topOffender
g$isFraud <- factor(g$isFraud)

X <- model.matrix(thisFormula, data = g)
y <- g$isFraud
# can check the iversion
# MASS::ginv(X) # take inverse of a matrix
# model
t_prior <- student_t(df = 7, location = 0, scale = 2.5) # student t prior with 7 degrees of freedom for coeffs ought to be close to 0
post1 <- stan_glm(thisFormula, data = g, family = binomial(link = 'logit'), QR = TRUE, seed = 42, refresh = 0) # prior = t_prior, prior_intercept = t_prior,

round(coef(post1), 2)
##             (Intercept)             creditLimit    merchantCategoryCode 
##                   -0.68                    0.03                   -0.34 
##             cardPresent theWeeHoursTransactions             topOffender 
##                   -0.14                   -0.04                    1.45
round(posterior_interval(post1, prob = 0.9), 2)
##                            5%   95%
## (Intercept)             -0.76 -0.61
## creditLimit              0.01  0.06
## merchantCategoryCode    -0.37 -0.30
## cardPresent             -0.23 -0.06
## theWeeHoursTransactions -0.08  0.01
## topOffender              1.37  1.54
(loo1 <- loo(post1, save_psis = TRUE)) # smoothed leave-one-out cross-validation (psis-loo) to compute expected log predictive density (elpd)
## 
## Computed from 4000 by 23107 log-likelihood matrix
## 
##          Estimate   SE
## elpd_loo -15120.4 41.4
## p_loo         5.9  0.1
## looic     30240.8 82.7
## ------
## Monte Carlo SE of elpd_loo is 0.0.
## 
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
post0 <- update(post1, formula = isFraud ~ 1, QR = FALSE, refresh=0) # compare to baseline
(loo0 <- loo(post0))
## 
## Computed from 4000 by 23107 log-likelihood matrix
## 
##          Estimate  SE
## elpd_loo -16017.5 0.3
## p_loo         1.0 0.0
## looic     32035.0 0.7
## ------
## Monte Carlo SE of elpd_loo is 0.0.
## 
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
loo_compare(loo0, loo1) # `post1` clearly better, i.e. covariates contain valuable info for predictions
##       elpd_diff se_diff
## post1    0.0       0.0 
## post0 -897.1      41.4
# predicted probabilities
linpred <- posterior_linpred(post1)
preds <- posterior_epred(post1)
pred <- colMeans(preds)
pr <- as.integer(pred >= 0.5)
   
# posterior classification accuracy
round(mean(xor(pr, as.integer(y == 0))), 2)
## [1] 0.63
# posterior balanced classification accuracy
round((mean(xor(pr[y == 0] > 0.5, as.integer(y[y == 0])))+mean(xor(pr[y == 1] < 0.5, as.integer(y[y == 1]))))/2, 2)
## [1] 0.63

Now, let’s predict on not yet seen data using leave-one-out crossval.

# loo predictive probabilities
ploo <- E_loo(preds, loo1$psis_object, type = 'mean', log_ratios = -log_lik(post1))$value

# loo classification accuracy
round(mean(xor(ploo > 0.5, as.integer(y == 0))), 2)
## [1] 0.63
# loo balanced classification accuracy
round((mean(xor(ploo[y == 0] > 0.5, as.integer(y[y == 0])))+mean(xor(ploo[y == 1] < 0.5,as.integer(y[y == 1]))))/2, 2)
## [1] 0.63

An important consideration, on the face of it, it seems that this performed worse than our plain vanilla model. But, (i) we tossed a lot of data to inverse the matrix, which, given more time, we could easily fix by doing a better-balanced sample, and / or by hand-writing the function, and (ii) consider how much more robust (and therefore trustworthy) this is.

Importantly also, unlike regular ML, this is a far more explainable model. That is, not a black box. For robustness, modularity, explainability, regular Bayesian workflow, to my mind, bests all else. As to the “hand-writing the function” bit, we mean not doing a precompiled model from a package, but instead writing our own Stan program, and calling the compiler ourselves, which would more readily handle hierarchical models, and messier datasets.

Given more time, we could explore some deep learning model, and fix the issues with the sample to boost some Bayesian model.