首页人工智能Scikit-learn3.kNN算法的超参数对鸢...

3.kNN算法的超参数对鸢尾花分类的影响

此篇主要是看下KNN算法中k值(n_neighbors)的对分类结果的影响。

1.代码:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import ai_utils


data_file = './data/Iris.csv'

# 使用特征列
feature_columns = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']

# 类别列
Species_dict = {
	'Iris-setosa':0,  # 山鸢尾
	'Iris-versicolor':1,  # 变色鸢尾
	'Iris-virginica':2 # 维吉尼亚鸢尾
}


def investigate_knn(iris_data,sel_cols,k):
	# 获取数据集特征
	X = iris_data[sel_cols].values
	# 获取数据标签
	y = iris_data['Label'].values

	# 划分数据集
	X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 / 3, random_state=10)

	# 声明模型
	knn_model = KNeighborsClassifier(n_neighbors=k)
	# 训练模型
	knn_model.fit(X_train, y_train)
	# 评价模型
	accuracy = knn_model.score(X_test, y_test)
	print('k={}时,模型准确率为:{:.2f}'.format(k , accuracy * 100))

	# 可视化边界
	ai_utils.plot_knn_boundary(knn_model,X_test, y_test,'SepalLengthCm vs SepalWidthCm , k={}',format(k),save_fig='k={}'.format(k))


def main():
	"""
		主函数
	"""
	# 读取数据,将Id列设置为索引列
	iris_data = pd.read_csv(data_file,index_col='Id')
	# 使用map()函数将类别转变为相对应的数字
	iris_data['Label'] = iris_data['Species'].map(Species_dict)

	# k值列表
	k_list = [3,5,10]
	# 特征列,如需可视化只能选择两个特征,做二维可视化
	sel_cols = ['SepalLengthCm', 'SepalWidthCm']
	for k in k_list:
		investigate_knn(iris_data,sel_cols,k)

if __name__ == '__main__':
	main()

2.输出结果:

k=3时,模型准确率为:66.00
k=5时,模型准确率为:68.00
k=10时,模型准确率为:78.00

可以看到,k值越大,分类边界越光滑。
那么k值如何确定那?后续会讲到将采用交叉验证的方法得到。

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments