mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
* push * chat interface * remove video artifact * changes * restore * changes * more functions * clean * chat interface * chat interface * Update gradio/chat_interface.py * nit * changes * api fix * chat interface * guide * guide * changes * fixes * Added ChatInterface Examples for Langchain, OpenAI Streaming, and HF's Text Generation Inference (StarChat) * docstring * tests * tests * tests * tests * rename * guide * chatbot * conclusion * demo notebooks * clog * test * chat interface * fixes * functional test * test * notebook * guides * typing * docstring * script * upgrade pyright * upgrade pyright * Update CHANGELOG.md * revert pyright upgrade * typing * redirects * Update CHANGELOG.md * guide * guide * readme * screenshot * add to readme * quickstart * readme * Added transformers open-source LLM example using ChatInterface * Minor nits to guide * Minor tweak to test - use connect fixture * website fixes * nav * guide * fix tests' * fix * type * clear * chat interface * fix test * fix guide * handle edge case * edge case * typing * fix example caching with streaming * typing --------- Co-authored-by: pngwn <hello@pngwn.io> Co-authored-by: Yuvraj Sharma <48665385+yvrjsharma@users.noreply.github.com> Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com> Co-authored-by: aliabd <ali.si3luwa@gmail.com>
108 lines
3.2 KiB
Python
108 lines
3.2 KiB
Python
import time
|
|
|
|
import pytest
|
|
|
|
import gradio as gr
|
|
|
|
|
|
def invalid_fn(message):
|
|
return message
|
|
|
|
|
|
def double(message, history):
|
|
return message + " " + message
|
|
|
|
|
|
def stream(message, history):
|
|
for i in range(len(message)):
|
|
yield message[: i + 1]
|
|
|
|
|
|
def count(message, history):
|
|
return str(len(history))
|
|
|
|
|
|
class TestInit:
|
|
def test_no_fn(self):
|
|
with pytest.raises(TypeError):
|
|
gr.ChatInterface()
|
|
|
|
def test_invalid_fn_inputs(self):
|
|
with pytest.warns(UserWarning):
|
|
gr.ChatInterface(invalid_fn)
|
|
|
|
def test_configuring_buttons(self):
|
|
chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None)
|
|
assert chatbot.submit_btn is None
|
|
assert chatbot.retry_btn is None
|
|
|
|
def test_events_attached(self):
|
|
chatbot = gr.ChatInterface(double)
|
|
dependencies = chatbot.dependencies
|
|
textbox = chatbot.textbox._id
|
|
assert next(
|
|
(
|
|
d
|
|
for d in dependencies
|
|
if d["targets"] == [textbox] and d["trigger"] == "submit"
|
|
),
|
|
None,
|
|
)
|
|
for btn_id in [
|
|
chatbot.submit_btn._id,
|
|
chatbot.retry_btn._id,
|
|
chatbot.clear_btn._id,
|
|
chatbot.undo_btn._id,
|
|
]:
|
|
assert next(
|
|
(
|
|
d
|
|
for d in dependencies
|
|
if d["targets"] == [btn_id] and d["trigger"] == "click"
|
|
),
|
|
None,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_example_caching(self):
|
|
chatbot = gr.ChatInterface(
|
|
double, examples=["hello", "hi"], cache_examples=True
|
|
)
|
|
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
|
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
|
assert prediction_hello[0][0] == ["hello", "hello hello"]
|
|
assert prediction_hi[0][0] == ["hi", "hi hi"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_example_caching_with_streaming(self):
|
|
chatbot = gr.ChatInterface(
|
|
stream, examples=["hello", "hi"], cache_examples=True
|
|
)
|
|
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
|
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
|
assert prediction_hello[0][0] == ["hello", "hello"]
|
|
assert prediction_hi[0][0] == ["hi", "hi"]
|
|
|
|
|
|
class TestAPI:
|
|
def test_get_api_info(self):
|
|
chatbot = gr.ChatInterface(double)
|
|
api_info = gr.blocks.get_api_info(chatbot.get_config_file())
|
|
assert len(api_info["named_endpoints"]) == 1
|
|
assert len(api_info["unnamed_endpoints"]) == 0
|
|
assert "/chat" in api_info["named_endpoints"]
|
|
|
|
def test_streaming_api(self, connect):
|
|
chatbot = gr.ChatInterface(stream).queue()
|
|
with connect(chatbot) as client:
|
|
job = client.submit("hello")
|
|
while not job.done():
|
|
time.sleep(0.1)
|
|
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
|
|
|
def test_non_streaming_api(self, connect):
|
|
chatbot = gr.ChatInterface(double)
|
|
with connect(chatbot) as client:
|
|
result = client.predict("hello")
|
|
assert result == "hello hello"
|