Rotated bounding box NMS implementation for CPU#9450
Rotated bounding box NMS implementation for CPU#9450zy1git wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9450
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ae5fb41 with merge base d7400a3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| auto ovr = single_box_iou_rotated<scalar_t>( | ||
| dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>()); | ||
| if (ovr >= iou_threshold) { |
There was a problem hiding this comment.
Flagging that this is different from the iou threshold comparison we have in the non-rotated case:
See my other comment about unifying the implementation, which should resolve this as a consequence.
| namespace { | ||
|
|
||
| template <typename scalar_t> | ||
| at::Tensor nms_rotated_cpu_kernel( |
There was a problem hiding this comment.
This is exactly the same implementation we already have for the non-rotated case, the only difference being the iou computation:
Could we consider fusing the two implementations, perhaps templating over the iou computation function?
| return torch.ops.torchvision.nms(boxes, scores, iou_threshold) | ||
|
|
||
|
|
||
| def nms_rotated(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: |
There was a problem hiding this comment.
any reason to expose nms_rotated instead of just handling all this within a single nms function?
For iou, we chose not to expose iou_rotated at the Python layer.
Summary:
Implemented rotated box NMS (Non-Maximum Suppression) for CPU, adapted from Detectron2's nms_rotated implementation. The NMS algorithm is identical to standard NMS — sort by scores, suppress overlapping boxes — but uses
single_box_iou_rotatedfor IoU computation instead of axis-aligned intersection. The public API follows the existing nms op pattern in TorchVision.Test Plan:
Added TestNMSRotated test class adapted from Detectron2's test suite:
0° rotation test: rotated NMS with angle=0 should match reference horizontal NMS (IoU thresholds 0.2, 0.5, 0.8)
90° rotation test: rotated NMS with angle=90 and swapped width/height should match reference horizontal NMS
180° rotation test: rotated NMS with angle=180 should match reference horizontal NMS
TorchScript compatibility test
Results are compared using edit distance (≤ 1 allowed) to account for floating-point precision differences at IoU threshold boundaries.
Run
pytest test/test_ops.py::TestNMSRotated -vAll tests pass locally.