gradio/test/test_chat_interface.py
aliabid94 91a7a31cd1
Store configs per session in the backend (#8030)
* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changeas

* add changeset

* unrelated fix

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2024-04-22 11:20:05 -07:00

278 lines
10 KiB
Python

import tempfile
from concurrent.futures import wait
from pathlib import Path
from unittest.mock import patch
import pytest
import gradio as gr
def invalid_fn(message):
return message
def double(message, history):
return message + " " + message
async def async_greet(message, history):
return "hi, " + message
def stream(message, history):
for i in range(len(message)):
yield message[: i + 1]
async def async_stream(message, history):
for i in range(len(message)):
yield message[: i + 1]
def count(message, history):
return str(len(history))
def echo_system_prompt_plus_message(message, history, system_prompt, tokens):
response = f"{system_prompt} {message}"
for i in range(min(len(response), int(tokens))):
yield response[: i + 1]
class TestInit:
def test_no_fn(self):
with pytest.raises(TypeError):
gr.ChatInterface()
def test_configuring_buttons(self):
chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None)
assert chatbot.submit_btn is None
assert chatbot.retry_btn is None
def test_concurrency_limit(self):
chat = gr.ChatInterface(double, concurrency_limit=10)
assert chat.concurrency_limit == 10
fns = [fn for fn in chat.fns if fn.name in {"_submit_fn", "_api_submit_fn"}]
assert all(fn.concurrency_limit == 10 for fn in fns)
def test_custom_textbox(self):
def chat():
return "Hello"
gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder="Type Message", container=False, scale=7),
title="Test",
clear_btn="Clear",
)
gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
textbox=gr.MultimodalTextbox(container=False, scale=7),
title="Test",
clear_btn="Clear",
)
def test_events_attached(self):
chatbot = gr.ChatInterface(double)
dependencies = chatbot.fns
textbox = chatbot.textbox._id
submit_btn = chatbot.submit_btn._id
assert next(
(
d
for d in dependencies
if d.targets == [(textbox, "submit"), (submit_btn, "click")]
),
None,
)
for btn_id in [
chatbot.retry_btn._id,
chatbot.clear_btn._id,
chatbot.undo_btn._id,
]:
assert next(
(d for d in dependencies if d.targets[0] == (btn_id, "click")),
None,
)
def test_example_caching(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")
@pytest.mark.asyncio
async def test_example_caching_lazy(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples="lazy"
)
async for _ in chatbot.examples_handler.async_lazy_cache(0, "hello"):
pass
prediction_hello = chatbot.examples_handler.load_from_cache(0)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
with pytest.raises(IndexError):
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hi[0].root[0] == ("hi", "hi hi")
def test_example_caching_async(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")
def test_example_caching_with_streaming(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
def test_example_caching_with_streaming_async(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
def test_default_accordion_params(self):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
)
accordion = [
comp
for comp in chatbot.blocks.values()
if comp.get_config().get("name") == "accordion"
][0]
assert accordion.get_config().get("open") is False
assert accordion.get_config().get("label") == "Additional Inputs"
def test_setting_accordion_params(self, monkeypatch):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
additional_inputs_accordion=gr.Accordion(open=True, label="MOAR"),
)
accordion = [
comp
for comp in chatbot.blocks.values()
if comp.get_config().get("name") == "accordion"
][0]
assert accordion.get_config().get("open") is True
assert accordion.get_config().get("label") == "MOAR"
def test_example_caching_with_additional_inputs(self, monkeypatch):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
class TestAPI:
def test_get_api_info(self):
chatbot = gr.ChatInterface(double)
api_info = chatbot.get_api_info()
assert len(api_info["named_endpoints"]) == 1
assert len(api_info["unnamed_endpoints"]) == 0
assert "/chat" in api_info["named_endpoints"]
def test_streaming_api(self, connect):
chatbot = gr.ChatInterface(stream).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_stream).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_non_streaming_api(self, connect):
chatbot = gr.ChatInterface(double)
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"
def test_non_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_greet)
with connect(chatbot) as client:
result = client.predict("gradio")
assert result == "hi, gradio"
def test_streaming_api_with_additional_inputs(self, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
).queue()
with connect(chatbot) as client:
job = client.submit("hello", "robot", 7)
wait([job])
assert job.outputs() == [
"r",
"ro",
"rob",
"robo",
"robot",
"robot ",
"robot h",
]