add naswot
This commit is contained in:
157
graph_dit/naswot/pycls/core/distributed.py
Normal file
157
graph_dit/naswot/pycls/core/distributed.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Distributed helpers."""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def is_master_proc():
|
||||
"""Determines if the current process is the master process.
|
||||
|
||||
Master process is responsible for logging, writing and loading checkpoints. In
|
||||
the multi GPU setting, we assign the master role to the rank 0 process. When
|
||||
training using a single GPU, there is a single process which is considered master.
|
||||
"""
|
||||
return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
|
||||
|
||||
|
||||
def init_process_group(proc_rank, world_size):
|
||||
"""Initializes the default process group."""
|
||||
# Set the GPU to use
|
||||
torch.cuda.set_device(proc_rank)
|
||||
# Initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend=cfg.DIST_BACKEND,
|
||||
init_method="tcp://{}:{}".format(cfg.HOST, cfg.PORT),
|
||||
world_size=world_size,
|
||||
rank=proc_rank,
|
||||
)
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
"""Destroys the default process group."""
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def scaled_all_reduce(tensors):
|
||||
"""Performs the scaled all_reduce operation on the provided tensors.
|
||||
|
||||
The input tensors are modified in-place. Currently supports only the sum
|
||||
reduction operator. The reduced values are scaled by the inverse size of the
|
||||
process group (equivalent to cfg.NUM_GPUS).
|
||||
"""
|
||||
# There is no need for reduction in the single-proc case
|
||||
if cfg.NUM_GPUS == 1:
|
||||
return tensors
|
||||
# Queue the reductions
|
||||
reductions = []
|
||||
for tensor in tensors:
|
||||
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
||||
reductions.append(reduction)
|
||||
# Wait for reductions to finish
|
||||
for reduction in reductions:
|
||||
reduction.wait()
|
||||
# Scale the results
|
||||
for tensor in tensors:
|
||||
tensor.mul_(1.0 / cfg.NUM_GPUS)
|
||||
return tensors
|
||||
|
||||
|
||||
class ChildException(Exception):
|
||||
"""Wraps an exception from a child process."""
|
||||
|
||||
def __init__(self, child_trace):
|
||||
super(ChildException, self).__init__(child_trace)
|
||||
|
||||
|
||||
class ErrorHandler(object):
|
||||
"""Multiprocessing error handler (based on fairseq's).
|
||||
|
||||
Listens for errors in child processes and propagates the tracebacks to the parent.
|
||||
"""
|
||||
|
||||
def __init__(self, error_queue):
|
||||
# Shared error queue
|
||||
self.error_queue = error_queue
|
||||
# Children processes sharing the error queue
|
||||
self.children_pids = []
|
||||
# Start a thread listening to errors
|
||||
self.error_listener = threading.Thread(target=self.listen, daemon=True)
|
||||
self.error_listener.start()
|
||||
# Register the signal handler
|
||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||
|
||||
def add_child(self, pid):
|
||||
"""Registers a child process."""
|
||||
self.children_pids.append(pid)
|
||||
|
||||
def listen(self):
|
||||
"""Listens for errors in the error queue."""
|
||||
# Wait until there is an error in the queue
|
||||
child_trace = self.error_queue.get()
|
||||
# Put the error back for the signal handler
|
||||
self.error_queue.put(child_trace)
|
||||
# Invoke the signal handler
|
||||
os.kill(os.getpid(), signal.SIGUSR1)
|
||||
|
||||
def signal_handler(self, _sig_num, _stack_frame):
|
||||
"""Signal handler."""
|
||||
# Kill children processes
|
||||
for pid in self.children_pids:
|
||||
os.kill(pid, signal.SIGINT)
|
||||
# Propagate the error from the child process
|
||||
raise ChildException(self.error_queue.get())
|
||||
|
||||
|
||||
def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
|
||||
"""Runs a function from a child process."""
|
||||
try:
|
||||
# Initialize the process group
|
||||
init_process_group(proc_rank, world_size)
|
||||
# Run the function
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
except KeyboardInterrupt:
|
||||
# Killed by the parent process
|
||||
pass
|
||||
except Exception:
|
||||
# Propagate exception to the parent process
|
||||
error_queue.put(traceback.format_exc())
|
||||
finally:
|
||||
# Destroy the process group
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
|
||||
"""Runs a function in a multi-proc setting (unless num_proc == 1)."""
|
||||
# There is no need for multi-proc in the single-proc case
|
||||
fun_kwargs = fun_kwargs if fun_kwargs else {}
|
||||
if num_proc == 1:
|
||||
fun(*fun_args, **fun_kwargs)
|
||||
return
|
||||
# Handle errors from training subprocesses
|
||||
error_queue = multiprocessing.SimpleQueue()
|
||||
error_handler = ErrorHandler(error_queue)
|
||||
# Run each training subprocess
|
||||
ps = []
|
||||
for i in range(num_proc):
|
||||
p_i = multiprocessing.Process(
|
||||
target=run, args=(i, num_proc, error_queue, fun, fun_args, fun_kwargs)
|
||||
)
|
||||
ps.append(p_i)
|
||||
p_i.start()
|
||||
error_handler.add_child(p_i.pid)
|
||||
# Wait for each subprocess to finish
|
||||
for p in ps:
|
||||
p.join()
|
Reference in New Issue
Block a user