Title: | Evolutionary Learning of Globally Optimal Trees |
---|---|
Description: | Commonly used classification and regression tree methods like the CART algorithm are recursive partitioning methods that build the model in a forward stepwise search. Although this approach is known to be an efficient heuristic, the results of recursive tree methods are only locally optimal, as splits are chosen to maximize homogeneity at the next step only. An alternative way to search over the parameter space of trees is to use global optimization methods like evolutionary algorithms. The 'evtree' package implements an evolutionary algorithm for learning globally optimal classification and regression trees in R. CPU and memory-intensive tasks are fully computed in C++ while the 'partykit' package is leveraged to represent the resulting trees in R, providing unified infrastructure for summaries, visualizations, and predictions. |
Authors: | Thomas Grubinger [aut, cre], Achim Zeileis [aut] , Karl-Peter Pfeiffer [aut] |
Maintainer: | Thomas Grubinger <[email protected]> |
License: | GPL-2 | GPL-3 |
Version: | 1.0-8 |
Built: | 2024-11-06 19:14:05 UTC |
Source: | https://github.com/r-forge/partykit |
Marketing case study about a (fictitious) American book club to whose customers a book about “The Art History of Florence” was advertised.
data("BBBClub")
data("BBBClub")
A data frame containing 1,300 observations on 11 variables.
factor. Did the customer buy the advertised book?
factor indicating gender.
total amount of money spent at the BBB Club.
number of books purchased at the BBB Club.
number of months since the last purchase.
number of months since the first purchase.
number of children's books purchased.
number of youth books purchased.
number of cookbooks purchased.
number of do-it-yourself books purchased.
number of art books purchased.
The data is a marketing case study about a (fictitious) American book club.
taken from the Marketing Engineering textbook of Lilien and Rangaswamy (2004).
In this case study, a brochure of the book “The Art History of Florence” was
sent to 20,000 customers and 1,806 of which bought the book. A subsample of 1,300
customers is provided in BBBClub
for building a predictive model for choice
.
The use of a cost matrix is suggested for this dataset. Classifying a customer that purchased the book as a non-buyer is worse (cost = 5), than it is to classify a custumer that did not purchase the book as a buyer (cost = 1).
Complements to Lilien and Rangaswamy (2004).
Lilien GL, Rangaswamy A (2004). Marketing Engineering: Computer-Assisted Marketing Analysis and Planning, 2nd edition. Victoria, BC: Trafford Publishing.
## Not run: ## data, packages, random seed data("BBBClub", package = "evtree") library("rpart") suppressWarnings(RNGversion("3.5.0")) set.seed(1090) ## learn trees ev <- evtree(choice ~ ., data = BBBClub, minbucket = 10, maxdepth = 2) rp <- as.party(rpart(choice ~ ., data = BBBClub, minbucket = 10, model = TRUE)) ct <- ctree(choice ~ ., data = BBBClub, minbucket = 10, mincrit = 0.99) ## visualization plot(ev) plot(rp) plot(ct) ## accuracy: misclassification rate mc <- function(obj) 1 - mean(predict(obj) == BBBClub$choice) c("evtree" = mc(ev), "rpart" = mc(rp), "ctree" = mc(ct)) ## complexity: number of terminal nodes c("evtree" = width(ev), "rpart" = width(rp), "ctree" = width(ct)) ## compare structure of predictions ftable(tab <- table(evtree = predict(ev), rpart = predict(rp), ctree = predict(ct), observed = BBBClub$choice)) ## compare customer predictions only (absolute, proportion correct) sapply(c("evtree", "rpart", "ctree"), function(nam) { mt <- margin.table(tab, c(match(nam, names(dimnames(tab))), 4)) c(abs = as.vector(rowSums(mt))[2], rel = round(100 * prop.table(mt, 1)[2, 2], digits = 3)) }) ## End(Not run)
## Not run: ## data, packages, random seed data("BBBClub", package = "evtree") library("rpart") suppressWarnings(RNGversion("3.5.0")) set.seed(1090) ## learn trees ev <- evtree(choice ~ ., data = BBBClub, minbucket = 10, maxdepth = 2) rp <- as.party(rpart(choice ~ ., data = BBBClub, minbucket = 10, model = TRUE)) ct <- ctree(choice ~ ., data = BBBClub, minbucket = 10, mincrit = 0.99) ## visualization plot(ev) plot(rp) plot(ct) ## accuracy: misclassification rate mc <- function(obj) 1 - mean(predict(obj) == BBBClub$choice) c("evtree" = mc(ev), "rpart" = mc(rp), "ctree" = mc(ct)) ## complexity: number of terminal nodes c("evtree" = width(ev), "rpart" = width(rp), "ctree" = width(ct)) ## compare structure of predictions ftable(tab <- table(evtree = predict(ev), rpart = predict(rp), ctree = predict(ct), observed = BBBClub$choice)) ## compare customer predictions only (absolute, proportion correct) sapply(c("evtree", "rpart", "ctree"), function(nam) { mt <- margin.table(tab, c(match(nam, names(dimnames(tab))), 4)) c(abs = as.vector(rowSums(mt))[2], rel = round(100 * prop.table(mt, 1)[2, 2], digits = 3)) }) ## End(Not run)
Data of married women who were either not pregnant or do not know if they were at the time of interview. The task is to predict the women's current contraceptive method choice (no use, long-term methods, short-term methods) based on her demographic and socio-economic characteristics.
data("ContraceptiveChoice")
data("ContraceptiveChoice")
A data frame containing 1,437 observations on 10 variables.
wife's age in years.
ordered factor indicating the wife's education, with levels "low"
, "medium-low"
, "medium-high"
and "high"
.
ordered factor indicating the wife's education, with levels "low"
, "medium-low"
, "medium-high"
and "high"
.
number of children.
binary variable indicating the wife's religion, with levels "non-Islam"
and "Islam"
.
binary variable indicating if the wife is working.
ordered factor indicating the husbands occupation, with levels "low"
, "medium-low"
, "medium-high"
and "high"
.
standard of living index with levels "low"
, "medium-low"
, "medium-high"
and "high"
.
binary variable indicating media exposure, with levels "good"
and "not good"
.
factor variable indicating the contraceptive method used, with levels "no-use"
, "long-term"
and "short-term"
.
This dataset is a subset of the 1987 National Indonesia Contraceptive Prevalence Survey and was created by Tjen-Sien Lim.
It has been taken from the UCI Repository Of Machine Learning Databases at
http://archive.ics.uci.edu/ml/.
Lim, T.-S., Loh, W.-Y. & Shih, Y.-S. (1999). A Comparison of Prediction Accuracy, Complexity, and Training Time of Thirty-three Old and New Classification Algorithms. Machine Learning, 40(3), 203–228.
data("ContraceptiveChoice") summary(ContraceptiveChoice) ## Not run: suppressWarnings(RNGversion("3.5.0")) set.seed(1090) contt <- evtree(contraceptive_method_used ~ . , data = ContraceptiveChoice) contt table(predict(contt), ContraceptiveChoice$contraceptive_method_used) plot(contt) ## End(Not run)
data("ContraceptiveChoice") summary(ContraceptiveChoice) ## Not run: suppressWarnings(RNGversion("3.5.0")) set.seed(1090) contt <- evtree(contraceptive_method_used ~ . , data = ContraceptiveChoice) contt table(predict(contt), ContraceptiveChoice$contraceptive_method_used) plot(contt) ## End(Not run)
Learning of globally optimal classification and regression trees by using evolutionary algorithms.
evtree(formula, data, subset, na.action, weights, control = evtree.control(...), ...)
evtree(formula, data, subset, na.action, weights, control = evtree.control(...), ...)
formula |
a symbolic description of the model to be fit, no interactions should be used. |
data , subset , na.action
|
arguments controlling formula processing
via |
weights |
optional integer vector of case weights. |
control |
a list of control arguments specified via
|
... |
arguments passed to |
Globally optimal classification and regression trees are learned by using
evolutionary algorithm. Roughly, the algorithm works as follows. First, a set of
trees is initialized with random split rules in the root nodes. Second, mutation
and crossover operators are applied to modify the trees' structure and the tests
that are applied in the internal nodes. After each modification step a survivor
selection mechanism selects the best candidate models for the next iteration. In
this evolutionary process the mean quality of the population increases over
time. The algorithm terminates when the quality of the best trees does not
improve further, but not later than a maximum number of iterations specified by
niterations
in evtree.control
.
More details on the algorithm are provided Grubinger et al. (2014) which is also
provided as vignette("evtree", package = "evtree")
.
The resulting trees can be summarized and visualized by the print.constparty
,
and plot.constparty
methods provided by the partykit package.
Moreover, the predict.party
method can be used to compute fitted responses,
probabilities (for classification trees), and nodes.
An object of class party
.
Grubinger T, Zeileis A, Pfeiffer KP (2014). evtree: Evolutionary Learning of Globally Optimal Classification and Regression Trees in R. Journal of Statistical Software, 61(1), 1-29. doi:10.18637/jss.v061.i01
## regression suppressWarnings(RNGversion("3.5.0")) set.seed(1090) airq <- subset(airquality, !is.na(Ozone) & complete.cases(airquality)) ev_air <- evtree(Ozone ~ ., data = airq) ev_air plot(ev_air) mean((airq$Ozone - predict(ev_air))^2) ## classification ## (note that different equivalent "perfect" splits for the setosa species ## in the iris data may be found on different architectures/systems) ev_iris <- evtree(Species ~ .,data = iris) ## IGNORE_RDIFF_BEGIN ev_iris ## IGNORE_RDIFF_END plot(ev_iris) table(predict(ev_iris), iris$Species) 1 - mean(predict(ev_iris) == iris$Species)
## regression suppressWarnings(RNGversion("3.5.0")) set.seed(1090) airq <- subset(airquality, !is.na(Ozone) & complete.cases(airquality)) ev_air <- evtree(Ozone ~ ., data = airq) ev_air plot(ev_air) mean((airq$Ozone - predict(ev_air))^2) ## classification ## (note that different equivalent "perfect" splits for the setosa species ## in the iris data may be found on different architectures/systems) ev_iris <- evtree(Species ~ .,data = iris) ## IGNORE_RDIFF_BEGIN ev_iris ## IGNORE_RDIFF_END plot(ev_iris) table(predict(ev_iris), iris$Species) 1 - mean(predict(ev_iris) == iris$Species)
Various parameters that control aspects of the evtree
fit.
evtree.control(minbucket = 7L, minsplit = 20L, maxdepth = 9L, niterations = 10000L, ntrees = 100L, alpha = 1, operatorprob = list(pmutatemajor = 0.2, pmutateminor = 0.2, pcrossover = 0.2, psplit = 0.2, pprune = 0.2), seed = NULL, ...)
evtree.control(minbucket = 7L, minsplit = 20L, maxdepth = 9L, niterations = 10000L, ntrees = 100L, alpha = 1, operatorprob = list(pmutatemajor = 0.2, pmutateminor = 0.2, pcrossover = 0.2, psplit = 0.2, pprune = 0.2), seed = NULL, ...)
minbucket |
the minimum sum of weights in a terminal node. |
minsplit |
the minimum sum of weights in a node in order to be considered for splitting. |
maxdepth |
maximum depth of the tree. Note, that the memory requirements increase by the square of the m'aximum tree depth. |
niterations |
in case the run does not converge, it terminates
after a specified number of iterations defined by |
ntrees |
the number of trees in the population. |
alpha |
regulates the complexity part of the cost function. Increasing values of alpha encourage decreasing tree sizes. |
operatorprob |
list or vector of probabilities for the selection of variation operators. May also be specified partially in which case the default values are still used for the unspecified arguments. Always scaled to sum to 100 percent. |
seed |
an numeric seed to initialize the random number generator
(for reproducibility). By default the seed is randomly drawn using
|
... |
additional arguments. |
A list with the (potentially processed) control parameters.
The dataset contains data of past credit applicants. The applicants are rated as good or bad. Models of this data can be used to determine if new applicants present a good or bad credit risk.
data("GermanCredit")
data("GermanCredit")
A data frame containing 1,000 observations on 21 variables.
factor variable indicating the status of the existing checking account, with levels ... < 0 DM
, 0 <= ... < 200 DM
, ... >= 200 DM/salary for at least 1 year
and no checking account
.
duration in months.
factor variable indicating credit history, with levels no credits taken/all credits paid back duly
, all credits at this bank paid back duly
, existing credits paid back duly till now
, delay in paying off in the past
and critical account/other credits existing
.
factor variable indicating the credit's purpose, with levels car (new)
, car (used)
, furniture/equipment
, radio/television
, domestic appliances
, repairs
, education
, retraining
, business
and others
.
credit amount.
factor. savings account/bonds, with levels ... < 100 DM
, 100 <= ... < 500 DM
, 500 <= ... < 1000 DM
, ... >= 1000 DM
and unknown/no savings account
.
ordered factor indicating the duration of the current employment, with levels unemployed
, ... < 1 year
, 1 <= ... < 4 years
, 4 <= ... < 7 years
and ... >= 7 years
.
installment rate in percentage of disposable income.
factor variable indicating personal status and sex, with levels
male:divorced/separated
, female:divorced/separated/married
,
male:single
, male:married/widowed
and female:single
.
factor. Other debtors, with levels none
, co-applicant
and guarantor
.
present residence since?
factor variable indicating the client's highest valued property, with levels real estate
, building society savings agreement/life insurance
, car or other
and unknown/no property
.
client's age.
factor variable indicating other installment plans, with levels bank
, stores
and none
.
factor variable indicating housing, with levels rent
, own
and for free
.
number of existing credits at this bank.
factor indicating employment status, with levels unemployed/unskilled - non-resident
, unskilled - resident
, skilled employee/official
and management/self-employed/highly qualified employee/officer
.
Number of people being liable to provide maintenance.
binary variable indicating if the customer has a registered telephone number.
binary variable indicating if the customer is a foreign worker.
binary variable indicating credit risk, with levels good
and bad
.
The use of a cost matrix is suggested for this dataset. It is worse to class a customer as good when they are bad (cost = 5), than it is to class a customer as bad when they are good (cost = 1).
The original data was provided by:
Professor Dr. Hans Hofmann, Institut fuer Statistik und Oekonometrie, Universitaet Hamburg, FB Wirtschaftswissenschaften, Von-Melle-Park 5, 2000 Hamburg 13
The dataset has been taken from the UCI Repository Of Machine Learning Databases at
http://archive.ics.uci.edu/ml/.
data("GermanCredit") summary(GermanCredit) ## Not run: gcw <- array(1, nrow(GermanCredit)) gcw[GermanCredit$credit_risk == "bad"] <- 5 suppressWarnings(RNGversion("3.5.0")) set.seed(1090) gct <- evtree(credit_risk ~ . , data = GermanCredit, weights = gcw) gct table(predict(gct), GermanCredit$credit_risk) plot(gct) ## End(Not run)
data("GermanCredit") summary(GermanCredit) ## Not run: gcw <- array(1, nrow(GermanCredit)) gcw[GermanCredit$credit_risk == "bad"] <- 5 suppressWarnings(RNGversion("3.5.0")) set.seed(1090) gct <- evtree(credit_risk ~ . , data = GermanCredit, weights = gcw) gct table(predict(gct), GermanCredit$credit_risk) plot(gct) ## End(Not run)
The data was generated to simulate registration of high energy gamma particles in a Major Atmospheric Gamma-Ray Imaging Cherenkov (MAGIC) Gamma Telescope. The task is to distinguish gamma rays (signal) from hadronic showers (background).
data("MAGICGammaTelescope")
data("MAGICGammaTelescope")
A data frame containing 19,020 observations on 11 variables.
major axis of ellipse [mm].
minor axis of ellipse [mm].
10-log of sum of content of all pixels [in #phot].
ratio of sum of two highest pixels over fSize [ratio].
ratio of highest pixel over fSize [ratio].
distance from highest pixel to center, projected onto major axis [mm].
3rd root of third moment along major axis [mm].
3rd root of third moment along minor axis [mm].
angle of major axis with vector to origin [deg].
distance from origin to center of ellipse [mm].
binary variable class, with levels gamma
(signal) and hadron
(background).
Classifying a background event as signal is worse than classifying a signal event as background. For a meaningful comparison of different classifiers the use of an ROC curve with thresholds 0.01, 0.02, 0.05, 0.1, 0.2 is suggested.
The original data was provided by:
R. K. Bock, Major Atmospheric Gamma Imaging Cherenkov Telescope project (MAGIC), rkb '@' mail.cern.ch, https://magic.mppmu.mpg.de/
and was donated by:
P. Savicky, Institute of Computer Science, AS of CR, Czech Republic, savicky '@' cs.cas.cz
The dataset has been taken from the UCI Repository Of Machine Learning Databases at
http://archive.ics.uci.edu/ml/.
Bock, R.K., Chilingarian, A., Gaug, M., Hakl, F., Hengstebeck, T., Jirina, M., Klaschka, J., Kotrc, E., Savicky, P., Towers, S., Vaicilius, A., Wittek W. (2004). Methods for Multidimensional event Classification: a Case Study Using Images From a Cherenkov Gamma-Ray Telescope. Nuclear Instruments and Methods in Physics Research Section A: Accelerators, Spectrometers, Detectors and Associated Equipment, 516(1), 511–528.
P. Savicky, E. Kotrc (2004). Experimental Study of Leaf Confidences for Random Forest. In Proceedings of COMPSTAT, pp. 1767–1774. Physica Verlag, Heidelberg, Germany.
J. Dvorak, P. Savicky (2007). Softening Splits in Decision Trees Using Simulated Annealing. In Proceedings of the 8th International Conference on Adaptive and Natural Computing Algorithms, Part I, pp. 721–729, Springer-Verlag, New-York.
data("MAGICGammaTelescope") summary(MAGICGammaTelescope) ## Not run: suppressWarnings(RNGversion("3.5.0")) set.seed(1090) mgtt <- evtree(class ~ . , data = MAGICGammaTelescope) mgtt table(predict(mgtt), MAGICGammaTelescope$class) plot(mgtt) ## End(Not run)
data("MAGICGammaTelescope") summary(MAGICGammaTelescope) ## Not run: suppressWarnings(RNGversion("3.5.0")) set.seed(1090) mgtt <- evtree(class ~ . , data = MAGICGammaTelescope) mgtt table(predict(mgtt), MAGICGammaTelescope$class) plot(mgtt) ## End(Not run)
Models of this data predict the absence or presence of heart disease.
data("StatlogHeart")
data("StatlogHeart")
A data frame containing 270 observations on 14 variables.
age in years.
binary variable indicating sex.
factor variable indicating the chest pain type, with levels typical angina
, atypical angina
, non-anginal pain
and asymptomatic
.
resting blood pressure.
serum cholesterol in mg/dl.
binary variable indicating if fasting blood sugar > 120 mg/dl.
factor variable indicating resting electrocardiographic results, with levels 0
: normal, 1
: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV) and 2
: showing probable or definite left ventricular hypertrophy by Estes' criteria.
the maximum heart rate achieved.
binary variable indicating the presence of exercise induced angina.
oldpeak = ST depression induced by exercise relative to rest.
ordered factor variable describing the slope of the peak exercise ST segment, with levels upsloping
, flat
and downsloping
.
number of major vessels colored by flouroscopy.
factor variable thal, with levels normal
, fixed defect
and reversible defect
.
binary variable indicating the presence
or absence
of heart disease.
The use of a cost matrix is suggested for this dataset. It is worse to class patients with heart disease as patients without heart disease (cost = 5), than it is to class patients without heart disease as having heart disease (cost = 1).
The dataset has been taken from the UCI Repository Of Machine Learning Databases at
http://archive.ics.uci.edu/ml/.
data("StatlogHeart") summary(StatlogHeart) shw <- array(1, nrow(StatlogHeart)) shw[StatlogHeart$heart_disease == "presence"] <- 5 suppressWarnings(RNGversion("3.5.0")) set.seed(1090) sht <- evtree(heart_disease ~ . , data = StatlogHeart, weights = shw) sht table(predict(sht), StatlogHeart$heart_disease) plot(sht)
data("StatlogHeart") summary(StatlogHeart) shw <- array(1, nrow(StatlogHeart)) shw[StatlogHeart$heart_disease == "presence"] <- 5 suppressWarnings(RNGversion("3.5.0")) set.seed(1090) sht <- evtree(heart_disease ~ . , data = StatlogHeart, weights = shw) sht table(predict(sht), StatlogHeart$heart_disease) plot(sht)