Bootstrap

【监督学习】K 邻近算法步骤及matlab实现

(三)K 邻近算法

K 近邻算法(KNN,K-Nearest Neighbors)是一种简单且直观的监督学习方法,可用于分类和回归任务。它的工作原理是基于距离度量来找到与待预测样本最接近的K个训练样本,并根据这些“邻居”的信息来进行预测。KNN 的特点:

  • 非参数化方法:KNN 不假设数据服从任何特定的概率分布形式,因此适用于多种不同类型的数据集;
  • 懒惰学习:KNN 在训练阶段不做太多工作,只是简单地存储训练样本;所有的计算都在预测时进行;
  • 易于理解和实现:算法逻辑直接,容易上手。
    K 近邻算法根据不同的 K 的分类情况

1.算法步骤

开始
数据准备: 加载数据集
划分训练集和测试集如: 80%训练, 20%测试
是否需要数据标准化?
数据标准化
计算距离: 测试样本与所有训练样本的距离
选择距离度量如: 欧氏距离, 曼哈顿距离
计算并存储所有距离
排序与选择: 按距离升序排序, 选择前K个最近邻
投票决策: 统计前K个样本的类别标签, 多数表决
输出结果: 返回预测类别
结束
  1. 数据准备

    • 目标:确保数据格式正确,划分为训练集和测试集;
    • 关键步骤
      • 标准化:若特征量纲差异大,需标准化(如 Z-score)以避免距离计算偏差;
      • 划分数据集:通常按比例(如 8:2)随机分割,避免过拟合。
  2. 计算距离

    • 目标:量化测试样本与训练样本的相似性。
    • 常用距离公式
      • 欧氏距离(默认): d = ∑ i = 1 n ( x i − y i ) 2 d=\sqrt{\sum_{i=1}^{n}(x_i-y_i)^2} d=i=1n(xiyi)2
      • 曼哈顿距离 d = ∑ i = 1 n ∣ x i − y i ∣ d=\sum_{i=1}^{n}\lvert x_i-y_i \lvert d=i=1nxiyi
    • 实现:对每个测试样本,计算与所有训练样本的距离。
  3. 排序与选择

    • 目标:找到距离最近的 K 个样本;
    • 操作
      • 对距离数组升序排序;
      • 提取前 K 个样本的索引和标签。
  4. 投票决策

    • 目标:根据 K 个最近邻的标签确定预测类别;
    • 规则:多数表决(出现平票时可随机选择或加权投票)。
  5. 输出结果

    • 返回测试样本的预测标签。

2. MATLAB 实现

某电商平台希望根据客户的 历史行为数据 将其分为 高价值中价值低价值 三类,以便差异化运营。数据特征包括:

  • 最近购买天数(Recency)
  • 过去一年购买次数(Frequency)
  • 过去一年消费总额(Monetary)
  • 平均浏览时长(分钟)

目标变量:客户价值标签( 0=低价值1=中价值2=高价值
K 邻近算法分类客户价值

%% K 邻近算法根据历史行为数据判断客户价值
clc; clear; close all;

%% 1. 生成模拟电商数据(修正标签逻辑)
rng(42); 
num_customers = 1000;

% 生成特征数据(三类客户)
Recency = [abs(randn(300,1)*30 + 10);       % 高价值客户(标签2)
           abs(randn(400,1)*60 + 50);      % 中价值客户(标签1)
           abs(randn(300,1)*100 + 80)];    % 低价值客户(标签0)
Frequency = [abs(randn(300,1)*3 + 12);
             abs(randn(400,1)*5 + 8);
             abs(randn(300,1)*7 + 3)];
Monetary = [abs(randn(300,1)*0.3 + 2.5);
            abs(randn(400,1)*0.5 + 1.5);
            abs(randn(300,1)*0.8 + 0.5)];
BrowsingTime = [abs(randn(300,1)*5 + 20);
                abs(randn(400,1)*8 + 15);
                abs(randn(300,1)*10 + 5)];

% 合并特征并添加标签(三分类)
X = [Recency, Frequency, Monetary, BrowsingTime];
y = [2*ones(300,1);    % 高价值(标签2)
      ones(400,1);     % 中价值(标签1)
      zeros(300,1)];   % 低价值(标签0)

% 打乱数据顺序
shuffle_idx = randperm(num_customers);
X = X(shuffle_idx, :);
y = y(shuffle_idx);

%% 2. 数据标准化(强制标准化)
X_scaled = zscore(X); 

%% 3. 划分训练集与测试集(80%训练,20%测试)
train_ratio = 0.8;
train_size = floor(train_ratio * num_customers);
X_train = X_scaled(1:train_size, :);
y_train = y(1:train_size);
X_test = X_scaled(train_size+1:end, :);
y_test = y(train_size+1:end);

%% 4. 手动实现KNN算法
K = 15; % 近邻数
distance_metric = 'euclidean'; % 距离度量
y_pred = zeros(size(y_test));

for i = 1:size(X_test,1)
    % 计算距离(与所有训练样本)
    if strcmp(distance_metric, 'euclidean')
        distances = sqrt(sum((X_train - X_test(i,:)).^2, 2));
    elseif strcmp(distance_metric, 'manhattan')
        distances = sum(abs(X_train - X_test(i,:)), 2);
    end
    
    % 按距离排序并选择前K个
    [~, sorted_idx] = sort(distances);
    k_nearest_indices = sorted_idx(1:K);
    
    % 投票决策(多数表决)
    k_labels = y_train(k_nearest_indices);
    [unique_labels, ~, label_counts] = unique(k_labels);
    [max_count, max_idx] = max(histcounts(label_counts, length(unique_labels)));
    y_pred(i) = unique_labels(max_idx);
end

%% 5. 模型评估
% 计算准确率
accuracy = sum(y_pred == y_test) / numel(y_test);
fprintf('模型准确率: %.2f%%\n', accuracy * 100);

% 绘制混淆矩阵
classes = unique(y);
conf_mat = zeros(length(classes));
for i = 1:length(classes)
    for j = 1:length(classes)
        conf_mat(i,j) = sum(y_test == classes(i) & y_pred == classes(j));
    end
end

% 可视化混淆矩阵
figure;
imagesc(conf_mat);
colormap(jet);
colorbar;
xticks(1:length(classes));
yticks(1:length(classes));
xticklabels({'低价值','中价值','高价值'});
yticklabels({'低价值','中价值','高价值'});
xlabel('预测类别');
ylabel('真实类别');
title(['KNN分类混淆矩阵(K=',num2str(K),',模型准确率:',num2str(accuracy * 100),'%)']);

% 添加数值标签
for i = 1:length(classes)
    for j = 1:length(classes)
        text(j, i, num2str(conf_mat(i,j)),...
            'HorizontalAlignment', 'center',...
            'Color', 'white');
    end
end

参考资料

[1][5分钟学算法] #01 k近邻法_哔哩哔哩_bilibili
[2]【小萌五分钟】机器学习 | K近邻算法 KNN_哔哩哔哩_bilibili
[3]KNN简介_哔哩哔哩_bilibili

;