1 def auto_parallel(metagraph, model):
2 from tensorflow.python.grappler import tf_optimizer
3 rewriter_config = rewriter_config_pb2.RewriterConfig()
4 rewriter_config.optimizers.append("autoparallel")
5 rewriter_config.auto_parallel.enable = True
6 rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus
7 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
8 metagraph.graph_def.CopyFrom(optimized_graph)
9 UpdateCollection(metagraph, model)
Figure 1: Flow chart of auto_parallel.py file