Decision Tree Algorithm in Spark SQL
"""
Decision Tree Algorithm in Spark SQL
Written JP Vijaykumar
Date Sep 14 2020
This script is provided for educational purpose only.
Pls modify/change the script as may be required to suit your environment.
I presented a script to process decision tree algorithm using pl/sql earlier.
I like spark sql for the following reasons:
01) It is open source.
02) It combines the rich functionality of python and sql
03) It has the datamining libraries.
04) can be installed on my desktop and play around.
Besides, I love scripting and complex algorithms.
I used the following urls to install spark and setup spark on my desktop.
https://www.youtube.com/watch?v=IQfG0faDrzE
https://www.youtube.com/watch?v=WQErwxRTiW0
http://media.sundog-soft.com/spark-python-install.pdf
I save the following data in a csv file and processed the data using decision tree algorithm.
outlook,temperature,humidity,wind,playball
--------------,----------------------,----------------,---------,--------------
sunny,hot,high,weak,no
sunny,hot,high,strong,no
overcast,hot,high,weak,yes
rain,mild,high,weak,yes
rain,cool,normal,weak,yes
rain,cool,normal,strong,no
overcast,cool,normal,strong,yes
sunny,mild,high,weak,no
sunny,cool,normal,weak,yes
rain,mild,normal,weak,yes
sunny,mild,normal,strong,yes
overcast,mild,high,strong,yes
overcast,hot,normal,weak,yes
rain,mild,high,strong,no
This script takes the last column of the csv file as CLASSIFIER and the remaining columns as CANDIDATE columns
to identify the ROOT NODE and split further.
This script can not be run, as it is.
This script is dependent on the location of python-spark-tutorial folder and the data folder.
Also, the spark environment variables should be set properly.
Once the spark is installed and setup properly(as per the above referenced urls on installation and setup),
the script can be run.
If you have spark already installed on your machine, pls modify the csv file location and cd to python-spark-tutorial folder,save the
script under python-spark-tutorial/python folder with appropriate name and run the script.
The output of this script is lengthy and detailed. It can be modified to be concise, but for information and educational purposes, it
was left it as it is.
For most part of sql commands, I used the script from my previous post.
Wherever, I was using pl/sql programming, I switched to python scripting.
Few sql functions are not supported in spark sql, and I used equivalant spark functionality to accomplish the purpose.
There is further scope to tune the code and improve the formatting.
I was just curious to convert my pl/sql code on Decision Tree Algorithm into spark sql.
And I solved the data from csv file and got the output rules.
This script can be run in jupyter notebook interactively also.
version compatibility issues are plenty, added to non availability of some libraries. Finding work-arounds to meet the required
functionality is tricky and daunting.
Happy scripting.
"""
#################################SCRIPT STARTS HERE
#cd /c/python-spark-tutorial
#spark-submit.cmd python/pysparkDecisionTreeEntropy.py #to run the script from gitbash command prompt
from pyspark import SparkContext, SparkConf
conf = SparkConf().setAppName("DecisionTree").setMaster("local[*]")
sc = SparkContext(conf=conf)
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("DecisionTree").master("local[*]").getOrCreate()
spark.conf.set("spark.sql.crossJoin.enabled", "true") #To enable cartesian product in apache sql
from pyspark.sql.types import IntegerType,StringType
from pyspark.sql import functions as F
from pyspark.sql.window import Window
#import numpy as np
df = spark.read.format("csv").option("infer_schema","true").option("header","true").option("sep",",").load
("e:/data/decisiontree.csv")
print(df.printSchema())
print(df.show())
column_list = df.columns
print(column_list)
df.createOrReplaceTempView("dtree")
CLASSIFIER = df.columns[-1]
print(CLASSIFIER)
column_list.remove( CLASSIFIER) #remove classifier from columns' list
print(column_list)
#capture row count of DataFrame
spark.sql("select cast(count(*) as float) val from dtree").createOrReplaceTempView("rowcount") #TABLE rowcount
rowCount = spark.sql("select * from rowcount")
print(type(rowCount))
rowCount.show()
print(rowCount.collect())
rowCount.foreach(print)
print(rowCount.collect()[0][0])
numRows = df.count()
exitLimit = 0
print(rowCount)
print(rowCount.first())
spark.sql(" select "+ CLASSIFIER +",cast(count(*) as float) val from dtree group by "+ CLASSIFIER +" ").createOrReplaceTempView
("clscount") #tbl clsfcnt
#calculate entorpy on the DataFrame
QUERY=" with clscount as ( select clscount.val/rowcount.val val from clscount,rowcount )"
QUERY=QUERY + "select round((-1)*(sum((val)*(log(2,(val))))),4) val from clscount "
spark.sql(QUERY).createOrReplaceTempView("entropy") #TABLE entropy
spark.sql("select * from entropy").show()
MGAIN = float(0)
for COL in column_list:
print(COL)
###Assign multi-line values to a variable and execute through spark.sql
QUERY="with t1 as (select "+ COL +" col ,cast(count(*) as float) val from dtree group by "+ COL +"), "
QUERY=QUERY + " t2 as (select "+ COL +" col,"+ CLASSIFIER +" cls,cast(count(*) as float) val from dtree group by "+ COL +" ,"+
CLASSIFIER + " ), "
QUERY=QUERY + " t3 as (select t1.col,t1.val val1 ,t2.cls,t2.val val2 from t1,t2 where t1.col = t2.col ), "
QUERY=QUERY + " t4 as (select col,round((t3.val1/rowcount.val)*(sum((-1)*(t3.val2/t3.val1)*(log(2,(val2/val1))))),4) sum_val from
t3,rowcount "
QUERY=QUERY + " group by t3.col,t3.val1,t3.val2,rowcount.val) "
QUERY=QUERY + " select cast((entropy.val - sum(t4.sum_val)) as float) gain from t4,entropy group by entropy.val"
spark.sql(QUERY).show()
spark.sql(QUERY).filter(F.col("gain") >= MGAIN).show()
GAIN = spark.sql(QUERY).collect()[0][0]
print(GAIN)
if GAIN > MGAIN:
MGAIN = GAIN
ROOTNODE = COL
print(MGAIN,ROOTNODE)
print(MGAIN,ROOTNODE)
column_list.remove( ROOTNODE) #remove rootnode from columns' list
print(column_list)
QUERY="select t1.rnode,t1.val rn_val,t2.cls,t2.val cls_val from "
QUERY=QUERY + "( select "+ ROOTNODE +" rnode,count(*) val from dtree group by "+ ROOTNODE +") t1, "
QUERY=QUERY + "( select "+ ROOTNODE +" rnode,"+ CLASSIFIER +" cls,count(*) val from dtree "
QUERY=QUERY + " group by "+ ROOTNODE +","+ CLASSIFIER +") t2 where t1.rnode = t2.rnode "
spark.sql(QUERY).createOrReplaceTempView("rtree")
spark.sql("select * from rtree").show()
for CVAL in spark.sql("select * from rtree where rn_val = cls_val").collect():
print(ROOTNODE,"->",CVAL[0],"(",CVAL[1],")",CLASSIFIER,"->",CVAL[2],"(",CVAL[3],")")
PVAL=CVAL[0]
numRows -=CVAL[1] #reduce numRows with classified rows' count
if numRows > exitLimit :
QUERY=QUERY + " and t1.rnode <>'"+ PVAL +"' "
spark.sql(QUERY).createOrReplaceTempView("rtree")
spark.sql("select * from rtree order by rnode").show()
QUERY="with t1 as (select rnode, cast(cls_val/rn_val as float) val from rtree), "
QUERY=QUERY + " t2 as ( select rnode ,round((-1)*(sum((val)*(log(2,(val))))),4) gain from t1 group by rnode),"
QUERY=QUERY + " t3 as (select distinct rnode,rn_val from rtree) select t2.rnode,t3.rn_val,t2.gain from t2,t3 where t2.rnode =
t2.rnode"
spark.sql(QUERY).createOrReplaceTempView("gtree") #gain table
spark.sql("select * from gtree").show()
LNODESPLIT = []
for COL in column_list:
for RCOL in spark.sql("select distinct * from gtree ").collect():
print(ROOTNODE,"->",RCOL[0]," : ",COL)
QUERY="select "+ROOTNODE+" rnode,"+COL+" col,"+CLASSIFIER+" cls,count(*) val from dtree "
QUERY=QUERY +" where 1=1 and "+ROOTNODE+"='"+RCOL[0]+"' "
QUERY=QUERY + " group by "+ROOTNODE+","+COL+","+CLASSIFIER+" order by 1,2,3"
spark.sql(QUERY).createOrReplaceTempView("ctree")
spark.sql("select * from ctree").show()
QUERY="with t1 as (select rnode, col,count(*) val from ctree group by rnode,col) "
QUERY=QUERY + ", t2 as (select c.rnode,c.col,cls,round((-1)*(sum((c.val/t1.val)*(log(2,(c.val/t1.val))))),4) val from ctree c,t1 "
QUERY=QUERY + " where 1=1 and c.rnode = t1.rnode and c.col=t1.col group by c.rnode,c.col,c.cls,t1.val) "
QUERY=QUERY + ", t3 as (select rnode,col,sum(val) as val from t2 group by rnode,col)"
QUERY=QUERY + ", t4 as (select t1.rnode,t1.col,(t1.val/g.rn_val)*(t3.val) val,(g.gain - (t1.val/g.rn_val)*(t3.val)) gain from t1,t3,gtree g
where 1 = 1 "
QUERY=QUERY + " and t1.rnode=t3.rnode and t1.col = t3.col and t3.rnode=g.rnode order by 1,2) "
QUERY=QUERY + " select rnode ,sum(gain) gain from t4 group by rnode "
for lval in spark.sql(QUERY).collect():
print(lval[1],ROOTNODE,"->",lval[0]," : ",COL)
LNODESPLIT.append([lval[1],ROOTNODE,lval[0],COL]) #append values to a list
#supply the list of columns' names enclosed in "[]"
lf = sc.parallelize(LNODESPLIT).toDF(["gain","rootnode","rval","lnode"]) #convert list into DataFrame
lf.select("*").orderBy(F.desc("gain")).show() #sort orderBy desc
for LVAL in lf.select("*").orderBy(F.desc("gain")).collect():
if numRows > exitLimit :
print(LVAL[0])
QUERY="select t1.rnode,t1.lnode,t1.val rn_val,t2.cls,t2.val cls_val from "
QUERY=QUERY + "( select "+ ROOTNODE +" rnode,"+LVAL[3]+" lnode,count(*) val from dtree "
QUERY=QUERY + " where "+ROOTNODE+"='"+LVAL[2]+"' group by "+ ROOTNODE +","+LVAL[3]+" ) t1, "
QUERY=QUERY + "( select "+ ROOTNODE +" rnode,"+LVAL[3]+" lnode, "+ CLASSIFIER +" cls,count(*) val from dtree "
QUERY=QUERY + " where "+ROOTNODE+"='"+LVAL[2]+"' group by "+ ROOTNODE +","+LVAL[3]+","+ CLASSIFIER +") t2 "
QUERY=QUERY + " where t1.rnode = t2.rnode and t1.lnode = t2.lnode order by 1,2"
spark.sql(QUERY).createOrReplaceTempView("ltree")
spark.sql("select * from ltree").show()
for LLVAL in spark.sql("select * from ltree where rn_val = cls_val").collect():
if numRows > exitLimit :
print(ROOTNODE,"->",LLVAL[0]," : ",LVAL[3],"->",LLVAL[1],"(",LLVAL[2],") : ",CLASSIFIER,"->",LLVAL[3],"(",LLVAL[4],")")
numRows -=LLVAL[2] #reduce numRows with classified rows' count
spark.stop()
#################################SCRIPT ENDSS HERE
"""
References:
http://www.orafaq.com/node/3163
https://www.coursehero.com/file/17335804/Tutorial02/
OUTPUT FROM SCRIPT'S EXECUTION(omitted some parts as the original output is lengthy):
root
|-- outlook: string (nullable = true)
|-- temperature: string (nullable = true)
|-- humidity: string (nullable = true)
|-- wind: string (nullable = true)
|-- playball: string (nullable = true)
None
+--------+-----------+--------+------+--------+
| outlook|temperature|humidity| wind|playball|
+--------+-----------+--------+------+--------+
| sunny| hot| high| weak| no|
| sunny| hot| high|strong| no|
|overcast| hot| high| weak| yes|
| rain| mild| high| weak| yes|
| rain| cool| normal| weak| yes|
| rain| cool| normal|strong| no|
|overcast| cool| normal|strong| yes|
| sunny| mild| high| weak| no|
| sunny| cool| normal| weak| yes|
| rain| mild| normal| weak| yes|
| sunny| mild| normal|strong| yes|
|overcast| mild| high|strong| yes|
|overcast| hot| normal| weak| yes|
| rain| mild| high|strong| no|
+--------+-----------+--------+------+--------+
....
....
....
+------+
| val|
+------+
|0.9403|
+------+
....
....
....
0.24690000712871552 outlook
....
....
....
outlook -> overcast ( 4 ) playball -> yes ( 4 )
....
....
....
+-----+------+------+---+-------+
|rnode| lnode|rn_val|cls|cls_val|
+-----+------+------+---+-------+
|sunny| high| 3| no| 3|
|sunny|normal| 2|yes| 2|
+-----+------+------+---+-------+
outlook -> sunny : humidity -> high ( 3 ) : playball -> no ( 3 )
outlook -> sunny : humidity -> normal ( 2 ) : playball -> yes ( 2 )
6.58596
+-----+------+------+---+-------+
|rnode| lnode|rn_val|cls|cls_val|
+-----+------+------+---+-------+
| rain|strong| 2| no| 2|
| rain| weak| 3|yes| 3|
+-----+------+------+---+-------+
outlook -> rain : wind -> strong ( 2 ) : playball -> no ( 2 )
outlook -> rain : wind -> weak ( 3 ) : playball -> yes ( 3 )
"""
- jp_vijaykumar's blog
- Log in to post comments