1.准备数据集:
本次以图像三分类为例,准备猫、狗、熊猫三种动物的图片数据(每种各1000张图片),依次存放在’./dataset/cats’、’./dataset/dogs’、’./dataset/pandas’文件夹中。
2.训练模型:
# 导入所需工具包
import matplotlib
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from keras import initializers
from keras import regularizers
import utils_paths # 主要用于图像路径处理操作,具体代码参考最后的附录
import matplotlib.pyplot as plt
import numpy as np
import argparse
import random
import pickle
import cv2
import os
print("------开始读取数据------")
data = []
labels = []
# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images('./dataset')))
random.seed(42)
random.shuffle(imagePaths)
# 遍历读取数据
for imagePath in imagePaths:
# 读取图像数据,由于使用神经网络,需要给拉平成一维
image = cv2.imread(imagePath)
image = cv2.resize(image, (32, 32)).flatten()
data.append(image)
# 读取标签
label = imagePath.split(os.path.sep)[-2]
labels.append(label)
# 对图像数据做scale操作
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# 切分数据集
(trainX, testX, trainY, testY) = train_test_split(data,
labels, test_size=0.25, random_state=42)
# 转换标签为one-hot encoding格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 构造网络模型结构:本次为3072-128-64-3
model = Sequential()
# kernel_regularizer=regularizers.l2(0.01) L2正则化项
# initializers.TruncatedNormal 初始化参数方法,截断高斯分布
model.add(Dense(128, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(64, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
# 初始化参数
INIT_LR = 0.001
EPOCHS = 2000
# 模型编译
print("------准备训练网络------")
opt = SGD(lr=INIT_LR)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
# 拟合模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),
epochs=EPOCHS, batch_size=32)
# 测试网络模型
print("------正在评估模型------")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=lb.classes_))
# 绘制结果曲线
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N[1500:], H.history["accuracy"][1500:], label="train_acc")
plt.plot(N[1500:], H.history["val_accuracy"][1500:], label="val_acc")
plt.title("Training and Validation Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig('./output/simple_nn_plot_acc.png')
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.title("Training and Validation Loss (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend()
plt.savefig('./output/simple_nn_plot_loss.png')
# 保存模型到本地
print("------正在保存模型------")
model.save('././output/simple_nn.model')
f = open('./output/simple_nn_lb.pickle', "wb") # 保存标签数据
f.write(pickle.dumps(lb))
f.close()
运行得到如下文件数据:
3.加载模型进行预测:
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2
# 加载测试数据并进行相同预处理操作
image = cv2.imread('./cs_image/panda.jpg')
output = image.copy()
image = cv2.resize(image, (32, 32))
# scale图像数据
image = image.astype("float") / 255.0
# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))
# 读取模型和标签
print("------读取模型和标签------")
model = load_model('./output/simple_nn.model')
lb = pickle.loads(open('./output/simple_nn_lb.pickle', "rb").read())
# 预测
preds = model.predict(image)
# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]
# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)
最终得到预测结果:
4.附录:
utils_paths.py
代码如下:
import os
image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
def list_images(basePath, contains=None):
# 返回有效的图片路径数据集
return list_files(basePath, validExts=image_types, contains=contains)
def list_files(basePath, validExts=None, contains=None):
# 遍历图片数据目录,生成每张图片的路径
for (rootDir, dirNames, filenames) in os.walk(basePath):
# 循环遍历当前目录中的文件名
for filename in filenames:
# if the contains string is not none and the filename does not contain
# the supplied string, then ignore the file
if contains is not None and filename.find(contains) == -1:
continue
# 通过确定.的位置,从而确定当前文件的文件扩展名
ext = filename[filename.rfind("."):].lower()
# 检查文件是否为图像,是否应进行处理
if validExts is None or ext.endswith(validExts):
# 构造图像路径
imagePath = os.path.join(rootDir, filename)
yield imagePath