王小新 编译自 Medium 量子位 出品 | 公众号 QbitAI Alexandre Attia是《辛普森一家》的狂热粉丝。他看了一系列辛普森剧集,想建立一个能识别其中人物的神经网络。 接下来让我们跟着他的文章来了解下该如何建立一个用于识别《辛普森一家》中各个角色的神经网络。 要实现这个项目不是很困难,可能会比较耗时,因为需要手动标注每个人物的多张照片。 目前在网上没有《辛普森一家》人物的训练数据集,所以我正在标注各类图片来构建训练数据集。这个数据集的第一个版本已经挂在Kaggle上了,将持续进行更新,希望这个数据集能帮到大家。 在学了用TensorFlow构建不同项目后,我决定用Keras,因为它比TensorFlow更为简单易上手,而且以TensorFlow作为后端,具有很强的兼容性。Keras是Francois Chollet用Python语言编写的一个深度学习库。 本文基于卷积神经网络(CNN)来完成此项目,CNN网络是一种能够学习许多特征的多层前馈神经网络。 准备数据集 该数据集目前有18类,有以下人物:Homer,Marge,Lisa,Bart,Burns,Grampa,Flanders,Moe,Krusty,Sideshow Bob,Skinner,Milhouse等。 我的目标是达到20类,当然类别越多越好。各类样本的大小不一,图片背景也不尽相同,主要是从第4至24季的剧集中提取出来的。 △部分人物的图片 在训练集中,每个人物各大约包括1000个样本(还在标注数据来达到这个数量)。每个人物不一定处于图像中间,有时周围还带有其他人物。 △人物的样本量分布 通过label_data.py函数,我们可以从AVI电影中标注数据:得到裁剪后的图片(左部分或右部分),或者完整版,然后仅需输入人物名称的一部分,如对Charles Montgomery Burns输入burns。 添加数据时,我也使用了Keras模型。对视频进行截图,每一帧可转化得到3张图片,分别是左部分、右部分和完整版,然后通过编写算法来分类每张图片。 之后,我检查了此算法的分类效果,虽然是手动的,但这是一个渐进的过程,速度将会不断提升,特别是对出现频率较低的小类别人物。 数据预处理 在预处理图片时,第一步是调整样本大小。为了节省数据内存,先将样本转换为float32类型,并除以255进行归一化。 然后,使用Keras的自带函数,将各类人物的标签从名字转换为数字,再利用one-hot编码转换成矢量: importkeras importcv2pic_size = 64num_classes = 10img = cv2.resize(img, (pic_size, pic_size)).astype( 'float32') / 255....y = keras.utils.to_categorical(y, num_classes) 进而,使用sklearn库的train_test_split函数,将数据集分成训练集和测试集。 构建模型 现在让我们开始进入最有趣的部分:定义网络模型。 首先,我们构建了一个前馈网络,包括4个带有ReLU激活函数的卷积层和一个全连接的隐藏层(随着数据量的增大,atv,可能会进一步加深网络)。 这个模型与Keras文档中的CIFAR示例模型比较相近,接下来还会使用更多数据对其他模型进行测试。我还在模型中加入了Dropout层来防止网络过拟合。在输出层中,使用softmax函数来输出各类的所属概率。 损失函数为分类交叉熵(Categorical Cross Entropy)。优化器optimizer使用了随机梯度下降中的RMS Prop方法,通过该权重临近窗口的梯度平均值来确定该点的学习率。 训练模型 这个模型在训练集上迭代训练了200次,其中批次大小为32。 由于目前的数据集样本不多,我还用了数据增强操作,使用Keras库可以很快地实现。 这实际上是对图片进行一些随机变化,如小角度旋转和加噪声等,所以输入模型的样本都不大相同。这有助于防止模型过拟合,提高模型的泛化能力。 datagen = ImageDataGenerator( featurewise_center= False, # set input mean to 0 over the datasetsamplewise_center= False, # set each sample mean to 0featurewise_std_normalization= False, # divide inputs by std samplewise_std_normalization= False, # divide each input by its stdrotation_range= 0, # randomly rotate images in the range width_shift_range= 0.1, # randomly shift images horizontally height_shift_range= 0.1, # randomly shift images vertically horizontal_flip= True, # randomly flip imagesvertical_flip= False) # randomly flip images (责任编辑:本港台直播) |