暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

keras进阶,我从Layer开始之八

工程师milter 2020-08-21
751

继续keras源码阅读,不急不躁,慢就是快。

490-502

这里处理了一个很trick的问题。就是当layer只是简单将input返回作为output时,要对output进行一下复制,防止丢失tensor的元数据。

复制的函数就是 K.identify。

这个函数的签名说的很清楚:

复制一个tensor的值,感觉这个函数还是很有用的。值得用心记一下。

503-512

这段代码不用考虑,因为注释说的很明白:

513-517

这段同样是对mask进行处理。上一篇文章中,看到Embedding layer计算mask时,根据input是否等于0,产生了mask。

但是有一个问题,就是当输出是多个tensor时,需要每个tensor有一个mask。这段代码就是做这件事儿的。为什么不用一个呢?

我想这样的设计是为了简单,清晰。每个output tensor都对应自己的mask。数据结构非常整齐。

518-530

又到了这段代码,在前面已经见过几次了。不过我觉得还是有必要认认真真读一下。因为它对理解keras组织神经网络的方式非常重要。

首先就是注释。说明了Node的作用。

1、追踪对该layer的call调用

2、追踪call调用产生的新的variable

3、更新output tensor的keras history

这一点我们之前已经学习过,就是tensor的坐标。强调一点,我觉得tensor坐标是一个非常重要的概念。它揭示出了keras中很多的设计初衷。

4、如果input tensor有自己的坐标,则也进行记录。

看完注释,自然是要进到方法内部去看看。

这是上面讲的第4点的内容。记录input tensor的坐标。

inbound_layers = []

node_indices = []

tensor_indices = []

就是这三项。当然,前提是input tensor 要有_keras_history这个属性,如果没有,就用None代替。

看到这里,我想到,keras现在强绑定tensorflow,那按道理,现在tensorflow的每个tensor应该都有_keras_history了吧。为了验证,我打开了tensorflow的源码。发现并没有。

这样就理解了一个限制,keras layer的输入必须是其他的keras的layer的输出,目的就是确保每个tensor都有_keras_history,以便keras能够全面地追踪并管理整个神经网络。

这一点,在我们使用tf.keras时要格外注意。

同时,进一步,可以更深刻地理解tensor和variable的区别。tensor是layer的输入和输出,在计算时才有值,同时,还记录了自己从哪里来这样的信息(即tensor坐标)。variable就是比较简单了,就是简单的多维数组,并且有初始化值。

这么联系起来一想,对keras就有了更深刻的感知了。

继续往下读:

这里是创建了Node,之前已经看过,在创建Node时,会把该node加到input layer的outbound_nodes数组中,同时加入该layer的inbound_nodes数组中。

有点绕,需要慢下来,确保自己弄清楚了。

Node添加完之后,就是为output tensor添加_keras_history

关键看下面的一行。

(self, len(self._inbound_nodes) - 1, i)

因为Node创建时,该layer的_inbound_nodes已经添加了一个Node,所以,直接使用该Node的index作为第二个坐标。

上面的_uses_learning_phase是keras中一个便利标志。可以让每个tensor针对训练阶段进行一些定制操作。K.in_train_phase就是配套的方法。这一点可以以后慢慢理解,现在知道有这么回事儿就行。

确定每个output tensor的 _uses_learning_phase时有三个考虑,一个就是该tensor自己的_uses_learning_phase属性,另一个就是产生该tensor的layer的_uses_learning_phase属性,另一个就是所有的input tensor的_uses_learning_phase属性。代码中已经明示了。

至此,我们就读完了__call__方法。撒花!

往期精彩回顾

【crash DP】上下班路上学动态规划(1)

【腾讯面试热身题】二叉树层次遍历(动画展示)

【阿里高频题】动画讲解二叉树深度优先遍历

【阿里高频题】二叉树深度遍历进阶版

【图解】快慢指针解链表判环问题

【强!】深度优先遍历求二叉树的直径

  换个风格学快速排序

【头条面试-40K】反转吧,链表!

【头条面试-40K】反转链表加强版

【面霸之路】双指针和滑动窗口结合

【送书!】最佳算法入门书籍

  我是如何简单粗暴攻克面试算法题入职阿里的?

【阿里面试热身题】数组去重(动画展示)




文章转载自工程师milter,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论