Google has published their quantization method on this paper. It use int8 to run feed-forward but float32 for back-propagation, since back-propagation need more accurate to accumulate gradients. I got a question right after reading the paper: why all the performance test works are on platform of mobile-phone (ARM architecture)? The quantization consequences of model in google’s method doesn’t only need addition and multiplication of int8 numbers, but also bit-shift operations. The AVX instruments set in Intel x86_64 architecture could accelerate MAC (Multiplication, Addition and aCcumulation), but couldn’t boost bit-shift operations.
To verify my suspicion, I wrote a model with ResNet-50 (float32) to classify CIFAR-100 dataset. After running a few epochs, I evaluate the speed of inference by using my ‘eval.py’. The result is:
Time: 5.58819s
Then, I follow these steps to add tf.contrib.quantize.create_training_graph() and tf.contrib.quantize.create_eval_graph() into my code. This time, the speed of inference is:
Time: 6.23221s
A little bit of disappointment. Using quantized (int8) version of model could not accelerate processing speed of x86 CPU. May be we need to find other more powerful quantization algorithm.
Appendix:
# eval.py
from input_data import Cifar100Data
import tensorflow as tf
import numpy as np
import resnet_v2
import argparse
import time
import sys
EVAL_SAMPLES = 10000
BATCH_SIZE = 10000
MODEL_PATH = './models/'
MODEL_NAME = 'cifar_resnet_50'
def cnn_part(images):
print(images.shape)
ivg, _ = resnet_v2.resnet_v2_50(images, 100)
return ivg
def main(_):
with tf.device('/cpu:0'):
images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3])
labels = tf.placeholder(tf.int64, [BATCH_SIZE])
with tf.contrib.slim.arg_scope([tf.contrib.slim.conv2d],
weights_initializer = tf.truncated_normal_initializer(mean = 0, stddev = 0.1)):
image_vector = cnn_part(images)
loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, logits = image_vector)
loss = tf.reduce_mean(loss)
opt = tf.train.AdamOptimizer(1e-3)
train_op = tf.contrib.slim.learning.create_train_op(loss, opt)
correct_prediction = tf.equal(tf.argmax(image_vector, 1), labels)
correct_prediction = tf.cast(correct_prediction, tf.float32)
accuracy = tf.reduce_mean(correct_prediction)
data = Cifar100Data('/disk3/cifar/cifar-100-python/test')
saver = tf.train.Saver()
with tf.Session() as sess:
with tf.gfile.FastGFile('./models/cifar_resnet_50_quant.pb') as fl:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fl.read())
tf.import_graph_def(graph_def, name = '')
saver.restore(sess, MODEL_PATH + MODEL_NAME + '-' + str(FLAGS.epoch))
batch = data.next_batch(BATCH_SIZE)
for i in range(3):
begin = time.time()
res = sess.run(accuracy, feed_dict = {images: batch[0], labels: batch[1]})
print("Time: %gs" % (time.time() - begin))
print(res)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=str,
default='8',
help='Epoch of checkpoint for evaluation')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main = main, argv = [sys.argv[0]] + unparsed)