Site icon R-bloggers

Create SQL Rules from rpart model

[This article was first published on R (en) - Analytik dat, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

Mapping output of rpart tree to SQL statements is not easy. In rpart package you have to print out rules and then manually write SQL CASE statement. Fortunately, we can write new function to do this job.

To test the function, I will use dataset german_data, located on github:

library(devtools)
install_github(repo="riv",username="tomasgreif")
library(riv)

First we create some (rather naive) model:

x <- german_data
x$gbbin <- NULL
model <- rpart(data=x,formula=gb~.)

The result model has a lot of leafs:

  1) root 1000 300 good (0.7000000 0.3000000)  
    2) ca_status=A13,A14 457  60 good (0.8687090 0.1312910) *
    3) ca_status=A11,A12 543 240 good (0.5580110 0.4419890)  
      6) mob< 22.5 306 106 good (0.6535948 0.3464052)  
       12) credit_history=A32,A33,A34 278  85 good (0.6942446 0.3057554)  
         24) credit_amount< 7491.5 271  79 good (0.7084871 0.2915129)  
           48) purpose=A40,A41,A410,A42,A43,A45,A48,A49 256  69 good (0.7304688 0.2695312)  
             96) mob< 11.5 73   9 good (0.8767123 0.1232877) *
             97) mob>=11.5 183  60 good (0.6721311 0.3278689)  
              194) credit_amount>=1387.5 118  29 good (0.7542373 0.2457627) *
              195) credit_amount< 1387.5 65  31 good (0.5230769 0.4769231)  
                390) property=A121,A122 45  14 good (0.6888889 0.3111111) *
                391) property=A123,A124 20   3 bad (0.1500000 0.8500000) *
           49) purpose=A44,A46 15   5 bad (0.3333333 0.6666667) *
         25) credit_amount>=7491.5 7   1 bad (0.1428571 0.8571429) *
       13) credit_history=A30,A31 28   7 bad (0.2500000 0.7500000) *
      7) mob>=22.5 237 103 bad (0.4345992 0.5654008)  
       14) savings=A64,A65 41  12 good (0.7073171 0.2926829) *
       15) savings=A61,A62,A63 196  74 bad (0.3775510 0.6224490)  
         30) mob< 47.5 160  69 bad (0.4312500 0.5687500)  
           60) purpose=A41 23   6 good (0.7391304 0.2608696) *
           61) purpose=A40,A410,A42,A43,A45,A46,A49 137  52 bad (0.3795620 0.6204380) *
         31) mob>=47.5 36   5 bad (0.1388889 0.8611111) *

Now, we can call function parse_tree:

parse_tree(x,model)

And we get the following result:

case  when ca_status in ('A13','A14') then 'node_2' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob <  11.5 then 'node_96' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount >= 1388 then 'node_194' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount <  1388 AND property in ('A121','A122') then 'node_390' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A40','A41','A410','A42','A43','A45','A48','A49') AND mob >= 11.5 AND credit_amount <  1388 AND property in ('A123','A124') then 'node_391' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount <  7492 AND purpose in ('A44','A46') then 'node_49' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A32','A33','A34') AND credit_amount >= 7492 then 'node_25' when ca_status in ('A11','A12') AND mob <  22.5 AND credit_history in ('A30','A31') then 'node_13' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A64','A65') then 'node_14' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob <  47.5 AND purpose in ('A41') then 'node_60' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob <  47.5 AND purpose in ('A40','A410','A42','A43','A45','A46','A49') then 'node_61' when ca_status in ('A11','A12') AND mob >= 22.5 AND savings in ('A61','A62','A63') AND mob >= 47.5 then 'node_31'  end

This is valid SQL that can be used in most database engines (I’m using this in SQLite and PostgreSQL).

The function parse_tree has two arguments – data frame and model. It is necessary that variables in model exist in data frame and are of the same type.  You can find parse_tree function on github. Let me know if this works for you.

To leave a comment for the author, please follow the link and comment on their blog: R (en) - Analytik dat.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.