将detectron2模型改成单参数NHWC
detectron2 导出的模型它所要的输入参数是 Tuple[Tensor[N, C, H, W], Tensor[N, 3]] 这样的形式。第二个 Tensor 的 3 的内容是 H, W, Scale;这个 Scale 是图片预处理前后图片尺寸的比值[1]。如果我们输入的图片没有大小的变化,第二个参数是没必要输入的,做在模型里就好了。另外,Tuple 的输入格式 Triton 不支持。
修改的思路是,导出成 torchscript 模型,用 torch.jit.load 加载,包装一层,再 torch.jit.trace 并保存。
import torch import torch.jit from torch.jit import RecursiveScriptModule class MyModel(torch.nn.Module): def __init__(self, model: RecursiveScriptModule): super(MyModel, self).__init__() self.model = model def forward(self, x): x = x.permute(0, 3, 1, 2) return self.model((x, torch.reshape(torch.tensor([x.shape[2], x.shape[3], 1], dtype=torch.float32, device="cuda"), [1, 3]))) def main(): src_model_ts_filename = "./output1/model.ts" dst_model_ts_filename = "./wrap/model.pt" model = torch.jit.load(src_model_ts_filename) model = MyModel(model) traced = torch.jit.trace(model, torch.randint(0, 255, size=[1, 800, 800, 3], dtype=torch.uint8, device="cuda")) traced.save(dst_model_ts_filename) if __name__ == "__main__": main()