sklearn鸢尾花数据集的分类

摘要
本文基于sklearn鸢尾花数据集,通过数据可视化、统计分析和机器学习算法建立了分类模型。利用matplotlib实现特征分布可视化,采用KNN分类器进行模型训练,最终达到97.78%的测试准确率,为植物分类问题提供有效的数学建模方法。

  1. 问题重述
    鸢尾花分类任务需要根据花卉的形态学特征(萼片长度、萼片宽度、花瓣长度、花瓣宽度)准确识别其所属品种(Setosa、Versicolour、Virginica)。建立可靠的数学模型对植物分类研究和自动化识别具有重要意义。

  2. 数据预处理

2.1 数据描述
数据集包含150个样本,每个样本具有4个特征:
$$X = {x_1^{(sepal\ length)}, x_2^{(sepal\ width)}, x_3^{(petal\ length)}, x_4^{(petal\ width)}}$$
目标变量为三类标签:
$$y \in {0^{(Setosa)}, 1^{(Versicolour)}, 2^{(Virginica)}}$$

2.2 数据可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

iris = load_iris()
features = iris.data.T

plt.figure(figsize=(12,6))
plt.suptitle("Feature Distribution by Class")

for i in range(4):
plt.subplot(2,2,i+1)
for c in range(3):
plt.hist(features[i][iris.target==c], alpha=0.7,
label=iris.target_names[c])
plt.xlabel(iris.feature_names[i])
plt.legend()
plt.tight_layout()
plt.show()

feature dist

  1. 模型建立

3.1 K最近邻算法
选择KNN分类器,其决策函数为:
$$\hat{y} = \mathop{\arg\min}\limits_{c} \sum_{i=1}^k \mathbb{I}(y_i = c)$$
其中距离度量采用欧式距离:
$$d(x^{(i)},x^{(j)}) = \sqrt{\sum_{m=1}^4 (x_m^{(i)} - x_m^{(j)})^2}$$

3.2 模型训练

1
2
3
4
5
6
7
8
9
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)

knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
  1. 结果分析

4.1 分类性能
测试集混淆矩阵可视化:

1
2
3
4
5
6
7
8
from sklearn.metrics import ConfusionMatrixDisplay

plt.figure(figsize=(8,6))
ConfusionMatrixDisplay.from_estimator(knn, X_test, y_test,
display_labels=iris.target_names,
cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()

confusion matrix

4.2 决策边界
二维特征投影可视化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X = iris.data[:, :2] # 只使用前两个特征
y = iris.target

# 创建并训练KNN模型
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X, y)

# 创建网格点
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))

# 预测网格点的类别
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 绘制决策边界
plt.figure(figsize=(10,6))
plt.contourf(xx, yy, Z, alpha=0.4,
cmap=ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']))
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k',
cmap=ListedColormap(['red', 'green', 'blue']))
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.title("KNN Decision Boundary (Sepal features)")
plt.show()

decision boundary

  1. 结论
    本研究表明,基于KNN算法的分类模型在鸢尾花数据集上表现优异,测试准确率达97.78%。通过特征可视化发现花瓣尺寸具有更好的类别区分度,后续研究可考虑特征加权优化提升模型性能。

注:实际使用时需确保已安装所需库:

1
pip install scikit-learn matplotlib numpy
Donate
  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.
  • Copyrights © 2023-2025 John Doe
  • Visitors: | Views:

请我喝杯茶吧~

支付宝
微信