mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Adding a script to benchmark the queue (#3272)
* added benchmark queue script * changelg * fix instructions
This commit is contained in:
parent
5203c5ddb1
commit
f36445522f
@ -15,7 +15,8 @@
|
||||
No changes to highlight.
|
||||
|
||||
## Testing and Infrastructure Changes:
|
||||
No changes to highlight.
|
||||
* Adds a script to benchmark the performance of the queue and adds some instructions on how to use it. By [@freddyaboulton](https://github.com/freddyaboulton) and [@abidlabs](https://github.com/abidlabs) in [PR 3272](https://github.com/gradio-app/gradio/pull/3272)
|
||||
|
||||
|
||||
## Breaking Changes:
|
||||
No changes to highlight.
|
||||
|
116
scripts/benchmark_queue.py
Normal file
116
scripts/benchmark_queue.py
Normal file
@ -0,0 +1,116 @@
|
||||
'''
|
||||
A script that benchmarks the queue performance, can be used to compare the performance
|
||||
of the queue on a given branch vs the main branch. By default, runs 100 jobs in batches
|
||||
of 20 and prints the average time per job. The inference time for each job (without the
|
||||
network overhead of sending/receiving the data) is 0.5 seconds. Each job sends one of:
|
||||
a text, image, audio, or video input and the output is the same as the input.
|
||||
|
||||
Navigate to the root directory of the gradio repo and run:
|
||||
>> python scripts/benchmark_queue.py
|
||||
|
||||
You can specify the number of jobs to run and the batch size with the -n parameter:
|
||||
>> python scripts/benchmark_queue.py -n 1000
|
||||
|
||||
The results are printed to the console, but you can specify a path to save the results
|
||||
to with the -o parameter:
|
||||
>> python scripts/benchmark_queue.py -n 1000 -o results.json
|
||||
'''
|
||||
|
||||
import gradio as gr
|
||||
from gradio import media_data
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import pandas as pd
|
||||
import argparse
|
||||
|
||||
|
||||
def identity_with_sleep(x):
|
||||
time.sleep(0.5)
|
||||
return x
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_txt = gr.Text()
|
||||
output_text = gr.Text()
|
||||
submit_text = gr.Button()
|
||||
submit_text.click(identity_with_sleep, input_txt, output_text, api_name="text")
|
||||
with gr.Column():
|
||||
input_img = gr.Image()
|
||||
output_img = gr.Image()
|
||||
submit_img = gr.Button()
|
||||
submit_img.click(identity_with_sleep, input_img, output_img, api_name="img")
|
||||
with gr.Column():
|
||||
input_audio = gr.Audio()
|
||||
output_audio = gr.Audio()
|
||||
submit_audio = gr.Button()
|
||||
submit_audio.click(identity_with_sleep, input_audio, output_audio, api_name="audio")
|
||||
with gr.Column():
|
||||
input_video = gr.Video()
|
||||
output_video = gr.Video()
|
||||
submit_video = gr.Button()
|
||||
submit_video.click(identity_with_sleep, input_video, output_video, api_name="video")
|
||||
demo.queue(max_size=50, concurrency_count=20).launch(prevent_thread_lock=True, quiet=True)
|
||||
|
||||
|
||||
FN_INDEX_TO_DATA = {
|
||||
"text": (0, "A longish text " * 15),
|
||||
"image": (1, media_data.BASE64_IMAGE),
|
||||
"audio": (2, media_data.BASE64_AUDIO),
|
||||
"video": (3, media_data.BASE64_VIDEO)
|
||||
}
|
||||
|
||||
|
||||
async def get_prediction(host):
|
||||
async with websockets.connect(host) as ws:
|
||||
completed = False
|
||||
name = random.choice(["image", "text", "audio", "video"])
|
||||
fn_to_hit, data = FN_INDEX_TO_DATA[name]
|
||||
start = time.time()
|
||||
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": [data], "fn_index": fn_to_hit}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": fn_to_hit, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
end = time.time()
|
||||
return {"fn_to_hit": name, "duration": end - start}
|
||||
|
||||
|
||||
async def main(host, n_results=100):
|
||||
results = []
|
||||
while len(results) < n_results:
|
||||
batch_results = await asyncio.gather(*[get_prediction(host) for _ in range(20)])
|
||||
for result in batch_results:
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
data = pd.DataFrame(results).groupby("fn_to_hit").agg({"mean"})
|
||||
data.columns = data.columns.get_level_values(0)
|
||||
data = data.reset_index()
|
||||
data = {"fn_to_hit": data["fn_to_hit"].to_list(), "duration": data["duration"].to_list()}
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Upload a demo to a space")
|
||||
parser.add_argument("-n", "--n_jobs", type=int, help="number of jobs", default=100, required=False)
|
||||
parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
host = f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
data = asyncio.run(main(host, n_results=args.n_jobs))
|
||||
data = dict(zip(data["fn_to_hit"], data["duration"]))
|
||||
|
||||
print(data)
|
||||
|
||||
if args.output:
|
||||
print("Writing results to:", args.output)
|
||||
json.dump(data, open(args.output, "w"))
|
Loading…
Reference in New Issue
Block a user