Knowhow/Vision

Ultralytics YOLOv8 NMS 포함해서 onnx로 변환하기

침닦는수건 2024. 7. 17. 18:54
반응형

https://github.com/ultralytics/ultralytics/tree/main]

 

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite - ultralytics/ultralytics

github.com

 

Detection 모델하면 떠오르는 YOLO 시리즈를 제공하는 툴로 ultralytics가 있다. 완성도 있는 팩토링으로 단 몇 줄만으로 detector를 학습시킬 수 있기 때문에 굉장히 편하고 좋다. 

 

ONNX 및 TensoRT 변환도 지원하는 친절한 모습을 보이고 있는데, 이 중 ONNX 변환은 생각만큼 친절하지 않았다. 

 

간단히 말하면 model.export(format="onnx") 호출로 onnx로 변환해주긴 하는데 NMS가 빠져있는 형태로 변환해준다. 

 

그 결과 ONNX 변환 모델을 inference 해보면, [x, y, w, h, c, c0, c2, c3, ...] 과 같이 bounding box 정보와 class 별 confidence를 나타내는 [N x (4+C)] 형태가 아니라, [(4+C) x 8400] 과 같이 이상한 dimension으로 뱉어준다. 

  • 참고로 8400은 YOLOv8이 사용하는 anchor의 개수로, anchor 마다 추정한 결과를 NMS 없이 통째로 뱉어준 것이다. 
  • 게다가, 가끔 4+C가 아니라 4+C-1과 같이 class 개수가 안 맞는 형태로 변환될 때도 있다. 
  • NMS가 빠져있으면, 사용 환경에서 NMS를 따로 구현해서 후처리해야 하므로 추가 개발이 필요하다. 속도 문제도 심하고.

 

따라서 ultralytics YOLOv8을 ONNX로 변환하고자 할 때는, 지원되는 export() 기능을 쓰면 귀찮은 일이 많아진다. 

 

직접 NMS를 넣어서 직접 ONNX로 변환하는 것이 가장 쉽다. 

 

Ultralytics YOLOv8 ONNX로 직접 변환하기

workflow는 다음과 같다. 

  • NMS 함수를 따로 꺼내서, 약간의 수정하기 (bit 연산자가 있어서 ONNX 변환이 안됨.)
  • torch model을 fp32으로 만들어두기 (ultralytics는 자동으로 모델을 fp16으로 저장한다. fp32로 바꿔야 성공하더라.)
  • torch model + NMS 함수를 또 다른 nn.Module class로 만들어 두기
  • ONNX 변환하기

1) NMS 함수 정리

