-
Notifications
You must be signed in to change notification settings - Fork 0
/
InfoGain.java
127 lines (125 loc) · 3.44 KB
/
InfoGain.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package ml_lab;
import java.util.*;
import java.sql.*;
public class InfoGain
{
static String db;
static String tab;
public double Entropy(double Sn,double Sp)
{
if(Sn==Sp)
return 1;
else if(Sn==0 || Sp==0)
return 0;
else
{
double a=Sn/(Sn+Sp),b=Sp/(Sn+Sp);double entropy;
entropy=(((-a)*(Math.log(a)/Math.log(2)))+((-b)*(Math.log(b)/Math.log(2))));
return entropy;
}
}
public double gain(double E,String att,Statement st,String target) throws SQLException
{
int i=0;double sum=0;double S[]=null;
double g=1;double tmp;
String que="Select COUNT(DISTINCT "+att+" )from "+tab+"";
ResultSet rs=st.executeQuery(que);
if(rs.next())
S=new double[rs.getInt(1)];
que="Select "+att+" ,COUNT(*) from "+tab+" GROUP BY "+att;
rs=st.executeQuery(que);
while(rs.next())
{
S[i]=rs.getDouble(2);
i++;
}
que="Select "+att+" ,COUNT(*) from "+tab+" where "+target+"='yes' GROUP BY "+att;
rs=st.executeQuery(que); i=0;
while(rs.next())
{
tmp=rs.getDouble(2);
sum=sum+(((double)S[i]/14 )*Entropy(rs.getDouble(2),S[i]-tmp));
i++;
}
g=E-sum;
return g;
}
public static void main(String args[])
{
Scanner obj=new Scanner(System.in);
int count=0,i=0,j=0;
String target=new String();
double E;String root=null;
InfoGain in=new InfoGain();
double tmp;String temp;
double max=0;
System.out.println("Java Program for Information Gain...");
System.out.println("Enter the database name..");
db=obj.nextLine();
System.out.println("Enter the table name..");
tab=obj.nextLine();
try
{
Class.forName("com.mysql.jdbc.Driver");
Connection conn = null;
conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/"+db+"?autoReconnect=true&useSSL=false","root","");
Statement st=conn.createStatement();
System.out.println("The Attributes are:");
String que="SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '"+tab+"' and TABLE_SCHEMA='"+db+"' ";
ResultSet rs=st.executeQuery(que);
while(rs.next())
{
System.out.print(rs.getString(1)+" ");
count++;
}
System.out.print("\n");
rs.close();
System.out.println("Enter the target attribute");
target=obj.nextLine();
que="SELECT "+target+" , COUNT(*) FROM "+tab+" GROUP BY "+target;
rs=st.executeQuery(que);
float S[]=new float[2];
while(rs.next())
{
S[i]=Float.parseFloat(rs.getString(2));
i++;
}
rs.close();
E=in.Entropy(S[0],S[1]);
System.out.println("Entropy :"+E);
String col[]=new String[count-1];
que="SELECT column_name FROM information_schema.columns WHERE table_name='"+tab+"'";
ResultSet rst=st.executeQuery(que);
i=0;
boolean t;
while(i<count-1)
{
rst.next();
if(rst.getString(1)!=target)
col[i]=rst.getString(1);
else
continue;
i++;
}
rst.close();
while(j<i)
{
temp=col[j];
System.out.println("Attribute :"+temp);
tmp=in.gain(E,temp,st,target);
System.out.println("Information Gain:"+tmp+"\n");
if(max<tmp)
{
max=tmp;
root=temp;
}
j++;
}
System.out.println("Root:"+root);
}
catch(Exception e)
{
System.out.println(e);
}
}
}