gradio/demo/rt-detr-object-detection/run.py
Abubakar Abid c3324d7f7f
Fix issues related to examples and example caching in gr.ChatInterface (#9990)
* changes

* changes

* add functional tests

* add changeset

* revert

* example format

* chat interface

* replace attribute with str

* replace attribute with function

* fix tests

* changes

* fix

* more changes

* changes

* changes

* demo

* more changes

* typing

* demos

* test

* changes

* changes

* functional tests

* add changeset

* fix pytest

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
2024-11-20 12:57:59 -08:00

122 lines
3.6 KiB
Python

# type: ignore
import spaces
import gradio as gr
import cv2
from PIL import Image
import torch
import time
import numpy as np
import uuid
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor # type: ignore
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd").to("cuda")
SUBSAMPLE = 2
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
video_codec = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
fps = int(cap.get(cv2.CAP_PROP_FPS))
desired_fps = fps // SUBSAMPLE
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
iterating, frame = cap.read()
n_frames = 0
name = f"output_{uuid.uuid4()}.mp4"
segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore
batch = []
while iterating:
frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if n_frames % SUBSAMPLE == 0:
batch.append(frame)
if len(batch) == 2 * desired_fps:
inputs = image_processor(images=batch, return_tensors="pt").to("cuda")
print(f"starting batch of size {len(batch)}")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
end = time.time()
print("time taken for inference", end - start)
start = time.time()
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(height, width)] * len(batch)),
threshold=conf_threshold,
)
for _, (array, box) in enumerate(zip(batch, boxes)):
pil_image = draw_bounding_boxes(
Image.fromarray(array), box, model, conf_threshold
)
frame = np.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
segment_file.write(frame)
batch = []
segment_file.release()
yield name
end = time.time()
print("time taken for processing boxes", end - start)
name = f"output_{uuid.uuid4()}.mp4"
segment_file = cv2.VideoWriter(
name, video_codec, desired_fps, (width, height)
) # type: ignore
iterating, frame = cap.read()
n_frames += 1
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>RT-DETR</a>
</h1>
"""
)
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
with gr.Column():
output_video = gr.Video(
label="Processed Video", streaming=True, autoplay=True
)
video.upload(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[output_video],
)
gr.Examples(
examples=["3285790-hd_1920_1080_30fps.mp4"],
inputs=[video],
)
if __name__ == "__main__":
demo.launch()