Google has published their quantization method on this paper. It use int8 to run feed-forward but float32 for back-propagation, since back-propagation need more accurate to accumulate gradients. I got a question right after reading the paper: why all the performance test works are on platform of mobile-phone (ARM architecture)? The quantization consequences of model in google’s method doesn’t only need addition and multiplication of int8 numbers, but also bit-shift operations. The AVX instruments set in Intel x86_64 architecture could accelerate MAC (Multiplication, Addition and aCcumulation), but couldn’t boost bit-shift operations.

To verify my suspicion, I wrote a model with ResNet-50 (float32) to classify CIFAR-100 dataset. After running a few epochs, I evaluate the speed of inference by using my ‘eval.py’. The result is:

Time: 5.58819s

Then, I follow these steps to add tf.contrib.quantize.create_training_graph() and tf.contrib.quantize.create_eval_graph() into my code. This time, the speed of inference is:

Time: 6.23221s

A little bit of disappointment. Using quantized (int8) version of model could not accelerate processing speed of x86 CPU. May be we need to find other more powerful quantization algorithm.

Appendix:

# eval.py
from input_data import Cifar100Data

import tensorflow as tf
import numpy as np

import resnet_v2
import argparse
import time
import sys

EVAL_SAMPLES = 10000
BATCH_SIZE = 10000

MODEL_PATH = './models/'
MODEL_NAME = 'cifar_resnet_50'

def cnn_part(images):
    print(images.shape)
    ivg, _ = resnet_v2.resnet_v2_50(images, 100)
    return ivg

def main(_):
    with tf.device('/cpu:0'):
        images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3])
        labels = tf.placeholder(tf.int64, [BATCH_SIZE])

    with tf.contrib.slim.arg_scope([tf.contrib.slim.conv2d],
                        weights_initializer = tf.truncated_normal_initializer(mean = 0, stddev = 0.1)):
        image_vector = cnn_part(images)

    loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, logits = image_vector)
    loss = tf.reduce_mean(loss)
    opt = tf.train.AdamOptimizer(1e-3)
    train_op = tf.contrib.slim.learning.create_train_op(loss, opt)

    correct_prediction = tf.equal(tf.argmax(image_vector, 1), labels)
    correct_prediction = tf.cast(correct_prediction, tf.float32)
    accuracy = tf.reduce_mean(correct_prediction)

    data = Cifar100Data('/disk3/cifar/cifar-100-python/test')

    saver = tf.train.Saver()
    with tf.Session() as sess:
        with tf.gfile.FastGFile('./models/cifar_resnet_50_quant.pb') as fl:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(fl.read())
        tf.import_graph_def(graph_def, name = '')

        saver.restore(sess, MODEL_PATH + MODEL_NAME + '-' + str(FLAGS.epoch))

        batch = data.next_batch(BATCH_SIZE)

        for i in range(3):
            begin = time.time()
            res = sess.run(accuracy, feed_dict = {images: batch[0], labels: batch[1]})
            print("Time: %gs" % (time.time() - begin))
            print(res)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=str,
                        default='8',
                        help='Epoch of checkpoint for evaluation')
    FLAGS, unparsed = parser.parse_known_args()

    tf.app.run(main = main, argv = [sys.argv[0]] + unparsed)