def xywh2xyxy(x):

    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
    top-left corner and (x2, y2) is the bottom-right corner.

    Args:
        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.

    Returns:
        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
    """
    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
    y = torch.empty_like(x)
    # y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)  # faster than clone/copy
    dw = x[..., 2] / 2  # half-width
    dh = x[..., 3] / 2  # half-height
    y[..., 0] = x[..., 0] - dw  # top left x
    y[..., 1] = x[..., 1] - dh  # top left y
    y[..., 2] = x[..., 0] + dw  # bottom right x
    y[..., 3] = x[..., 1] + dh  # bottom right y
    return y

def _get_covariance_matrix(boxes):
    """
    Generating covariance matrix from obbs.

    Args:
        boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.

    Returns:
        (torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes.
    """
    # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
    gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
    a, b, c = gbbs.split(1, dim=-1)
    cos = c.cos()
    sin = c.sin()
    cos2 = cos.pow(2)
    sin2 = sin.pow(2)
    return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin

def batch_probiou(obb1, obb2, eps=1e-7):
    """
    Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.

    Args:
        obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
        obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
    """
    obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
    obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2

    x1, y1 = obb1[..., :2].split(1, dim=-1)
    x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
    a1, b1, c1 = _get_covariance_matrix(obb1)
    a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))

    t1 = (
        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
    ) * 0.25
    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
    t3 = (
        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
        + eps
    ).log() * 0.5
    bd = (t1 + t2 + t3).clamp(eps, 100.0)
    hd = (1.0 - (-bd).exp() + eps).sqrt()
    return 1 - hd
    
def non_max_suppression(
    prediction,
    conf_thres=0.25,
    iou_thres=0.7,
    classes=None,
    agnostic=False,
    multi_label=False,
    labels=(),
    max_det=300,
    nc=0,  # number of classes (optional)
    max_time_img=0.05,
    max_nms=30000,
    max_wh=7680,
    in_place=True,
    rotated=False,
):
    # Checks
    assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
    assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    bs = prediction.shape[0]  # batch size
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    nm = prediction.shape[1] - nc - 4
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    time_limit = 2.0 + max_time_img * bs  # seconds to quit after
    #multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    multi_label =  multi_label and (nc > 1)  # multiple labels per box (adds 0.5ms/img)

    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
    if not rotated:
        if in_place:
            prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy
        else:
            prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]) and not rotated:
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
            v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)

        if multi_label:
            i, j = torch.where(cls > conf_thres)
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        if n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        scores = x[:, 4]  # scores

        boxes = x[:, :4] + c  # boxes (offset by class)
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        i = i[:max_det]  # limit detections

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            break  # time limit exceeded

    return output

 

필요한 함수는 위 4가지다. 복사붙여넣기 해서 사용하길 권장하지만 확인 차 수정사항을 기록하면 다음과 같다. 

  • xywh2xyxy 함수를 torch.Tensor만 취급하도록 변경. input type을 보고 원래 numpy.array도 취급했지만, 쓸 numpy는 쓸일이 없기 때문에 삭제함
  • non_max_suppression 함수 내 한 줄을 다음과 같이 바꿔주고 nms_rotated 함수가 사용되는 줄을 지운다.
### from
multi_label &= nc > 1
### to
multi_label =  multi_label and (nc > 1)  # multiple labels per box (adds 0.5ms/img)

 

빌어먹을 & bit 연산자가 &= inplace operator 형태로 박혀있어서 ONNX가 이걸 변환 못한다. 따라서 and로 바꿔줘야 한다. 

 

2) Torch model fp32으로 만들기

완전 간단하다. 그냥 float() 호출하면 된다. 애초에 fp16으로 저장되어 있어서 fp32로 끌어올릴 경우, 소실되는 weight는 없다.

pretrained = torch.load(args.model_path)
model = pretrained["model"].cuda().float()

 

3) Torch model 과 NMS를 포함한 nn.Module class 새로 만들기

class Wrapper(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        pretrained = torch.load(args.model_path)
        self.core = pretrained["model"].cuda().float()
        self.core.eval()
        self.save_dir = os.path.dirname(args.model_path)


    def forward(self, x):
        outputs = self.core(x)
        output = outputs[0]

        detections = non_max_suppression(output,
                                         conf_thres=0.25,
                                         iou_thres=0.7,
                                         classes=None,
                                         agnostic=False,
                                         rotated=False,
                                         )
        return detections

 

이것도 간단하다. 그냥 위와 같이 한 클래스 안에 몰아넣어두기만 하면 된다. 

 

4) ONNX 변환하기

### image_np : [H, w, 3] 0~1 float32
x = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(dim=0).cuda()

# test = model(x)
torch.onnx.export(model,
                  x,
                  args.save_path,
                  export_params=True,
                  do_constant_folding=True,
                  input_names=['input'],  # 모델의 입력값을 가리키는 이름
                  output_names=['output'],  # 모델의 출력값을 가리키는 이름
                  dynamic_axes={'input': {0: 'batch_size'},  # 가변적인 길이를 가진 차원
                                'output': {0: 'batch_size'}})
print("success")

 

여기서 사소한 주의점. 모델 변환할 때 더미로 들어가는 저 x라는 입력이 쓰레기 값이기 때문에 input type을 신경 안쓰곤 하는데 습관적으로 신경 안쓰면 나중에 inference할 때 input type 때문에 결과가 달라지는 실수를 많이 본다. 

 

무조건 이미지가 0~1 fp32로 표현된 데이터인지, 변환 시 사용하는 dummy, x도 0~1 fp32가 맞는지 항상 확인하자.

 

5) 결과 확인

위와 같이 원래 더럽게 큰 NMS도 안 된 결과를 뱉던 ONNX 모델이

NMS가 적용된, 기대하던 dimension을 내뱉는 ONNX 모델로 변한 걸 볼 수 있다. 

 

끄적끄적

이슈보면 앵무새처럼 복사붙여넣기 답변만 있고, 버전 확인해봐라. 이런 코멘트만 있는게 짜증난다. 진지하게 답변할 거면 하고, 안할 거면 아예 하지말던가....

반응형