Add conversational and text-to-image pipelines to Interface.load (#3011)

* Add conversational and text-to-image

* remove redundant state

* CHANGELOG

* Add media to changelog

* Lint

* Format code

* Fix typos in CHANGELOG
This commit is contained in:
Freddy Boulton 2023-01-18 20:49:03 +01:00 committed by GitHub
parent f7f5398e4c
commit 32af45cd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 3 deletions

View File

@ -1,7 +1,31 @@
# Upcoming Release
## New Features:
No changes to highlight.
### Extended support for Interface.load! 🏗️
You can now load `image-to-text` and `conversational` pipelines from the hub!
### Image-to-text Demo
```python
io = gr.Interface.load("models/nlpconnect/vit-gpt2-image-captioning",
api_key="<optional-api-key>")
io.launch()
```
<img width="1087" alt="image" src="https://user-images.githubusercontent.com/41651716/213260197-dc5d80b4-6e50-4b3a-a764-94980930ac38.png">
### conversational Demo
```python
chatbot = gr.Interface.load("models/microsoft/DialoGPT-medium",
api_key="<optional-api-key>")
chatbot.launch()
```
![chatbot_load](https://user-images.githubusercontent.com/41651716/213260220-3eaa25b7-a38b-48c6-adeb-2718bdf297a2.gif)
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3011](https://github.com/gradio-app/gradio/pull/3011)
## Bug Fixes:
* Fixes bug where interpretation event was not configured correctly by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2993](https://github.com/gradio-app/gradio/pull/2993)

View File

@ -63,6 +63,32 @@ def load_blocks_from_repo(
return blocks
def chatbot_preprocess(text, state):
payload = {
"inputs": {"generated_responses": None, "past_user_inputs": None, "text": text}
}
if state is not None:
payload["inputs"]["generated_responses"] = state["conversation"][
"generated_responses"
]
payload["inputs"]["past_user_inputs"] = state["conversation"][
"past_user_inputs"
]
return payload
def chatbot_postprocess(response):
response_json = response.json()
chatbot_value = list(
zip(
response_json["conversation"]["past_user_inputs"],
response_json["conversation"]["generated_responses"],
)
)
return chatbot_value, response_json
def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
@ -76,7 +102,6 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
response.status_code == 200
), f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
p = response.json().get("pipeline_tag")
pipelines = {
"audio-classification": {
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
@ -101,6 +126,12 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
"preprocess": to_binary,
"postprocess": lambda r: r.json()["text"],
},
"conversational": {
"inputs": [components.Textbox(), components.State()], # type: ignore
"outputs": [components.Chatbot(), components.State()], # type: ignore
"preprocess": chatbot_preprocess,
"postprocess": chatbot_postprocess,
},
"feature-extraction": {
# example model: julien-c/distilbert-feature-extraction
"inputs": components.Textbox(label="Input"),
@ -125,6 +156,12 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
),
},
"image-to-text": {
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Textbox(),
"preprocess": to_binary,
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"question-answering": {
# Example: deepset/xlm-roberta-base-squad2
"inputs": [
@ -311,7 +348,12 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
}
kwargs = dict(interface_info, **kwargs)
kwargs["_api_mode"] = True # So interface doesn't run pre/postprocess.
# So interface doesn't run pre/postprocess
# except for conversational interfaces which
# are stateful
kwargs["_api_mode"] = p != "conversational"
interface = gradio.Interface(**kwargs)
return interface

View File

@ -228,6 +228,28 @@ class TestLoadInterface:
except TooManyRequestsError:
pass
def test_image_to_text(self):
io = gr.Interface.load("models/nlpconnect/vit-gpt2-image-captioning")
try:
output = io("gradio/test_data/lion.jpg")
assert isinstance(output, str)
except TooManyRequestsError:
pass
def test_conversational(self):
io = gr.Interface.load("models/microsoft/DialoGPT-medium")
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
assert app.state_holder == {}
response = client.post(
"/api/predict/",
json={"session_hash": "foo", "data": ["Hi!", None], "fn_index": 0},
)
output = response.json()
assert isinstance(output["data"], list)
assert isinstance(output["data"][0], list)
assert isinstance(app.state_holder["foo"], dict)
def test_speech_recognition_model(self):
io = gr.Interface.load("models/facebook/wav2vec2-base-960h")
try: