gradio/scripts/benchmark_queue.py
Abubakar Abid d3b7f73bcf
Update view api page to use Python client (#3765)
* Update view api page

* simplify

* update

* changes

* changes

* updated info

* formatting

* changes

* fixes

* save

* moved

* remove test input

* tweaks

* formatting

* add raw

* serialize

* fixes

* refactor

* fixes

* fixes

* Fetch api

* lower case

* view api

* fix tests

* format

* rough design

* readme

* api docs

* examples

* format

* formatting

* format

* version

* client changes

* formatting

* update client

* more example inputs

* api docs fixes

* remove notebook

* fix demo

* demo notebook

* styling on code snippet

* formatting

* fix audio, model3d

* format

* fix tests

* version

* cleanup

* format

* format

* format

* fixes

* version

* fix tests

* version

* format

* test

* format

* changelog

* changelog

---------

Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
Co-authored-by: aliabd <ali.si3luwa@gmail.com>
2023-04-13 16:20:33 -07:00

119 lines
4.2 KiB
Python

'''
A script that benchmarks the queue performance, can be used to compare the performance
of the queue on a given branch vs the main branch. By default, runs 100 jobs in batches
of 20 and prints the average time per job. The inference time for each job (without the
network overhead of sending/receiving the data) is 0.5 seconds. Each job sends one of:
a text, image, audio, or video input and the output is the same as the input.
Navigate to the root directory of the gradio repo and run:
>> python scripts/benchmark_queue.py
You can specify the number of jobs to run and the batch size with the -n parameter:
>> python scripts/benchmark_queue.py -n 1000
The results are printed to the console, but you can specify a path to save the results
to with the -o parameter:
>> python scripts/benchmark_queue.py -n 1000 -o results.json
'''
import argparse
import asyncio
import json
import random
import time
import pandas as pd
import websockets
import gradio as gr
from gradio_client import media_data
def identity_with_sleep(x):
time.sleep(0.5)
return x
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_txt = gr.Text()
output_text = gr.Text()
submit_text = gr.Button()
submit_text.click(identity_with_sleep, input_txt, output_text, api_name="text")
with gr.Column():
input_img = gr.Image()
output_img = gr.Image()
submit_img = gr.Button()
submit_img.click(identity_with_sleep, input_img, output_img, api_name="img")
with gr.Column():
input_audio = gr.Audio()
output_audio = gr.Audio()
submit_audio = gr.Button()
submit_audio.click(identity_with_sleep, input_audio, output_audio, api_name="audio")
with gr.Column():
input_video = gr.Video()
output_video = gr.Video()
submit_video = gr.Button()
submit_video.click(identity_with_sleep, input_video, output_video, api_name="video")
demo.queue(max_size=50, concurrency_count=20).launch(prevent_thread_lock=True, quiet=True)
FN_INDEX_TO_DATA = {
"text": (0, "A longish text " * 15),
"image": (1, media_data.BASE64_IMAGE),
"audio": (2, media_data.BASE64_AUDIO),
"video": (3, media_data.BASE64_VIDEO)
}
async def get_prediction(host):
async with websockets.connect(host) as ws:
completed = False
name = random.choice(["image", "text", "audio", "video"])
fn_to_hit, data = FN_INDEX_TO_DATA[name]
start = time.time()
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [data], "fn_index": fn_to_hit}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": fn_to_hit, "session_hash": "shdce"}))
if msg["msg"] == "process_completed":
completed = True
end = time.time()
return {"fn_to_hit": name, "duration": end - start}
async def main(host, n_results=100):
results = []
while len(results) < n_results:
batch_results = await asyncio.gather(*[get_prediction(host) for _ in range(20)])
for result in batch_results:
if result:
results.append(result)
data = pd.DataFrame(results).groupby("fn_to_hit").agg({"mean"})
data.columns = data.columns.get_level_values(0)
data = data.reset_index()
data = {"fn_to_hit": data["fn_to_hit"].to_list(), "duration": data["duration"].to_list()}
return data
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upload a demo to a space")
parser.add_argument("-n", "--n_jobs", type=int, help="number of jobs", default=100, required=False)
parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False)
args = parser.parse_args()
host = f"{demo.local_url.replace('http', 'ws')}queue/join"
data = asyncio.run(main(host, n_results=args.n_jobs))
data = dict(zip(data["fn_to_hit"], data["duration"]))
print(data)
if args.output:
print("Writing results to:", args.output)
json.dump(data, open(args.output, "w"))