-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
72 lines (55 loc) · 1.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from mllib.decisionTree import DecisionTree
import sys
import re
def train(training_set, option):
data = []
with open(training_set,'r') as f:
for line in f:
data.append(re.split(',| |\n',line)[:-1])
attributes = data[0]
target_attr = attributes[-1]
# Decision Tree algorithm starts here
dt = DecisionTree(option)
dt.fit(data[1:], attributes, target_attr )
return dt
def print_stats(dt, to_print, validation_set, test_set):
if to_print == "yes":
dt.stats()
dt.accuracy(test_set)
def main():
if len(sys.argv) != 7:
print('''
Invalid input arguments.
Please specify input as :
python main.py <L> <K> <training_set> <validation_set> <test_set> <to_print>
''')
sys.exit(1)
# read inputs from command line
val_l = int(sys.argv[1])
val_k = int(sys.argv[2])
train_set = sys.argv[3]
validation_set = sys.argv[4]
test_set = sys.argv[5]
to_print = sys.argv[6]
# train the decisionTree over the training data
print "informaiton gain decision tree"
dt1 = train(train_set, DecisionTree.ENTROPY)
dt1.accuracy(test_set)
print "AFTER PRUNING"
dt1.prune(validation_set, val_l, val_k)
dt1.accuracy(test_set)
print "variance impurity decision tree"
dt2 = train(train_set, DecisionTree.VARIANCE_IMPURITY)
dt2.accuracy(test_set)
print "AFTER PRUNING"
dt2.prune(validation_set, val_l, val_k)
dt2.accuracy(test_set)
if to_print == "yes":
print "\n Decision tree information gain "
dt1.stats()
print "\n Decision tree variance impurity "
dt2.stats()
if __name__ == '__main__':
main()