Skip to content

Commit

Permalink
add supervision
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Nov 9, 2024
1 parent 5a1f6a3 commit 3bfc219
Showing 1 changed file with 43 additions and 65 deletions.
108 changes: 43 additions & 65 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ from transformers import (
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# Stage 1. Run Object Detector
# User can replace this object_detector part
# Stage 1. Run Object Detector (User can replace this object_detector part)
person_image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
inputs = person_image_processor(images=image, return_tensors="pt")
Expand Down Expand Up @@ -95,16 +94,15 @@ def pascal_voc_to_coco(bboxes: np.ndarray) -> np.ndarray:

return bboxes

# 0 index indicates human label in COCO
# Human label refers 0 index in COCO dataset
boxes = results[0]["boxes"][results[0]["labels"] == 0]
boxes = [pascal_voc_to_coco(boxes.cpu().numpy())]

# Stage 2. Run ViTPose
config = VitPoseConfig()
image_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")

config = VitPoseConfig()

# Stage 2. Run ViTPose
pixel_values = image_processor(image, boxes=boxes, return_tensors="pt").pixel_values

with torch.no_grad():
Expand All @@ -117,6 +115,35 @@ for pose_result in pose_results:
x, y, score = keypoint
print(f"coordinate : [{x}, {y}], score : {score}")

# Visualization for supervision user
import supervision as sv

key_points = sv.KeyPoints(xy=torch.cat([pose_result['keypoints'].unsqueeze(0) for pose_result in pose_results]).cpu().numpy())

edge_annotator = sv.EdgeAnnotator(
color=sv.Color.GREEN,
thickness=5
)
annotated_frame = edge_annotator.annotate(
scene=image.copy(),
key_points=key_points
)

# Visualization for advanced user
def draw_points(image, keypoints, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
if pose_keypoint_color is not None:
assert len(pose_keypoint_color) == len(keypoints)
for kid, kpt in enumerate(keypoints):
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
if kpt_score > keypoint_score_threshold:
color = tuple(int(c) for c in pose_keypoint_color[kid])
if show_keypoint_weight:
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
transparency = max(0, min(1, kpt_score))
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
else:
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)

def draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
height, width, _ = image.shape
if keypoint_edges is not None and link_colors is not None:
Expand Down Expand Up @@ -153,53 +180,6 @@ def draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_thr
else:
cv2.line(image, pos1, pos2, color, thickness=thickness)

def visualize_keypoints(
image,
pose_result,
keypoint_edges=None,
keypoint_score_threshold=0.3,
keypoint_colors=None,
link_colors=None,
radius=4,
thickness=1,
show_keypoint_weight=False,
):
"""Draw keypoints and links on an image.
Args:
image (`numpy.ndarray`):
The image to draw poses on. It will be modified in-place.
pose_result (`List[numpy.ndarray]`):
The poses to draw. Each element is a set of K keypoints as a Kx3 numpy.ndarray, where each keypoint
is represented as x, y, score.
keypoint_edges (`List[tuple]`, *optional*):
Mapping index of the keypoint_edges links.
keypoint_score_threshold (`float`, *optional*, defaults to 0.3):
Minimum score of keypoints to be shown.
keypoint_colors (`numpy.ndarray`, *optional*):
Color of N keypoints. If None, the keypoints will not be drawn.
link_colors (`numpy.ndarray`, *optional*):
Color of M links. If None, the links will not be drawn.
radius (`int`, *optional*, defaults to 4):
Radius of keypoint circles.
thickness (`int`, *optional*, defaults to 1):
Thickness of lines.
show_keypoint_weight (`bool`, *optional*, defaults to False):
Whether to adjust keypoint and link visibility based on the keypoint scores.
Returns:
`numpy.ndarray`: Image with drawn keypoints and links.
"""
for keypoints in pose_result:
keypoints = np.array(keypoints, copy=False)

# draw each point on image
draw_points(image, keypoints, keypoint_colors, keypoint_score_threshold, radius, show_keypoint_weight)

# draw links
draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight)

return image

# Note: keypoint_edges and color palette are dataset-specific
keypoint_edges = config.edges
Expand Down Expand Up @@ -233,20 +213,18 @@ link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 1
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]

pose_results = [result["keypoints"] for result in pose_results]
numpy_image = np.array(image)

result = visualize_keypoints(
np.array(image),
pose_result,
keypoint_edges=keypoint_edges,
keypoint_score_threshold=0.3,
keypoint_colors=keypoint_colors,
link_colors=link_colors,
radius=4,
thickness=1,
show_keypoint_weight=False,
)
for keypoints in pose_result:
keypoints = np.array(keypoints, copy=False)

# draw each point on image
draw_points(numpy_image, keypoints, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)

# draw links
draw_links(numpy_image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)

pose_image = Image.fromarray(result)
pose_image = Image.fromarray(numpy_image)
pose_image
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>
Expand Down

0 comments on commit 3bfc219

Please sign in to comment.