前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(https://github.com/tensorflow/magenta/blob/master/README.md) 本文详细解释了Sketch-RNN的TensorFlow代码,即之前发布的两篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循环神经网络模型(RNN)。 模型概览 sketch-rnn是序列到序列的变体自动编码器。编码器RNN是双向RNN,解码器是自回归混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size设置指定要使用的RNN单元格的类型和RNN的大小。 编码器将采用一个潜在代码z,一个维度为z_size的浮点矢量。像VAE一样,我们可以对z强制执行高斯IID分布,并使用kl_weight来控制KL发散损失项的强度。KL散度损失与重建损失之间将会有一个权衡。我们还允许潜在的代码存储信息的一些空间,而不是纯高斯IID。一旦KL损失期限低于kl_tolerance,我们将停止对该期限的优化。 对于中小型数据集,丢失(dropout)和数据扩充是避免过度拟合的非常有用的技术。我们提供了输入丢失、输出丢失、不存在内存丢失的循环丢失三个选项。实际上,我们只使用循环丢失,通常根据数据集将其设置在65%到90%之间。层次归一化和反复丢失可以一起使用,形成了一个强大的组合,用于在小型数据集上训练循环神经网络。 谷歌提供了两种数据增强技术。第一个是随机缩放训练图像大小的random_scale_factor。第二种增加技术(sketch-rnn论文中未使用)剔除线笔划中的随机点。给定一个具有超过2点的线段,我们可以随机放置线段内的点,并且仍然保持类似的矢量图像。这种类型的数据增强在小数据集上使用时非常强大,并且对矢量图是唯一的,因为难以在文本或MIDI数据中删除随机字符或音符,并且也不可能在像素图像数据中丢弃随机像素而不引起大的视觉差异。我们通常将数据增加参数设置为10%至20%。如果在与普通示例相比较的情况下,人类观众几乎没有差异,那么我们应用数据增强技术,而不考虑训练数据集的大小。 有效地使用丢弃和数据扩充,可以避免过度拟合到一个小的训练集。 训练模型 要训练模型,首先需要一个包含训练/验证/测试例子的数据集。我们提供了指向aaron_sheep数据集的链接,默认情况下,该模型将使用此轻量级数据集。 使用示例: sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"} 我们建议你在模型和数据集内部创建子目录,以保存自己的数据和检查点。 TensorBoard日志将存储在checkpoint_path内,用于查看训练/验证/测试数据集中各种损失的训练曲线。 以下是模型的完整选项列表以及默认设置: data_set='aaron_sheep.npz', # Our dataset. save_every=500, # Number of batches percheckpoint creation. dec_rnn_size=512, # Size of decoder. dec_model='lstm', # Decoder: lstm, layer_norm orhyper. enc_rnn_size=256, # Size of encoder. enc_model='lstm', # Encoder: lstm, layer_norm orhyper. z_size=128, # Size of latent vector z.Recommend 32, 64 or 128. kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0. kl_weight_start=0.01, # KL start weight when annealing. kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL. batch_size=100, # Minibatch size. Recommendleaving at 100. grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0. num_mixture=20, # Number of mixtures in Gaussianmixture model. learning_rate=0.001, # Learning rate. decay_rate=0.9999, # Learning rate decay per minibatch. kl_decay_rate=0.99995, # KL annealing decay rate per minibatch. min_learning_rate=0.00001, # Minimum learning rate. use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended. recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep. use_input_dropout=False, # Input dropout. Recommend leaving False. input_dropout_prob=0.90, # Probability of input dropout keep. use_output_dropout=False, # Output droput. Recommend leaving False. output_dropout_prob=0.90, # Probability of output dropout keep. random_scale_factor=0.15, # Random scaling data augmentionproportion. augment_stroke_prob=0.10, # Point dropping augmentation proportion. conditional=True, # If False, use decoder-only model. (责任编辑:本港台直播) |