mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Allow gr.Interface.from_pipeline()
and gr.load()
to work within gr.Blocks()
(#5231)
* add test * external * fixed in external * add changeset * close demo --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
b2f49cfa36
commit
87f1c2b4ac
5
.changeset/poor-papayas-behave.md
Normal file
5
.changeset/poor-papayas-behave.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Allow `gr.Interface.from_pipeline()` and `gr.load()` to work within `gr.Blocks()`
|
@ -152,8 +152,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
pipelines = {
|
||||
"audio-classification": {
|
||||
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
||||
"outputs": components.Label(label="Class"),
|
||||
"inputs": components.Audio(
|
||||
source="upload", type="filepath", label="Input", render=False
|
||||
),
|
||||
"outputs": components.Label(label="Class", render=False),
|
||||
"preprocess": lambda i: to_binary,
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
@ -161,34 +163,38 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
},
|
||||
"audio-to-audio": {
|
||||
# example model: facebook/xm_transformer_sm_all-en
|
||||
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
||||
"outputs": components.Audio(label="Output"),
|
||||
"inputs": components.Audio(
|
||||
source="upload", type="filepath", label="Input", render=False
|
||||
),
|
||||
"outputs": components.Audio(label="Output", render=False),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
# example model: facebook/wav2vec2-base-960h
|
||||
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"inputs": components.Audio(
|
||||
source="upload", type="filepath", label="Input", render=False
|
||||
),
|
||||
"outputs": components.Textbox(label="Output", render=False),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": lambda r: r.json()["text"],
|
||||
},
|
||||
"conversational": {
|
||||
"inputs": [components.Textbox(), components.State()], # type: ignore
|
||||
"outputs": [components.Chatbot(), components.State()], # type: ignore
|
||||
"inputs": [components.Textbox(render=False), components.State(render=False)], # type: ignore
|
||||
"outputs": [components.Chatbot(render=False), components.State(render=False)], # type: ignore
|
||||
"preprocess": chatbot_preprocess,
|
||||
"postprocess": chatbot_postprocess,
|
||||
},
|
||||
"feature-extraction": {
|
||||
# example model: julien-c/distilbert-feature-extraction
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Dataframe(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Dataframe(label="Output", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0],
|
||||
},
|
||||
"fill-mask": {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["token_str"]: i["score"] for i in r.json()}
|
||||
@ -196,8 +202,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
},
|
||||
"image-classification": {
|
||||
# Example: google/vit-base-patch16-224
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"inputs": components.Image(
|
||||
type="filepath", label="Input Image", render=False
|
||||
),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
@ -206,27 +214,27 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
"question-answering": {
|
||||
# Example: deepset/xlm-roberta-base-squad2
|
||||
"inputs": [
|
||||
components.Textbox(lines=7, label="Context"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Textbox(lines=7, label="Context", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": [
|
||||
components.Textbox(label="Answer"),
|
||||
components.Label(label="Score"),
|
||||
components.Textbox(label="Answer", render=False),
|
||||
components.Label(label="Score", render=False),
|
||||
],
|
||||
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
|
||||
"postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
|
||||
},
|
||||
"summarization": {
|
||||
# Example: facebook/bart-large-cnn
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Summary"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Summary", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["summary_text"],
|
||||
},
|
||||
"text-classification": {
|
||||
# Example: distilbert-base-uncased-finetuned-sst-2-english
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: postprocess_label(
|
||||
{i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
|
||||
@ -234,32 +242,34 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
},
|
||||
"text-generation": {
|
||||
# Example: gpt2
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Output", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
"text2text-generation": {
|
||||
# Example: valhalla/t5-small-qa-qg-hl
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Generated Text", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
"translation": {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Translation"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Translation", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["translation_text"],
|
||||
},
|
||||
"zero-shot-classification": {
|
||||
# Example: facebook/bart-large-mnli
|
||||
"inputs": [
|
||||
components.Textbox(label="Input"),
|
||||
components.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
components.Checkbox(label="Allow multiple true classes"),
|
||||
components.Textbox(label="Input", render=False),
|
||||
components.Textbox(
|
||||
label="Possible class names (" "comma-separated)", render=False
|
||||
),
|
||||
components.Checkbox(label="Allow multiple true classes", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"inputs": i,
|
||||
"parameters": {"candidate_labels": c, "multi_class": m},
|
||||
@ -275,15 +285,18 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
# Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
"inputs": [
|
||||
components.Textbox(
|
||||
value="That is a happy person", label="Source Sentence"
|
||||
value="That is a happy person",
|
||||
label="Source Sentence",
|
||||
render=False,
|
||||
),
|
||||
components.Textbox(
|
||||
lines=7,
|
||||
placeholder="Separate each sentence by a newline",
|
||||
label="Sentences to compare to",
|
||||
render=False,
|
||||
),
|
||||
],
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda src, sentences: {
|
||||
"inputs": {
|
||||
"source_sentence": src,
|
||||
@ -296,32 +309,32 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
},
|
||||
"text-to-speech": {
|
||||
# Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Audio(label="Audio"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Audio(label="Audio", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
"text-to-image": {
|
||||
# example model: osanseviero/BigGAN-deep-128
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Image(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Image(label="Output", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
"token-classification": {
|
||||
# example model: huggingface-course/bert-finetuned-ner
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.HighlightedText(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.HighlightedText(label="Output", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
|
||||
},
|
||||
"document-question-answering": {
|
||||
# example model: impira/layoutlm-document-qa
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Document"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Image(type="filepath", label="Input Document", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"outputs": components.Label(label="Label", render=False),
|
||||
"preprocess": lambda img, q: {
|
||||
"inputs": {
|
||||
"image": extract_base64_data(img), # Extract base64 data
|
||||
@ -335,10 +348,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
"visual-question-answering": {
|
||||
# example model: dandelin/vilt-b32-finetuned-vqa
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Image"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Image(type="filepath", label="Input Image", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"outputs": components.Label(label="Label", render=False),
|
||||
"preprocess": lambda img, q: {
|
||||
"inputs": {
|
||||
"image": extract_base64_data(img),
|
||||
@ -351,8 +364,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
},
|
||||
"image-to-text": {
|
||||
# example model: Salesforce/blip-image-captioning-base
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"inputs": components.Image(
|
||||
type="filepath", label="Input Image", render=False
|
||||
),
|
||||
"outputs": components.Textbox(label="Generated Text", render=False),
|
||||
"preprocess": to_binary,
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
@ -369,9 +384,10 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
|
||||
type="pandas",
|
||||
headers=col_names,
|
||||
col_count=(len(col_names), "fixed"),
|
||||
render=False,
|
||||
),
|
||||
"outputs": components.Dataframe(
|
||||
label="Predictions", type="array", headers=["prediction"]
|
||||
label="Predictions", type="array", headers=["prediction"], render=False
|
||||
),
|
||||
"preprocess": rows_to_cols,
|
||||
"postprocess": lambda r: {
|
||||
|
@ -35,9 +35,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
source="microphone",
|
||||
type="filepath",
|
||||
label="Input",
|
||||
render=False,
|
||||
),
|
||||
"outputs": components.Label(label="Class"),
|
||||
"outputs": components.Label(label="Class", render=False),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
@ -47,9 +50,9 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Audio(
|
||||
source="microphone", type="filepath", label="Input"
|
||||
source="microphone", type="filepath", label="Input", render=False
|
||||
),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"outputs": components.Textbox(label="Output", render=False),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: r["text"],
|
||||
}
|
||||
@ -57,8 +60,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Dataframe(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Dataframe(label="Output", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0],
|
||||
}
|
||||
@ -66,8 +69,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.fill_mask.FillMaskPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
||||
}
|
||||
@ -75,8 +78,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.image_classification.ImageClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Label(type="confidences", label="Classification"),
|
||||
"inputs": components.Image(
|
||||
type="filepath", label="Input Image", render=False
|
||||
),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
@ -85,12 +90,12 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(lines=7, label="Context"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Textbox(lines=7, label="Context", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": [
|
||||
components.Textbox(label="Answer"),
|
||||
components.Label(label="Score"),
|
||||
components.Textbox(label="Answer", render=False),
|
||||
components.Label(label="Score", render=False),
|
||||
],
|
||||
"preprocess": lambda c, q: {"context": c, "question": q},
|
||||
"postprocess": lambda r: (r["answer"], r["score"]),
|
||||
@ -99,8 +104,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.text2text_generation.SummarizationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(lines=7, label="Input"),
|
||||
"outputs": components.Textbox(label="Summary"),
|
||||
"inputs": components.Textbox(lines=7, label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Summary", render=False),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0]["summary_text"],
|
||||
}
|
||||
@ -108,8 +113,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.text_classification.TextClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
@ -117,8 +122,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.text_generation.TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Output"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Output", render=False),
|
||||
"preprocess": lambda x: {"text_inputs": x},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
@ -126,8 +131,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.text2text_generation.TranslationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Translation"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Translation", render=False),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["translation_text"],
|
||||
}
|
||||
@ -135,8 +140,8 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Textbox(label="Generated Text"),
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.Textbox(label="Generated Text", render=False),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
@ -145,11 +150,13 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Textbox(label="Input"),
|
||||
components.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
components.Checkbox(label="Allow multiple true classes"),
|
||||
components.Textbox(label="Input", render=False),
|
||||
components.Textbox(
|
||||
label="Possible class names (" "comma-separated)", render=False
|
||||
),
|
||||
components.Checkbox(label="Allow multiple true classes", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"outputs": components.Label(label="Classification", render=False),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"sequences": i,
|
||||
"candidate_labels": c,
|
||||
@ -165,10 +172,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Document"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Image(type="filepath", label="Input Document", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Label"),
|
||||
"outputs": components.Label(label="Label", render=False),
|
||||
"preprocess": lambda img, q: {"image": img, "question": q},
|
||||
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
|
||||
}
|
||||
@ -177,10 +184,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": [
|
||||
components.Image(type="filepath", label="Input Image"),
|
||||
components.Textbox(label="Question"),
|
||||
components.Image(type="filepath", label="Input Image", render=False),
|
||||
components.Textbox(label="Question", render=False),
|
||||
],
|
||||
"outputs": components.Label(label="Score"),
|
||||
"outputs": components.Label(label="Score", render=False),
|
||||
"preprocess": lambda img, q: {"image": img, "question": q},
|
||||
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
|
||||
}
|
||||
@ -188,8 +195,10 @@ def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore
|
||||
):
|
||||
pipeline_info = {
|
||||
"inputs": components.Image(type="filepath", label="Input Image"),
|
||||
"outputs": components.Textbox(label="Text"),
|
||||
"inputs": components.Image(
|
||||
type="filepath", label="Input Image", render=False
|
||||
),
|
||||
"outputs": components.Textbox(label="Text", render=False),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
|
@ -236,8 +236,9 @@ class TestLoadInterface:
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_conversational(self):
|
||||
io = gr.load("models/microsoft/DialoGPT-medium")
|
||||
def test_conversational_in_blocks(self):
|
||||
with gr.Blocks() as io:
|
||||
gr.load("models/microsoft/DialoGPT-medium")
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
assert app.state_holder == {}
|
||||
|
@ -5,9 +5,21 @@ import gradio as gr
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
class TestLoadFromPipeline:
|
||||
def test_text_to_text_model_from_pipeline(self):
|
||||
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
io = gr.Interface.from_pipeline(pipe)
|
||||
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
|
||||
assert isinstance(output, str)
|
||||
def test_text_to_text_model_from_pipeline():
|
||||
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
io = gr.Interface.from_pipeline(pipe)
|
||||
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_interface_in_blocks():
|
||||
pipe1 = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
pipe2 = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tab("Image Inference"):
|
||||
gr.Interface.from_pipeline(pipe1)
|
||||
with gr.Tab("Image Inference"):
|
||||
gr.Interface.from_pipeline(pipe2)
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
demo.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user