首页人工智能Scikit-learn1.使用scikit-le...

1.使用scikit-learn的欧式距离分类算法实现鸢尾花分类

数据来源:

本次scikit-learn系列教程使用经典的Iris数据集。它是常用的分类实验数据集,由Fisher在1936年收集整理。
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

  • 150个样本
  • 每个样本有4个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
  • 3个类别:变色鸢尾(Versicolor)、维吉尼亚鸢尾(Virginica)、山鸢尾(Setosa)

2.任务:

使用欧式距离分类算法实现鸢尾花分类。

3.方法:

使用机器学习工具库scikit-learn中的以下方法。

  • 数据集划分:train_test_split()
  • 计算空间中两个点的距离,即欧氏距离(近朱者赤):euclidean()

4.代码:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.spatial.distance import euclidean

import ai_utils

data_file = './data/Iris.csv'

# 使用特征列
feature_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']

# 类别列
Species = ['Iris-setosa',       # 山鸢尾
           'Iris-versicolor',   # 变色鸢尾
           'Iris-virginica'     # 维吉尼亚鸢尾
           ]

def get_predict_label(test_feature_cols,train_data):
	"""
		计算空间中两个点的欧氏距离
	"""
	distance_list = []

	for idx,row in train_data.iterrows():
		# 训练样本特征列
		train_feature_cols = row[feature_columns].values

		# 计算欧式距离
		distance = euclidean(test_feature_cols,train_feature_cols)
		distance_list.append(distance)

	# 获取最小距离所对应的位置
	position = np.argmin(distance_list)
	# 得到预测标签
	predict_label = train_data.iloc[position]['Species']

	return predict_label


def main():
	"""
		主函数
	"""
	# 读取数据,将Id列设置为索引列
	iris_data = pd.read_csv(data_file,index_col='Id')

	# EDA,查看数据分布
	ai_utils.do_eda_plot_for_iris(iris_data)

	# 划分数据集,test_size用于指定测试集的大小,random_state用于指定随机状态,通常设定一个固定的数字用于重复实验
	train_data,test_data = train_test_split(iris_data,test_size=1/3,random_state=100)

	# 预测对的个数
	accuracy_counts = 0

	# 构建分类器
	for idx,row in test_data.iterrows():
		# 测试样本特征列
		test_feature_cols = row[feature_columns].values

		# 预测标签
		predict_label = get_predict_label(test_feature_cols,train_data)

		# 真实标签
		real_label = row['Species']

		# 输出测试样本的预测标签和真实标签
		print('样本{}的真实标签为{} --> 预测标签为{}'.format(idx , real_label , predict_label))

		if real_label == predict_label:
			accuracy_counts += 1

	# 计算预测准确率
	accuracy = accuracy_counts / test_data.shape[0]
	print('预测准确率为:{:.2f}%'.format(accuracy*100))


if __name__ == '__main__':
	main()

5.输出结果:

样本129的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本12的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本119的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本16的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本124的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本136的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本33的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本2的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本117的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本46的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本41的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本116的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本27的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本29的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本146的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本98的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本63的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本78的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本123的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本113的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本126的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本32的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本147的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本30的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本70的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本150的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本76的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本21的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本74的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本121的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本82的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本100的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本120的真实标签为Iris-virginica --> 预测标签为Iris-versicolor
样本13的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本17的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本52的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本47的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本90的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本137的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本115的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本42的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本91的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本103的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本110的真实标签为Iris-virginica --> 预测标签为Iris-virginica
样本38的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本7的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本26的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本22的真实标签为Iris-setosa --> 预测标签为Iris-setosa
样本93的真实标签为Iris-versicolor --> 预测标签为Iris-versicolor
样本10的真实标签为Iris-setosa --> 预测标签为Iris-setosa
预测准确率为:98.00%

6.所需脚本ai_utils.py:

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import seaborn as sns


def do_eda_plot_for_iris(iris_data):
    """
        对鸢尾花数据集进行简单的可视化
        参数:
            - iris_data: 鸢尾花数据集
    """
    category_color_dict = {
        'Iris-setosa':      'red',      # 山鸢尾
        'Iris-versicolor':  'blue',     # 变色鸢尾
        'Iris-virginica':   'green'     # 维吉尼亚鸢尾
    }

    fig, axes = plt.subplots(2, 1, figsize=(8, 8))

    for category_name, category_color in category_color_dict.items():

        # 查看数据的萼片长度(SepalLengthCm)和萼片宽度(SepalWidthCm)
        iris_data[iris_data['Species'] == category_name].plot(ax=axes[0], kind='scatter',
                                                              x='SepalLengthCm', y='SepalWidthCm', label=category_name,
                                                              color=category_color)
        # 查看数据的花瓣长度(PetalLengthCm)和花瓣宽度(PetalWidthCm)
        iris_data[iris_data['Species'] == category_name].plot(ax=axes[1], kind='scatter',
                                                              x='PetalLengthCm', y='PetalWidthCm', label=category_name,
                                                              color=category_color)

    axes[0].set_xlabel('Sepal Length')
    axes[0].set_ylabel('Sepal Width')
    axes[0].set_title('Sepal Length vs Sepal Width')

    axes[1].set_xlabel('Petal Length')
    axes[1].set_ylabel('Petal Width')
    axes[1].set_title('Petal Length vs Petal Width')

    plt.tight_layout()
    plt.savefig('./iris_eda.png')
    plt.show()


def do_pair_plot_for_iris(iris_data):
    """
        对鸢尾花数据集的样本特征关系进行可视化
        参数:
            - iris_data: 鸢尾花数据集
    """
    g = sns.pairplot(data=iris_data[['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species']],
                     hue='Species')
    plt.tight_layout()
    plt.show()
    g.savefig('./iris_pairplot.png')


def plot_knn_boundary(knn_model, X, y, fig_title, save_fig):
    """
        绘制二维平面的kNN边界
        参数:
            knn_mode:   训练好的kNN模型
            X:          数据集特征
            y:          数据集标签
            fig_title:  图像名称
            save_fig:   保存图像的路径
    """
    h = .02  # step size in the mesh

    # Create color maps
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

    # point in the mesh [x_min, x_max]x[y_min, y_max].
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Z = knn_model.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
                edgecolor='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title(fig_title)

    plt.savefig(save_fig)

    plt.show()


def plot_feat_and_price(house_data):
    """
        绘制每列特征与房价的关系
        参数:
            -house_data: 房屋价格数据集
    """
    feat_cols = ['bedrooms', 'bathrooms', 'sqft_living', 'sqft_lot', 'sqft_above', 'sqft_basement']
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    for i, feat_col in enumerate(feat_cols):
        house_data[[feat_col, 'price']].plot.scatter(x=feat_col, y='price', alpha=0.5,
                                                     ax=axes[int(i / 3), i - 3 * int(i / 3)])
    plt.tight_layout()
    plt.savefig('./house_feat.png')
    plt.show()


def plot_fitting_line(linear_reg_model, X, y, fig_title, save_fig):
    """
        绘制线性拟合曲线
        参数:
            linear_reg_model:   训练好的线性回归模型
            X:                  数据集特征
            y:                  数据集标签
            fig_title:          图像名称
            save_fig:           保存图像的路径
    """
    # 线性回归模型的系数
    coef = linear_reg_model.coef_

    # 线性回归模型的截距
    intercept = linear_reg_model.intercept_

    # 绘制样本点
    plt.scatter(X, y, alpha=0.5)

    # 绘制拟合线
    plt.plot(X, X * coef + intercept, c='red')

    plt.title(fig_title)
    plt.savefig(save_fig)
    plt.show()
RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments