数据来源:
本次scikit-learn系列教程使用经典的Iris数据集。它是常用的分类实验数据集,由Fisher在1936年收集整理。
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
- 150个样本
- 每个样本有4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
- 3个类别:变色鸢尾(Versicolor)、维吉尼亚鸢尾(Virginica)、山鸢尾(Setosa)
2.任务:
使用欧式距离分类算法实现鸢尾花分类。
3.方法:
使用机器学习工具库scikit-learn中的以下方法。
- 数据集划分:train_test_split()
- 计算空间中两个点的距离,即欧氏距离(近朱者赤):euclidean()
4.代码:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.spatial.distance import euclidean
import ai_utils
data_file = './data/Iris.csv'
# 使用特征列
feature_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
# 类别列
Species = ['Iris-setosa', # 山鸢尾
'Iris-versicolor', # 变色鸢尾
'Iris-virginica' # 维吉尼亚鸢尾
]
def get_predict_label(test_feature_cols,train_data):
"""
计算空间中两个点的欧氏距离
"""
distance_list = []
for idx,row in train_data.iterrows():
# 训练样本特征列
train_feature_cols = row[feature_columns].values
# 计算欧式距离
distance = euclidean(test_feature_cols,train_feature_cols)
distance_list.append(distance)
# 获取最小距离所对应的位置
position = np.argmin(distance_list)
# 得到预测标签
predict_label = train_data.iloc[position]['Species']
return predict_label
def main():
"""
主函数
"""
# 读取数据,将Id列设置为索引列
iris_data = pd.read_csv(data_file,index_col='Id')
# EDA,查看数据分布
ai_utils.do_eda_plot_for_iris(iris_data)
# 划分数据集,test_size用于指定测试集的大小,random_state用于指定随机状态,通常设定一个固定的数字用于重复实验
train_data,test_data = train_test_split(iris_data,test_size=1/3,random_state=100)
# 预测对的个数
accuracy_counts = 0
# 构建分类器
for idx,row in test_data.iterrows():
# 测试样本特征列
test_feature_cols = row[feature_columns].values
# 预测标签
predict_label = get_predict_label(test_feature_cols,train_data)
# 真实标签
real_label = row['Species']
# 输出测试样本的预测标签和真实标签
print('样本{}的真实标签为{} --> 预测标签为{}'.format(idx , real_label , predict_label))
if real_label == predict_label:
accuracy_counts += 1
# 计算预测准确率
accuracy = accuracy_counts / test_data.shape[0]
print('预测准确率为:{:.2f}%'.format(accuracy*100))
if __name__ == '__main__':
main()
5.输出结果:
样本129的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本12的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本119的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本16的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本124的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本136的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本33的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本2的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本117的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本46的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本41的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本116的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本27的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本29的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本146的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本98的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本63的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本78的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本123的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本113的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本126的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本32的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本147的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本30的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本70的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本150的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本76的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本21的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本74的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本121的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本82的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本100的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本120的真实标签为Iris-virginica --> 预测标签为Iris-versicolor
样本13的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本17的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本52的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本47的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本90的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本137的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本115的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本42的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本91的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本103的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本110的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本38的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本7的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本26的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本22的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本93的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本10的真实标签为Iris-setosa --> 预测标签为Iris-setosa
预测准确率为:98.00%
6.所需脚本ai_utils.py:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import seaborn as sns
def do_eda_plot_for_iris(iris_data):
"""
对鸢尾花数据集进行简单的可视化
参数:
- iris_data: 鸢尾花数据集
"""
category_color_dict = {
'Iris-setosa': 'red', # 山鸢尾
'Iris-versicolor': 'blue', # 变色鸢尾
'Iris-virginica': 'green' # 维吉尼亚鸢尾
}
fig, axes = plt.subplots(2, 1, figsize=(8, 8))
for category_name, category_color in category_color_dict.items():
# 查看数据的萼片长度(SepalLengthCm)和萼片宽度(SepalWidthCm)
iris_data[iris_data['Species'] == category_name].plot(ax=axes[0], kind='scatter',
x='SepalLengthCm', y='SepalWidthCm', label=category_name,
color=category_color)
# 查看数据的花瓣长度(PetalLengthCm)和花瓣宽度(PetalWidthCm)
iris_data[iris_data['Species'] == category_name].plot(ax=axes[1], kind='scatter',
x='PetalLengthCm', y='PetalWidthCm', label=category_name,
color=category_color)
axes[0].set_xlabel('Sepal Length')
axes[0].set_ylabel('Sepal Width')
axes[0].set_title('Sepal Length vs Sepal Width')
axes[1].set_xlabel('Petal Length')
axes[1].set_ylabel('Petal Width')
axes[1].set_title('Petal Length vs Petal Width')
plt.tight_layout()
plt.savefig('./iris_eda.png')
plt.show()
def do_pair_plot_for_iris(iris_data):
"""
对鸢尾花数据集的样本特征关系进行可视化
参数:
- iris_data: 鸢尾花数据集
"""
g = sns.pairplot(data=iris_data[['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species']],
hue='Species')
plt.tight_layout()
plt.show()
g.savefig('./iris_pairplot.png')
def plot_knn_boundary(knn_model, X, y, fig_title, save_fig):
"""
绘制二维平面的kNN边界
参数:
knn_mode: 训练好的kNN模型
X: 数据集特征
y: 数据集标签
fig_title: 图像名称
save_fig: 保存图像的路径
"""
h = .02 # step size in the mesh
# Create color maps
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = knn_model.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
# Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
edgecolor='k', s=20)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.title(fig_title)
plt.savefig(save_fig)
plt.show()
def plot_feat_and_price(house_data):
"""
绘制每列特征与房价的关系
参数:
-house_data: 房屋价格数据集
"""
feat_cols = ['bedrooms', 'bathrooms', 'sqft_living', 'sqft_lot', 'sqft_above', 'sqft_basement']
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
for i, feat_col in enumerate(feat_cols):
house_data[[feat_col, 'price']].plot.scatter(x=feat_col, y='price', alpha=0.5,
ax=axes[int(i / 3), i - 3 * int(i / 3)])
plt.tight_layout()
plt.savefig('./house_feat.png')
plt.show()
def plot_fitting_line(linear_reg_model, X, y, fig_title, save_fig):
"""
绘制线性拟合曲线
参数:
linear_reg_model: 训练好的线性回归模型
X: 数据集特征
y: 数据集标签
fig_title: 图像名称
save_fig: 保存图像的路径
"""
# 线性回归模型的系数
coef = linear_reg_model.coef_
# 线性回归模型的截距
intercept = linear_reg_model.intercept_
# 绘制样本点
plt.scatter(X, y, alpha=0.5)
# 绘制拟合线
plt.plot(X, X * coef + intercept, c='red')
plt.title(fig_title)
plt.savefig(save_fig)
plt.show()