Closed this PR since it had already been merged by a different commit.
Enable differentiable training and update cluster indices
In the current clustering implementation, the original weights of the clustered layers and the cluster indices are not updated during each training step. Tho the training process alters the values of the cluster centroids, due to the changes in other non-clustered layers, which are reflected in the gradients during training, the non-updated original weights will not always match the constantly updated cluster centroids and will create problems during training. In order to fix this issue, we want to update the original weights after the backpropagation. After that, using the updated centroids, the indices should be re-generated in the next training step. In this PR, changes are made and the unit tests for them are created too.
Details of the implementation:
As shown in the figure below, in the forward pass of our current clustering implementation, first, it uses density-based or linear methods to initialize the centroids (c) for the weights of each layer. Then, the original set of weights (W) are grouped into several clusters using the centroid values. Afterward, the association between the weights and the centroids is calculated based on c and W as indices. Finally, for a single cluster, the centroid value will be shared among all the weights and used in the forward pass instead of the original weights.
In the current backpropagation, the clustered weights will get the gradients from the layer being wrapped. These gradients will be fed into the node gather. Then, the gather node groups all the gradients by indices and accumulates them as the gradients of the centroids. However, due to the non-differentiable node
tf.math.argmin, no gradients will be calculated for original weights W by automatic differentiation in TensorFlow.
1) how to update the original weights?
A small modification (gradient approximations using the straight-through estimator ) of the training graph is used to override the gradient during backpropagation like this:
clustered_weights = tf.gather(cluster_centroids, indices)*tf.sign(original_weights + 1e+6)
In the forward pass, the
multiply in the graph does not change the graph (
tf.sign gives out identity matrix) but in the backpropagation, the multiply is changed into add and the
tf.sign is changed into identity via
tf.custom_gradient. Essentially, the graph becomes:
clustered_weights = tf.gather(cluster_centroids, indices)+tf.identity(original_weights + 1e+6)
In this way, original weights can be updated by the automatic differentiation in TensorFlow.
2) how to update cluster indices?
Indices are not differentiable themselves and they are calculated only in the forward pass during training. Therefore, they are updated using
tf.assign specifically in the forward pass in the
call function. This will lead to some extra change for using
tf.distribute, which has not been covered in this PR.
Result table: As shown in the table below, the changes in this PR significantly improve the accuracy when the number of clusters is small and give limited benefit for other configurations.
| Model | Number of clusters | tfmot | tfmot+this PR | delta |
| ------------- | ------------- | ------------- | ------------- | ------------- |
| Mobilenet_v1 | full model (all 64) | 65.03% | 66.65% | 1.62% |
| | | 3.11 MB | 3.06 MB | -0.05 MB |
| | selective clustering (32 32 32) | 49.72% | 68% | 18.28% |
| | | 7.17 MB | 6.99 MB | -0.18 MB |
| | selective clustering (256 256 32) | 70.16% | 69.32% | -0.84% |
| | | 8.32 MB | 7.68 MB | -0.64 MB |
| Mobilenet_v2 | full model (all 32) | 68.26% | 69.09% | 0.83% | | | | 2.65 MB | 2.64 MB | -0.01 MB | | | selective clustering (8 8 8) | 35.05%| 67.28%| 32.23%| | | | 6.25 MB | 6.23 MB | -0.02 MB | | | selective clustering (16 16 16) | 67.10% | 70.94% | 3.84% | | | | 6.59 MB | 6.42 MB | -0.17 MB | | | selective clustering (256 256 32) | 72.3% | 72.30% | 0| | | | 7.31 MB | 7.18 MB | -0.13 MB | | DS-CNN-L | full model (all 32) | 94.77% | 94.86% | 0.09% | | | | 0.33 MB | 0.33 MB | 0 MB | | | full model (all 8) | 73.51% | 86.83% | 13.32% | | | | 0.19 MB | 0.19 MB | 0 MB |
Reference:  Y. Bengio, N. Leonard, and A. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
- weixin_39757893 5月前点赞 评论 复制链接分享
- weixin_39757893 5月前
Hi and , I have just filled in all the results in the description. Could you please take a look at the PR and let me know your thoughts? Also, not sure how long the description should be? Thanks, , for reviewing.点赞 评论 复制链接分享
- weixin_39977642 5月前
As noted in the call, I'll take a look at this, but after it merges when I have the time.点赞 评论 复制链接分享
- weixin_39737233 5月前
, could you check what is holding up this PR please?点赞 评论 复制链接分享
- weixin_39977642 5月前
Yes I am.点赞 评论 复制链接分享