在其中还包括了一个 print_tree() 函数,它能够递归地一行一个地打印出决策树的节点。经过它打印的不是一个明显的树结构,但它能给我们关于树结构的大致印象,并能帮助决策。 # Split a dataset based on an attribute and an attribute value def test_split(index, value, dataset): left, right = list(), list() for row in dataset: if row[index] < value: left.append(row) else: right.append(row) return left, right # Calculate the Gini index for a split dataset def gini_index(groups, class_values): gini = 0.0 for class_value in class_values: for group in groups: size = len(group) if size == 0: continue proportion = [row[-1] for row in group].count(class_value) / float(size) gini += (proportion * (1.0 - proportion)) return gini # Select the best split point for a dataset def get_split(dataset): class_values = list(set(row[-1] for row in dataset)) b_index, b_value, b_score, b_groups = 999, 999, 999, None for index in range(len(dataset[0])-1): for row in dataset: groups = test_split(index, row[index], dataset) gini = gini_index(groups, class_values) if gini < b_score: b_index, b_value, b_score, b_groups = index, row[index], gini, groups return {'index':b_index, 'value':b_value, 'groups':b_groups} # Create a terminal node value def to_terminal(group): outcomes = [row[-1] for row in group] return max(set(outcomes), key=outcomes.count) # Create child splits for a node or make terminal def split(node, max_depth, min_size, depth): left, right = node['groups'] del(node['groups']) # check for a no split if not left or not right: node['left'] = node['right'] = to_terminal(left + right) return # check for max depth if depth >= max_depth: node['left'], node['right'] = to_terminal(left), to_terminal(right) return # process left child if len(left) <= min_size: node['left'] = to_terminal(left) else: node['left'] = get_split(left) split(node['left'], max_depth, min_size, depth+1) # process right child if len(right) <= min_size: node['right'] = to_terminal(right) else: node['right'] = get_split(right) split(node['right'], max_depth, min_size, depth+1) # Build a decision tree def build_tree(train, max_depth, min_size): root = get_split(dataset) split(root, max_depth, min_size, 1) return root # Print a decision tree def print_tree(node, depth=0): if isinstance(node, dict): print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value']))) print_tree(node['left'], depth+1) print_tree(node['right'], depth+1) else: print('%s[%s]' % ((depth*' ', node))) dataset = [[2.771244718,1.784783929,0], [1.728571309,1.169761413,0], [3.678319846,2.81281357,0], [3.961043357,2.61995032,0], [2.999208922,2.209014212,0], [7.497545867,3.162953546,1], [9.00220326,3.339047188,1], [7.444542326,0.476683375,1], [10.12493903,3.234550982,1], [6.642287351,3.319983761,1]] tree = build_tree(dataset, 1, 1) print_tree(tree) 在运行过程中,我们能修改树的最大深度,并在打印的树上观察其影响。 当最大深度为 1 时(即调用 build_tree() 函数时第二个参数),我们可以发现该树使用了我们之前发现的完美分割点(作为树的唯一分割点)。该树只有一个节点,也被称为决策树桩。 [X1 < 6.642] [0] [1] (责任编辑:本港台直播) |