Spawn

torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')
  • fn:需要在每个子进程中执行的目标函数。该函数的第一个参数必须是子进程的全局排名(rank)。
  • args:传递给目标函数 fn 的额外参数,以元组形式提供。
  • nprocs:要启动的子进程数量。
  • join:布尔值,指示是否等待所有子进程执行完毕。默认为 True
  • daemon:布尔值,指示是否将子进程设置为守护进程。默认为 False
  • start_method:启动子进程的方法,可选值有 'spawn''fork''forkserver',默认为 'spawn'

当前代码中 spawn 函数调用后的执行情况:

  1. 主进程创建子进程:主进程调用 torch.multiprocessing.spawn 函数后,会创建 num_processes 个子进程。每个子进程都会执行 test_loop 函数。
  2. 传递参数test_loop 函数的第一个参数会被自动设置为该子进程的全局排名(rank),取值范围是 0num_processes - 1。之后,test_loop 函数会接收到 args 元组中的参数,即 num_processesargs
  3. 子进程执行:每个子进程独立执行 test_loop 函数,进行分布式环境初始化、创建 deep_ep.Buffer 实例、执行测试逻辑等操作。在 test_loop 函数中,会调用 init_dist 函数初始化分布式环境,确保各个子进程之间可以进行通信。
  4. 同步与结束:如果 join 参数为 True(默认值),主进程会等待所有子进程执行完毕后才会继续执行后续代码。当所有子进程执行完 test_loop 函数,释放资源并退出后,主进程也会结束。