Spawn
torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')fn:需要在每个子进程中执行的目标函数。该函数的第一个参数必须是子进程的全局排名(rank)。args:传递给目标函数fn的额外参数,以元组形式提供。nprocs:要启动的子进程数量。join:布尔值,指示是否等待所有子进程执行完毕。默认为True。daemon:布尔值,指示是否将子进程设置为守护进程。默认为False。start_method:启动子进程的方法,可选值有'spawn'、'fork'和'forkserver',默认为'spawn'。
当前代码中 spawn 函数调用后的执行情况:
- 主进程创建子进程:主进程调用
torch.multiprocessing.spawn函数后,会创建num_processes个子进程。每个子进程都会执行test_loop函数。 - 传递参数:
test_loop函数的第一个参数会被自动设置为该子进程的全局排名(rank),取值范围是0到num_processes - 1。之后,test_loop函数会接收到args元组中的参数,即num_processes和args。 - 子进程执行:每个子进程独立执行
test_loop函数,进行分布式环境初始化、创建deep_ep.Buffer实例、执行测试逻辑等操作。在test_loop函数中,会调用init_dist函数初始化分布式环境,确保各个子进程之间可以进行通信。 - 同步与结束:如果
join参数为True(默认值),主进程会等待所有子进程执行完毕后才会继续执行后续代码。当所有子进程执行完test_loop函数,释放资源并退出后,主进程也会结束。