1 def auto_parallel(metagraph, model):
2   from tensorflow.python.grappler import tf_optimizer
3   rewriter_config =  rewriter_config_pb2.RewriterConfig()
4   rewriter_config.optimizers.append("autoparallel")
5   rewriter_config.auto_parallel.enable = True
6   rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus
7   optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
8   metagraph.graph_def.CopyFrom(optimized_graph)
9   UpdateCollection(metagraph, model)

Figure 1: Flow chart of auto_parallel.py file

results matching ""

    No results matching ""