Here is my experimental code for distributed Tensorflow, which is learned from the example.

import tensorflow as tf
import argparse
import time

FLAGS = None

def main():
    print(tf.__version__)

    cluster_spec = tf.train.ClusterSpec({
        'worker': ['localhost:1829'],
        'ps': ['localhost:1057'],
        })
            
    if FLAGS.ps: 
        server = tf.train.Server(cluster_spec, job_name = 'ps', task_index = 0)
        server.join()
    else:   
        server = tf.train.Server(cluster_spec, job_name = 'worker', task_index = FLAGS.worker)
        print(server.target)
            
        with tf.device('/job:ps/task:0'):
            init = tf.constant_initializer([0])
            c = tf.get_variable('myc', shape = [], initializer = init)

        res = tf.add(c, 1)
        train_op = tf.assign(c, res)
    
        with tf.Session(target = server.target) as sess:
            c.initializer.run()
            while True:
                res = sess.run(train_op)
                print(res)
                time.sleep(1)
...

The important thing is that we need to use tf.assign() to push Variable back to Parameter Server. The operation ‘tf.add’ was about to run on the task0 of worker in this example. But if we deploy more complicated application by many tasks, things became weird: a pipeline operation sometimes even runs on ‘ps’ role! The official solution to this problem is using ‘tf.train.replica_device_setter()’, which will automatically deploy Variables to parameter servers and Operations (many replicas) to many workers. What did ‘tf.train.replica_device_setter()’ do? Let’s see the backbone code of its implementation:

def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
                          worker_device="/job:worker", merge_devices=True,
                          cluster=None, ps_ops=None, ps_strategy=None):
...
  if ps_ops is None:
    # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
    # placed in the parameter server.
    ps_ops = ["Variable", "VariableV2", "VarHandleOp"]

  if not merge_devices:
    logging.warning(
        "DEPRECATION: It is recommended to set merge_devices=true in "
        "replica_device_setter")
  if ps_strategy is None:
    ps_strategy = _RoundRobinStrategy(ps_tasks)
  if not six.callable(ps_strategy):
    raise TypeError("ps_strategy must be callable")
  chooser = _ReplicaDeviceChooser(
      ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy)
  return chooser.device_function

All the Variables will be counted as ‘ps_ops’, and the deploy strategy for Operations will be replication, for it’s called ‘_ReplicaDeviceChooser’.

def device_function(self, op):
...
    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
      ps_device = pydev.DeviceSpec.from_string(self._ps_device)
    
      current_job, ps_job = current_device.job, ps_device.job
      if ps_job and (not current_job or current_job == ps_job):
        ps_device.task = self._ps_strategy(op)
  
      ps_device.merge_from(current_device)
      return ps_device.to_string()
        
    worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
    worker_device.merge_from(current_device)
    return worker_device.to_string()

All the ‘op’ in ‘self._ps_ops’ will be put into ‘ps_device’.