In Tensorflow, we only need to use snippet below to assign a device to a Operation:
with tf.device('/GPU:0'):
...
result = tf.matmul(a, b)
How dose it implement? Let’s take a look.
There is a mechanism called ‘context manager’ in Python. For example, we can use it to add a wrapper for a few codes:
from contextlib import contextmanager
@contextmanager
def tag(name):
print("[%s]" % name)
yield
print("[/%s]" % name)
with tag("robin"):
print("what")
print("is")
print("nature's")
The result of running this script is:
[robin] what is nature's [/robin]
Function ‘tag()’ works like a decorator. It will do something before and after those codes laying under its ‘context’.
Tensorflow uses the same principle.
@tf_export("device")
def device(device_name_or_function):
...
if context.executing_eagerly():
# TODO(agarwal): support device functions in EAGER mode.
if callable(device_name_or_function):
raise RuntimeError(
"tf.device does not support functions when eager execution "
"is enabled.")
return context.device(device_name_or_function)
else:
return get_default_graph().device(device_name_or_function)
This will call class Graph’s function ‘device()’. Its implementation:
@tf_export("GraphKeys")
class GraphKeys(object):
...
@tf_contextlib.contextmanager
def device(self, device_name_or_function):
...
self._add_device_to_stack(device_name_or_function, offset=2)
try:
yield
finally:
self._device_function_stack.pop_obj()
The key line is ‘self._add_device_to_stack()’. Context of ‘device’ will add device name into stack of python, and when developer create an Operation it will fetch device name from stack and set it to this Operation.
Let’s check the code routine of creating Operation:
@tf_export("GraphKeys")
class GraphKeys(object):
...
def create_op(
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
...
with self._mutation_lock():
ret = Operation(
node_def,
self,
inputs=inputs,
output_types=dtypes,
control_inputs=control_inputs,
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
self._create_op_helper(ret, compute_device=compute_device)
return ret
def _create_op_helper(self, op, compute_device=True):
...
if compute_device:
self._apply_device_functions(op)
def _apply_device_functions(self, op):
...
for device_spec in self._device_function_stack.peek_objs():
if device_spec.function is None:
break
op._set_device(device_spec.function(op))
op._device_code_locations = self._snapshot_device_function_stack_metadata()
‘self._device_function_stack.peek_objs’ is where it peek the device name from stack.