首页人工智能Scikit-learn5.使用交叉验证对鸢尾花分...

5.使用交叉验证对鸢尾花分类模型进行调参

1.交叉验证:

  • 多个超参数;
  • 使用网格搜索(grid search),通过遍历给定的参数组合来优化模型表现的方法;
  • sklearn中GridSearchCV()用于网格搜索;

例如:2个超参数C和gamma的grid search结果显示:

2.代码:

import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

data_file = './data/Iris.csv'

SPECIES_LABEL_DICT = {
	'Iris-setosa': 0,  # 山鸢尾
	'Iris-versicolor': 1,  # 变色鸢尾
	'Iris-virginica': 2  # 维吉尼亚鸢尾
}

# 使用的特征列
FEAT_COLS = ['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']


def main():
	"""
		主函数
	"""
	# 读取数据集
	iris_data = pd.read_csv(data_file, index_col='Id')
	iris_data['Label'] = iris_data['Species'].map(SPECIES_LABEL_DICT)

	# 获取数据集特征
	X = iris_data[FEAT_COLS].values

	# 获取数据标签
	y = iris_data['Label'].values

	# 划分数据集
	X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 / 3, random_state=10)

	model_dict = {'KNN': (KNeighborsClassifier(), {'n_neighbors': [3, 5, 7], 'p': [1, 2]}),
				  'Logistic Regression': (LogisticRegression(), {'C': [1e-2, 1, 1e2]}),
				  'SVM': (SVC(), {'C': [1e-2, 1, 1e2]})}

	for model_name, (model, model_pramas) in model_dict.items():
		# 训练模型,参数cv表示多少折
		clf = GridSearchCV(estimator=model, param_grid=model_pramas, cv=5)
		clf.fit(X_train, y_train)
		# 得到最优模型
		best_model = clf.best_estimator_

		# 评价模型
		accuracy = best_model.score(X_test, y_test)
		print('{}模型的预测准确率为{:.2f}%'.format(model_name, accuracy * 100))
		# 得到模型最优参数
		print('{}模型的最优参数为{}'.format(model_name, clf.best_params_))


if __name__ == '__main__':
	main()

3.输出结果:

KNN模型的预测准确率为96.00%
KNN模型的最优参数为{'n_neighbors': 3, 'p': 2}
Logistic Regression模型的预测准确率为96.00%
Logistic Regression模型的最优参数为{'C': 100.0}
SVM模型的预测准确率为98.00%
SVM模型的最优参数为{'C': 1}
RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments