My python code using slim library to train classification model in Tensorflow:
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(weight_decay = 0.001)):
logits, _ = mobilenet_v2.mobilenet(images, NUM_CLASSES)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
cross_entropy = tf.reduce_mean(cross_entropy)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.contrib.slim.learning.create_train_op(cross_entropy, opt, global_step = global_step)
...
sess.run(train_op)
It works fine. However, no matter what value the ‘weight_decay’ is, the training accuracy of the model could reach higher than 90% easily. It seems ‘weight_decay’ just doesn’t work.
In order to find out the reason, I reviewed the code of Tensorflow for ‘tf.losses.sparse_softmax_cross_entropy()’:
# tensorflow/python/ops/losses/losses_impl.py
@tf_export("losses.sparse_softmax_cross_entropy")
def sparse_softmax_cross_entropy(
labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
...
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
(logits, labels, weights)) as scope:
# As documented above in Args, labels contain class IDs and logits contains
# 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
# therefore, expected_rank_diff=1.
labels, logits, weights = _remove_squeezable_dimensions(
labels, logits, weights, expected_rank_diff=1)
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
name="xentropy")
return compute_weighted_loss(
losses, weights, scope, loss_collection, reduction=reduction)
The ‘losses.sparse_softmax_cross_entropy()’ simply call ‘tf.nn.sparse_softmax_cross_entropy()’. Then let’s look into the implementation of ‘compute_weighted_loss()’:
# tensorflow/python/ops/losses/losses_impl.py
@tf_export("losses.compute_weighted_loss")
def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
...
loss = math_ops.cast(loss, input_dtype)
util.add_loss(loss, loss_collection)
return loss
What the secret in 'util.add_loss()'?
# tensorflow/python/ops/losses/util.py
@tf_export("losses.add_loss")
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
...
if loss_collection:
ops.add_to_collection(loss_collection, loss)
The losses of 'losses.sparse_softmax_cross_entropy()' will be added into collection of 'GraphKeys.LOSSES'. Then where dose the weight of parameters go ? Will they be added into same collection ? Let's check. All the layer written by library of 'tf.layers' or 'tf.contrib.slim' are inherited from 'class Layer' and will call 'add_loss()' when this layer call 'add_variable()'. Let's check 'add_loss()' of base class 'Layer':
@tf_export('layers.Layer')
class Layer(checkpointable.CheckpointableBase):
...
def add_loss(self, losses, inputs=None):
...
_add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)
It's weird. The loss from weight of variable has not been added into 'GraphKeys.LOSSES', but 'GraphKeys.REGULARIZATION_LOSSES'. Then how could we get all the losses at training stage ? After grep 'REGULARIZATION_LOSSES' in whole codes of Tensorflow, it comes up with the 'get_total_loss()':
# tensorflow/python/ops/losses/util.py
@tf_export("losses.get_total_loss")
def get_total_loss(add_regularization_losses=True, name="total_loss"):
...
losses = get_losses()
if add_regularization_losses:
losses += get_regularization_losses()
return math_ops.add_n(losses, name=name)
That is the secret of losses in 'tf.layers' and 'tf.contrib.slim': we should use 'get_total_loss()' to fetch model loss and regularization loss together!
After changing my code:
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
cross_entropy = tf.reduce_mean(cross_entropy)
global_step = tf.contrib.framework.get_or_create_global_step()
loss = tf.contrib.slim.losses.get_total_loss()
train_op = tf.contrib.slim.learning.create_train_op(loss, opt, global_step = global_step)
...
sess.run(train_op)
The 'weight_decay' works well now (which means training accuracy could not reach high value easily)