Bootstrap

决策树-信息增益的计算

import numpy as np
import pandas as pd

from collections import Counter
import math
from math import log

# 熵
# print(-(1 / 3) * log(1 / 3, 2) - (2 / 3) * log(2 / 3, 2))


def calc_ent(datasets):
    data_length = len(datasets)
    label_count = {}
    for i in range(data_length):
        label = datasets[i][-1]
        if label not in label_count:
            label_count[label] = 0
        label_count[label] += 1
    ent = -sum([(p / data_length) * log(p / data_length, 2)
                for p in label_count.values()])
    # print(ent)
    return ent

# 经验条件熵


def cond_ent(datasets, axis=0):
    data_length = len(datasets)
    feature_sets = {}
    for i in range(data_length):
        feature = datasets[i][axis]
        if feature not in feature_sets:
            feature_sets[feature] = []
        feature_sets[feature].append(datasets[i])
    cond_ent = sum([(len(p) / data_length) * calc_ent(p)
                    for p in feature_sets.values()])
    print(cond_ent)
    return cond_ent

# 信息增益


def info_gain(ent, cond_ent):
    return ent - cond_ent


def info_gain_train(datasets):
    count = len(datasets[0]) - 1
    print(count)
    ent = calc_ent(datasets)
    print(ent)
    best_feature = []
    for c in range(count):
        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))
        best_feature.append((c, c_info_gain))
        print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))
    # 比较大小
    best_ = max(best_feature, key=lambda x: x[-1])
    return '特征({})的信息增益最大,选择为根节点特征'.format(labels[best_[0]])

# labels = ["天气", "温度", "湿度", "刮风", '类别']
# datasets = pd.DataFrame([
#     ["晴", "高", "中", "否", '否'],
#     ["晴", "高", "中", "是", '否'],
#     ["阴天", "高", "高", "否", '是'],
#     ["雨", "高", "高", "否", '是'],
#     ["雨", "低", "高", "否", '否'],
#     ["晴", "中", "中", "是", '是'],
#     ["阴天", "中", "高", "是", '否'],
# ])


labels = ["天气", "湿度", "刮风", '类别']
datasets = pd.DataFrame([
    ["晴", "中", "否", '否'],
    ["晴", "中", "是", '否'],
    ["阴天", "高", "否", '是'],
    ["雨", "高", "否", '是']
])


print(datasets)
print(info_gain_train(np.array(datasets)))
;