4

Python 多进程使用单机多GPU加速推理

 7 months ago
source link: https://xujinzh.github.io/2024/01/17/python-torch-multi-gpu-py/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Python 多进程使用单机多GPU加速推理

发表于2024-01-17|更新于2024-01-17|technologypython
字数总计:593|阅读时长:2分钟|阅读量:9

我这里有个需求就是能够使用本机多个GPU对只能使用单GPU的模型进行推理,以能够释放多GPU的潜力,加速推理,节约时间。因为模型需要使用 torch 进行GPU运算,简单的调用 python 内建的 multiprocessing 无法正常执行,需要使用 torch.multiprocessing,后者支持前者完全相同的操作,但扩展了前者以便通过 multiprocessing.Queue 发送的所有张量将其数据移动到共享内存中,并且只会向其他进程发送一个句柄。

多进程使用多GPU

import math
import os

# 使用torch的multiprocessing,对原生multiprocessing进行了封装
import torch.multiprocessing as mp
from tqdm import tqdm
from fluorescence import add_pseudo_color, detect_fluorescence
from utils import show_allfiles

def split_list_to_nested_list(img_path_list, divider=8):
"""
把一个长列表划分为均匀的子列表,返回包含这些子列表的嵌套列表

@param img_path_list: 图像列表
@param divider: 划分子列表的个数
"""
stride = int(math.ceil(len(img_path_list) / divider))
img_paths_nested = [
img_path_list[i * stride : i * stride + stride] for i in range(divider)
]
# 返回的嵌套列表长度一定等于divider
return img_paths_nested


if __name__ == "__main__":
# 该目录下存放了所有伪彩色图像,等待进行荧光检测
data_path = "/disk0/images"
# 把所有图片的地址找到,组成图片地址列表
img_paths = show_allfiles(path=data_path)
# 均匀划分图片,使用多进程加速检测
img_paths_nested = split_list_to_nested_list(img_path_list=img_paths, divider=divider)
for i, p in enumerate(img_paths_nested):
print(f"第{i}段{len(p)}张图像")
# 设置多进程启动模式
mp.set_start_method("spawn", force=True)
# 检测服务器GPU个数
divider = torch.cuda.device_count()
# 对每一个GPU编号设置设备名
devices = [
torch.device(f"cuda:{i}") if torch.cuda.is_available() else torch.device("cpu")
for i in range(divider)
]
# 超参数
threshold_remove_flu = 8.1
# 创建一个进程列表,并启动每一个进程
processes = []
for dev, imgs in zip(devices, img_paths_nested):
p = mp.Process(
target=detect_fluorescence,
args=(
imgs, # 图像列表
dev, # GPU设备名
"vit_h",
"/disk1/datasets/models/sam/sam_vit_h_4b8939.pth",
threshold_remove_flu,
64,
0.75,
0.75,
100,
1500,
150000,
0.5,
),
name=f"Process-{dev}",
)
p.start()
processes.append(p)
print(f"Started {p.name}")

# 等待所有进程处理完成
for p in processes:
p.join()
print(f"Finished {p.name}")
# 打印所有图像处理完成
print(f"Finished all")

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK