Bootstrap

机器学习——西瓜树决策树id3算法,matlab代码不能运行,来砍我

决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会导致递归返回:

当前结点包含的样本全属于同一类别,无需划分

当前属性集为空,或是所有样本在所有属性上取值相同,无法划分

(3)当前结点包含的样本集合为空,不能划分。

在第(2)中情况下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别,即在利用当前结点的后验分布;

在第(3)种情况下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别,即把父结点的样本分布作为当前结点的先验分布。

ID3(Iterative Dichotomiser 3)是一种经典的决策树学习算法,由 Ross Quinlan 在 1986 年提出。ID3 算法主要用于解决分类问题,它通过对数据集进行递归划分来构建决策树。

ID3 算法的基本思想是在每个节点上选择最佳的特征进行分割,以使得得到的子集尽可能地“纯净”。纯净度通常用信息增益(Information Gain)或基尼指数(Gini Index)等指标来衡量,这些指标可以反映数据集的纯度或不确定性程度。

ID3 算法的步骤如下:

  1. 若所有实例属于同一类,则将当前节点标记为叶节点,并以该类别作为节点的类别标签。
  2. 若特征集为空集,或者当前节点的所有实例属于同一类,则将当前节点标记为叶节点,并以当前节点中实例数最多的类别作为节点的类别标签。
  3. 否则,计算每个特征的信息增益(或基尼指数),选择信息增益(或基尼指数)最大的特征作为当前节点的划分特征。
  4. 根据选定的划分特征将数据集划分为多个子集,并为每个子集递归地应用上述步骤,构建子节点。

matlab代码如下:main.m

clc;clear;

data_name = 'xigua'; %数据名称,

data_r = 'csv';     %数据格式 
dir_ = cd;          %目录,默认同文件下 

