About one month ago, I submit a request to Google Research Cloud for using TPU for free. Fortunately, I received the approvement yesterday. The approvement let me use 5 regular Cloud TPUs and 100 preemptible Cloud TPUs for free for 30 days with only submitting my GCP project name to it.
Then I have to change my previous Tensorflow program to let it run on TPUs. I can’t just change tf.device(‘/gpu:0’) to ‘tf.device(‘/tpu:0’) in code to run training on Google TPU. Actually, there are many documents about how to modify the code for this, such as TPUEstimator, Using TPUs etc.
Here are some tips about porting code for TPUs:
1. We can only use TPUEstimator for training
classifier = tf.contrib.tpu.TPUEstimator(
model_fn = model_wrapper,
config = run_config,
use_tpu = FLAGS.use_tpu,
train_batch_size = 64,
batch_axis = [0, 0],
params = {'optimizer': opt})
Pay attention to the ‘batch_axis’. It tells TPU pod to split data by ‘0’ dimension for data and labels, for I use ‘NHWC’ data format.
2. model_fn and data_input_fn in TPUEstimator has arguments more than regular tf.estimator.Estimator. We need to fetch some arguments (‘batch_size’) from params.
def data_input_fn(params):
batch = params['batch_size']
...
def model_fn(features, labels, mode, config, params):
...
3. TPU doesn’t support the operation like
images = tf.contrib.image.rotate(images, tf.random_uniform([1], minval = -math.pi / 4.0, maxval = math.pi / 4.0))
So try to avoid using them
4. Carefully use tf.dataset or else it will report data shape error. The code below could run correctly so far
dataset = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, sloppy = True, cycle_length = buff_size))
dataset = dataset.map(_parse_function)
dataset = dataset.repeat()
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.shuffle(batch_size * buff_size)
iterator = dataset.make_initializable_iterator()
5. Because using TPUEstimator, we can’t init iterator of tf.dataset in ‘session.run()’, so a little trick should be used:
def data_input_fn():
...
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, it.initializer)
...
6. The Tensorflow in GCP VM instance only supports loading datasets from and storing model into GCP storage.
run_config = tf.contrib.tpu.RunConfig(
master = master,
evaluation_master = master,
model_dir = 'gs://my-project/models/',
session_config = tf.ConfigProto(
allow_soft_placement = True, log_device_placement = True),
tpu_config = tf.contrib.tpu.TPUConfig(
FLAGS.iterations, FLAGS.num_shards)
)
7. There aren’t any hooks for TPUEstimator currently in Tensorflow-1.9. So I can’t see any report from console after launching a TPU program. Hope Google could improve it as soon as possible.