本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

wzatv:一文看懂迁移学习:怎样用预训练模型搞定深度(3)

时间:2017-07-02 22:35来源:本港台现场报码 作者:开奖直播现场 点击:
在这种情况下,因为数据与预训练模型的训练数据相似度很高,因此我们不需要重新训练模型。我们只需要将输出层改制成符合问题情境下的结构就好。

在这种情况下,因为数据与预训练模型的训练数据相似度很高,因此我们不需要重新训练模型。我们只需要将输出层改制成符合问题情境下的结构就好。

我们使用预处理模型作为模式提取器。

比如说我们使用在ImageNet上训练的模型来辨认一组新照片中的小猫小狗。在这里,需要被辨认的图片与ImageNet库中的图片类似,但是我们的输出结果中只需要两项——猫或者狗。

在这个例子中,我们需要做的就是把dense layer和最终softmax layer的输出从1000个类别改为2个类别。

场景二:数据集小,数据相似度不高

在这种情况下,我们可以冻结预训练模型中的前k个层中的权重,然后重新训练后面的n-k个层,当然最后一层也需要根据相应的输出格式来进行修改。

因为数据的相似度不高,重新训练的过程就变得非常关键。而新数据集大小的不足,则是通过冻结预训练模型的前k层进行弥补。

场景三:数据集大,数据相似度不高

在这种情况下,因为我们有一个很大的数据集,所以神经网络的训练过程将会比较有效率。然而,因为实际数据与预训练模型的训练数据之间存在很大差异,采用预训练模型将不会是一种高效的方式。

因此最好的方法还是将预处理模型中的权重全都初始化后在新数据集的基础上重头开始训练。

场景四:数据集大,数据相似度高

这就是最理想的情况,采用预训练模型会变得非常高效。最好的运用方式是保持模型原有的结构和初始权重不变,随后在新数据集的基础上重新训练。

6. 在手写数字识别中使用预训练模型

现在,让我们尝试来用预训练模型去解决一个简单的问题。

我曾经使用vgg16作为预训练的模型结构,并把它应用到手写数字识别上。

让我们先来看看这个问题对应着之前四种场景中的哪一种。我们的训练集(MNIST)有大约60,000张左右的手写数字图片,这样的数据集显然是偏小的。所以这个问题应该属于场景一或场景二。

我们可以尝试把两种对应的方法都用一下,看看最终的效果。

只重新训练输出层 & dense layer

这里我们采用vgg16作为特征提取器。随后这些特征,会被传递到依据我们数据集训练的dense layer上。输出层同样由与我们问题相对应的softmax层函数所取代。

在vgg16中,输出层是一个拥有1000个类别的softmax层。我们把这层去掉,换上一层只有10个类别的softmax层。我们只训练这些层,然后就进行数字识别的尝试。

# importing required librariesfromkeras.models importSequential fromscipy.misc importimreadget_ipython().magic( 'matplotlib inline') importmatplotlib.pyplot asplt importnumpy asnp importkeras fromkeras.layers importDense importpandas aspd fromkeras.applications.vgg16 importVGG16 fromkeras.preprocessing importimage fromkeras.applications.vgg16 importpreprocess_input importnumpy asnp fromkeras.applications.vgg16 importdecode_predictionstrain=pd.read_csv( "R/Data/Train/train.csv")test=pd.read_csv( "R/Data/test.csv")train_path= "R/Data/Train/Images/train/"test_path= "R/Data/Train/Images/test/"fromscipy.misc importimresize # preparing the train datasettrain_img=[] fori inrange(len(train)): temp_img=image.load_img(train_path+train[ 'filename'][i],target_size=( 224, 224)) temp_img=image.img_to_array(temp_img) train_img.append(temp_img) #converting train images to array and applying mean subtraction processingtrain_img=np.array(train_img) train_img=preprocess_input(train_img) # applying the same procedure with the test datasettest_img=[] fori inrange(len(test)): temp_img=image.load_img(test_path+test[ 'filename'][i],target_size=( 224, 224)) temp_img=image.img_to_array(temp_img) test_img.append(temp_img)test_img=np.array(test_img) test_img=preprocess_input(test_img) # loading VGG16 model weightsmodel = VGG16(weights= 'imagenet', include_top= False) # Extracting features from the train dataset using the VGG16 pre-trained modelfeatures_train=model.predict(train_img) # Extracting features from the train dataset using the VGG16 pre-trained modelfeatures_test=model.predict(test_img) # flattening the layers to conform to MLP inputtrain_x=features_train.reshape( 49000, 25088) # converting target variable to arraytrain_y=np.asarray(train[ 'label']) # performing one-hot encoding for the target variabletrain_y=pd.get_dummies(train_y)train_y=np.array(train_y) # creating training and validation setfromsklearn.model_selection importtrain_test_splitX_train, X_valid, Y_train, Y_valid=train_test_split(train_x,train_y,test_size= 0.3, random_state= 42) # creating a mlp modelfromkeras.layers importDense, Activationmodel=Sequential()model.add(Dense( 1000, input_dim= 25088, activation= 'relu',kernel_initializer= 'uniform'))keras.layers.core.Dropout( 0.3, noise_shape= None, seed= None)model.add(Dense( 500,input_dim= 1000,activation= 'sigmoid'))keras.layers.core.Dropout( 0.4, noise_shape= None, seed= None)model.add(Dense( 150,input_dim= 500,activation= 'sigmoid'))keras.layers.core.Dropout( 0.2, noise_shape= None, seed= None)model.add(Dense(units= 10))model.add(Activation( 'softmax'))model.compile(loss= 'categorical_crossentropy', optimizer= "adam", metrics=[ 'accuracy']) # fitting the model model.fit(X_train, Y_train, epochs= 20, batch_size= 128,validation_data=(X_valid,Y_valid))

冻结最初几层网络的权重

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容