%% 数据预处理
filename = fullfile([dir_ '\' data_name '.' data_r]);%文件名
% 获取属性标签
data = readtable(filename,"VariableNamingRule","preserve");
size_data = size(data); %数据大小
if isempty(data.Properties.VariableDescriptions) %英文属性值,无描述
  
labels = data.Properties.VariableNames(1,1:size(data,2)-1); %获取属性值,必须是英文
else %使用原始列标题以支持中文属性值
   labels = cell(1,size_data(2)-1);
   for i = 1:size_data(2)-1  
        VariableDescriptions = data.Properties.VariableDescriptions;%获取原始名称
      labels{i} = VariableDescriptions{i}(9:length(VariableDescriptions{1})-1);%添加标签 
    end
end

% 获取数据集
opts = detectImportOptions(filename);%检查数据
opts = setvartype(opts,opts.VariableNames,'char');
data = readtable(filename,opts) %读入数据
dataset = data{:,:}; %获取数据集
% 调用函数
myTree = ID3(dataset,labels);%生成决策树,并画出来

另创一个文件ID3.m,  ID3代码:


function myTree = ID3(dataset,labels)
% 输入参数:
% dataset:数据集,元胞数组或字符串数组
% labels:属性标签,元胞数组或字符串数组 

myTree = createTree(dataset,labels); %生成决策树
[nodeids,nodevalue,branchvalue] = print_tree(myTree); %解析决策树
tree_plot(nodeids,nodevalue,branchvalue); %画出
end

%% 使用熵最小策略构建决策树
function myTree = createTree(dataset,labels)

% 数据为空,则报错
if(isempty(dataset))
    error('必须提供数据!')
end
size_data = size(dataset);
% 数据大小与属性数量不一致,则报错
if (size_data(2)-1)~=length(labels)
    error('属性数量与数据集不一致!')
end

classList = dataset(:,size_data(2));
%全为同一类,熵为0,返回
if length(unique(classList))==1
    myTree =  char(classList(1));
    return 
end
%%属性集为空,应该用找最多数的那一类,这里取值NONE
if size_data(2) == 1
    myTree =  'NONE';
    %myTree =  char(classList(1));
    return
end
% 选取特征属性
bestFeature = chooseFeature(dataset); 
bestFeatureLabel = char(labels(bestFeature));
% 构建树
myTree = containers.Map;
leaf = containers.Map;
% 该属性下的不同取值 
featValues = dataset(:,bestFeature); 
uniqueVals = unique(featValues);
% 删除该属性
labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性
% 对该属性下不同取值,递归调用ID3函数
for i=1:length(uniqueVals)
    subLabels = labels(:)';
    value = char(uniqueVals(i));
    subdata = splitDataset(dataset,bestFeature,value);%数据集分割
    leaf(value) = createTree(subdata,subLabels); %递归调用
    myTree(char(bestFeatureLabel)) = leaf;
end
end

%% 计算信息熵
function shannonEnt = calShannonEnt(dataset)
data_size = size(dataset);
labels = dataset(:,data_size(2));
numEntries = data_size(1);
labelCounts = containers.Map;
for i = 1:length(labels)
    label = char(labels(i));
    if labelCounts.isKey(label)
        labelCounts(label) = labelCounts(label)+1; 
    else
        labelCounts(label) = 1;
    end  
end
shannonEnt = 0.0;
for key = labelCounts.keys
    key = char(key);
    labelCounts(key);
    prob = labelCounts(key) / numEntries;
    shannonEnt = shannonEnt - prob*(log(prob)/log(2));
end  
end

% 选择熵最小的属性特征
function bestFeature=chooseFeature(dataset,~)
baseEntropy = calShannonEnt(dataset);
data_size = size(dataset);
numFeatures = data_size(2) - 1;
minEntropy = 2.0;
bestFeature = 0;
for i = 1:numFeatures
    uniqueVals = unique(dataset(:,i));
    newEntropy = 0.0;
    for j=1:length(uniqueVals)
        value = uniqueVals(j);
        subDataset = splitDataset(dataset,i,value);
        size_sub = size(subDataset);
        prob = size_sub(1)/data_size(1);
     
        newEntropy = newEntropy + prob*calShannonEnt(subDataset);
    end

    if newEntropy<minEntropy
        minEntropy = newEntropy;
        bestFeature = i;
    end
end
end
% 分割数据集,取出该特征值为value的所有样本,并去除该属性
function subDataset = splitDataset(dataset,axis,value)
subDataset = {};
data_size = size(dataset);
for i=1:data_size(1)
    data = dataset(i,:);
    if string(data(axis)) == string(value)
        subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
    end
end
end
% 层序遍历决策树,返回nodeids,nodevalue,branchvalue
function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
nodeids(1) = 0;
nodeid = 0;
nodevalue={};
branchvalue={};

queue = {tree} ;%创建队列
while ~isempty(queue)
    node = queue{1}; %取数据
    queue(1) = []; %出队
    if string(class(node))~="containers.Map" %叶节点
        nodeid = nodeid+1;
        nodevalue = [nodevalue,{node}];
    elseif length(node.keys)==1 %节点
        nodevalue = [nodevalue,node.keys];
        node_info = node(char(node.keys));
        nodeid = nodeid+1;
        branchvalue = [branchvalue,node_info.keys];
        for i=1:length(node_info.keys)
            nodeids = [nodeids,nodeid];
        end       
    end
    
    if string(class(node))=="containers.Map" 
        keys = node.keys();
        for i = 1:length(keys)
            key = keys{i};         
            queue=[queue,{node(key)}]; %入队
        end
    end
nodeids_=nodeids;
nodevalue_=nodevalue;
branchvalue_ = branchvalue;
end
end
%% 参考treeplot,画图
function tree_plot(p,nodevalue,branchvalue)

[x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
f = find(p~=0); %非0节点
pp = p(f); %非0值
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];
X = X(:);
Y = Y(:);
n = length(p);
if n<500
    hold on;
    %plot(x,y,'ro',X,Y,'r-')
    set(gcf,'Position',get(0,'ScreenSize'))
    plot(X,Y,'r-');
    nodesize = length(x);
    for i=1:nodesize
        t = text(x(i),y(i),nodevalue{1,i},'HorizontalAlignment','center');      
        t.EdgeColor = 'blue';
        t.BackgroundColor = 'w';
    end
    for i=2:nodesize
        j = 3*i-5;%获取连线坐标
        t=text((X(j)+X(j+1))/2,(Y(j)+Y(j+1))/2,branchvalue{1,i-1},'HorizontalAlignment','center');
        t.BackgroundColor = 'w';
    end
    hold off
else
    plot(X,Y,'r-');
end
xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);
end                

数据集用excel表格,表格名称用xigua.csv

 

仔细按照流程去设置,不能运行,请来砍我!!!!!!!!!

 

;