本港台开奖现场直播 j2开奖直播报码现场
当前位置: 新闻频道 > IT新闻 >

【j2开奖】专栏 | 手机端运行卷积神经网络实践:基于TensorFlow和OpenCV实现文档检测功能(6)

时间:2017-05-31 19:19来源:报码现场 作者:j2开奖直播 点击:
在不断改进训练样本的过程中,还根据真实样本图片的统计情况和各种途径的反馈信息,刻意模拟了一些更复杂的样本场景,比如凌乱的背景环境、直线边

在不断改进训练样本的过程中,还根据真实样本图片的统计情况和各种途径的反馈信息,刻意模拟了一些更复杂的样本场景,比如凌乱的背景环境、直线边缘干扰等等

经过不断的调整和优化,最终才训练出一个满意的模型,可以再次通过下面这张图表中的第二列看一下神经网络模型的边缘检测效果:

  

【j2开奖】专栏 | 手机端运行卷积神经网络实践:基于TensorFlow和OpenCV实现文档检测功能

在手机设备上运行 TensorFlow

在手机上使用 TensorFlow 库

TensorFlow 官方是支持 iOS 和 Android 的,而且有清晰的文档,照着做就行。但是因为 TensorFlow 是依赖于 protobuf 3 的,所以有可能会遇到一些其他的问题,比如下面这两种,就是我们在两个不同的 iOS APP 中遇到的问题和解决办法,可以作为一个参考:

A 产品使用的是 protobuf 2,同时由于各种历史原因,使用并且停留在了很旧的某个版本的 Base 库上,而 protobuf 3 的内部也使用了 Base 库,当 A 产品升级到 protobuf 3 后,protobuf 3 的 Base 库和 A 源码中的 Base 库产生了一些奇怪的冲突,最后的解决办法是手动修改了 A 源码中的 Base 库,避免编译时的冲突

B 产品也是使用的 protobuf 2,而且 B 产品使用到的多个第三方模块 (没有源码,只有二进制文件) 也是依赖于 protobuf 2,直接升级 B 产品使用的 protobuf 库就行不通了,最后采用的方法是修改 TensorFlow 和 TensorFlow 中使用的 protobuf 3 的源代码,把 protobuf 3 换了一个命名空间,这样两个不同版本的 protobuf 库就可以共存了

Android 上因为本身是可以使用动态库的,所以即便 app 必须使用 protobuf 2 也没有关系,不同的模块使用 dlopen 的方式加载各自需要的特定版本的库就可以了。

在手机上使用训练得到的模型文件

模型通常都是在 PC 端训练的,对于大部分使用者,都是用 Python 编写的代码,得到 ckpt 格式的模型文件。在使用模型文件的时候,一种做法就是用代码重新构建出完整的神经网络,然后加载这个 ckpt 格式的模型文件,如果是在 PC 上使用模型文件,用这个方法其实也是可以接受的,复制粘贴一下 Python 代码就可以重新构建整个神经网络。但是,在手机上只能使用 TensorFlow 提供的 C++ 接口,如果还是用同样的思路,就需要用 C++ API 重新构建一遍神经网络,这个工作量就有点大了,而且 C++ API 使用起来比 Python API 复杂的多,所以,在 PC 上训练完网络后,还需要把 ckpt 格式的模型文件转换成 pb 格式的模型文件,这个 pb 格式的模型文件,atv,是用 protobuf 序列化得到的二进制文件,里面包含了神经网络的具体结构以及每个矩阵的数值,使用这个 pb 文件的时候,不需要再用代码构建完整的神经网络结构,只需要反序列化一下就可以了,这样的话,用 C++ API 编写的代码就会简单很多,其实这也是 TensorFlow 推荐的使用方法,在 PC 上使用模型的时候,也应该使用这种 pb 文件 (训练过程中使用 ckpt 文件)。

HED 网络在手机上遇到的奇怪 crash

在手机上加载 pb 模型文件并且运行的时候,遇到过一个诡异的错误,内容如下:

  Invalid argument: No OpKernel was registered to support Op 'Mul' with these attrs. Registered devices: [CPU], Registered kernels:

  device='CPU'; T in [DT_FLOAT]

  [[Node: hed/mul_1 = Mul[T=DT_INT32](hed/strided_slice_2, hed/mul_1/y)]]

之所以诡异,是因为从字面上看,这个错误的含义是缺少乘法操作 (Mul),但是我用其他的神经网络模型做过对比,乘法操作模块是可以正常工作的。

Google 搜索后发现很多人遇到过类似的情况,但是错误信息又并不相同,后来在 TensorFlow 的 github issues 里终于找到了线索,综合起来解释,是因为 TensorFlow 是基于操作 (Operation) 来模块化设计和编码的,每一个数学计算模块就是一个 Operation,由于各种原因,比如内存占用大小、GPU 独占操作等等,mobile 版的 TensorFlow,并没有包含所有的 Operation,mobile 版的 TensorFlow 支持的 Operation 只是 PC 完整版 TensorFlow 的一个子集,我遇到的这个错误,就是因为使用到的某个 Operation 并不支持 mobile 版。

按照这个线索,在 Python 代码中逐个排查,后来定位到了出问题的代码,修改前后的代码如下:

  def deconv(inputs, upsample_factor):

  input_shape = tf.shape(inputs)

  # Calculate the ouput size of the upsampled tensor

  upsampled_shape = tf.pack([input_shape[0],

  input_shape[1] * upsample_factor,

  input_shape[2] * upsample_factor,

  1])

  upsample_filter_np = bilinear_upsample_weights(upsample_factor, 1)

  upsample_filter_tensor = tf.constant(upsample_filter_np)

  # Perform the upsampling

  upsampled_inputs = tf.nn.conv2d_transpose(inputs, upsample_filter_tensor,

  output_shape=upsampled_shape,

  strides=[1, upsample_factor, upsample_factor, 1])

  return upsampled_inputs

  def deconv_mobile_version(inputs, upsample_factor, upsampled_shape):

  upsample_filter_np = bilinear_upsample_weights(upsample_factor, 1)

  upsample_filter_tensor = tf.constant(upsample_filter_np)

  # Perform the upsampling

  upsampled_inputs = tf.nn.conv2d_transpose(inputs, upsample_filter_tensor,

  output_shape=upsampled_shape,

  strides=[1, upsample_factor, upsample_factor, 1])

  return upsampled_inputs

(责任编辑:本港台直播)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
推荐内容