mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Allow image uploads to gr.load_chat (#10345)
* changes * add changeset * changes * changes * chagens * changes * changes * chagnges * changes * changes * Update gradio/external.py * changes * simplify tests --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
3750082b26
commit
39f0c23303
5
.changeset/brown-insects-say.md
Normal file
5
.changeset/brown-insects-say.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Allow image uploads to gr.load_chat
|
@ -17,10 +17,12 @@ import huggingface_hub
|
||||
from gradio_client import Client
|
||||
from gradio_client.client import Endpoint
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import encode_url_or_file_to_base64
|
||||
from packaging import version
|
||||
|
||||
import gradio
|
||||
from gradio import components, external_utils, utils
|
||||
from gradio.components.multimodal_textbox import MultimodalValue
|
||||
from gradio.context import Context
|
||||
from gradio.exceptions import (
|
||||
GradioVersionIncompatibleError,
|
||||
@ -31,6 +33,7 @@ from gradio.processing_utils import save_base64_to_cache, to_binary
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.chat_interface import ChatInterface
|
||||
from gradio.components.chatbot import MessageDict
|
||||
from gradio.interface import Interface
|
||||
|
||||
|
||||
@ -586,12 +589,146 @@ def from_spaces_interface(
|
||||
return interface
|
||||
|
||||
|
||||
TEXT_FILE_EXTENSIONS = (
|
||||
".doc",
|
||||
".docx",
|
||||
".rtf",
|
||||
".epub",
|
||||
".odt",
|
||||
".odp",
|
||||
".pptx",
|
||||
".txt",
|
||||
".md",
|
||||
".py",
|
||||
".ipynb",
|
||||
".js",
|
||||
".jsx",
|
||||
".html",
|
||||
".css",
|
||||
".java",
|
||||
".cs",
|
||||
".php",
|
||||
".c",
|
||||
".cc",
|
||||
".cpp",
|
||||
".cxx",
|
||||
".cts",
|
||||
".h",
|
||||
".hh",
|
||||
".hpp",
|
||||
".rs",
|
||||
".R",
|
||||
".Rmd",
|
||||
".swift",
|
||||
".go",
|
||||
".rb",
|
||||
".kt",
|
||||
".kts",
|
||||
".ts",
|
||||
".tsx",
|
||||
".m",
|
||||
".mm",
|
||||
".mts",
|
||||
".scala",
|
||||
".dart",
|
||||
".lua",
|
||||
".pl",
|
||||
".pm",
|
||||
".t",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".bat",
|
||||
".coffee",
|
||||
".csv",
|
||||
".log",
|
||||
".ini",
|
||||
".cfg",
|
||||
".config",
|
||||
".json",
|
||||
".proto",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".sql",
|
||||
)
|
||||
IMAGE_FILE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".gif", ".webp")
|
||||
|
||||
|
||||
def format_conversation(
|
||||
history: list[MessageDict], new_message: str | MultimodalValue
|
||||
) -> list[dict]:
|
||||
conversation = []
|
||||
for message in history:
|
||||
if isinstance(message["content"], str):
|
||||
conversation.append(
|
||||
{"role": message["role"], "content": message["content"]}
|
||||
)
|
||||
elif isinstance(message["content"], tuple):
|
||||
image_message = {
|
||||
"role": message["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": encode_url_or_file_to_base64(message["content"][0])
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
conversation.append(image_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid message format: {message['content']}. Messages must be either strings or tuples."
|
||||
)
|
||||
if isinstance(new_message, str):
|
||||
text = new_message
|
||||
files = []
|
||||
else:
|
||||
text = new_message.get("text", None)
|
||||
files = new_message.get("files", [])
|
||||
image_files, text_encoded = [], []
|
||||
for file in files:
|
||||
if file.lower().endswith(TEXT_FILE_EXTENSIONS):
|
||||
text_encoded.append(file)
|
||||
else:
|
||||
image_files.append(file)
|
||||
|
||||
for image in image_files:
|
||||
conversation.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": encode_url_or_file_to_base64(image)},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
if text or text_encoded:
|
||||
text = text or ""
|
||||
text += "\n".join(
|
||||
[
|
||||
f"\n## {Path(file).name}\n{Path(file).read_text()}"
|
||||
for file in text_encoded
|
||||
]
|
||||
)
|
||||
conversation.append(
|
||||
{"role": "user", "content": [{"type": "text", "text": text}]}
|
||||
)
|
||||
return conversation
|
||||
|
||||
|
||||
@document()
|
||||
def load_chat(
|
||||
base_url: str,
|
||||
model: str,
|
||||
token: str | None = None,
|
||||
*,
|
||||
file_types: Literal["text_encoded", "image"]
|
||||
| list[Literal["text_encoded", "image"]]
|
||||
| None = "text_encoded",
|
||||
system_message: str | None = None,
|
||||
streaming: bool = True,
|
||||
**kwargs,
|
||||
@ -602,9 +739,19 @@ def load_chat(
|
||||
base_url: The base URL of the endpoint, e.g. "http://localhost:11434/v1/"
|
||||
model: The name of the model you are loading, e.g. "llama3.2"
|
||||
token: The API token or a placeholder string if you are using a local model, e.g. "ollama"
|
||||
file_types: The file types allowed to be uploaded by the user. "text_encoded" allows uploading any text-encoded file (which is simply appended to the prompt), and "image" adds image upload support. Set to None to disable file uploads.
|
||||
system_message: The system message to use for the conversation, if any.
|
||||
streaming: Whether the response should be streamed.
|
||||
kwargs: Additional keyword arguments to pass into ChatInterface for customization.
|
||||
Example:
|
||||
import gradio as gr
|
||||
gr.load_chat(
|
||||
"http://localhost:11434/v1/",
|
||||
model="qwen2.5",
|
||||
token="***",
|
||||
file_types=["text_encoded", "image"],
|
||||
system_message="You are a silly assistant.",
|
||||
).launch()
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
@ -618,29 +765,32 @@ def load_chat(
|
||||
start_message = (
|
||||
[{"role": "system", "content": system_message}] if system_message else []
|
||||
)
|
||||
file_types = utils.none_or_singleton_to_list(file_types)
|
||||
|
||||
def open_api(message: str, history: list | None) -> str | None:
|
||||
def open_api(message: str | MultimodalValue, history: list | None) -> str | None:
|
||||
history = history or start_message
|
||||
if len(history) > 0 and isinstance(history[0], (list, tuple)):
|
||||
history = ChatInterface._tuples_to_messages(history)
|
||||
conversation = format_conversation(history, message)
|
||||
return (
|
||||
client.chat.completions.create(
|
||||
model=model,
|
||||
messages=history + [{"role": "user", "content": message}],
|
||||
messages=conversation, # type: ignore
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
def open_api_stream(
|
||||
message: str, history: list | None
|
||||
message: str | MultimodalValue, history: list | None
|
||||
) -> Generator[str, None, None]:
|
||||
history = history or start_message
|
||||
if len(history) > 0 and isinstance(history[0], (list, tuple)):
|
||||
history = ChatInterface._tuples_to_messages(history)
|
||||
conversation = format_conversation(history, message)
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=history + [{"role": "user", "content": message}],
|
||||
messages=conversation, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
response = ""
|
||||
@ -649,6 +799,23 @@ def load_chat(
|
||||
response += chunk.choices[0].delta.content
|
||||
yield response
|
||||
|
||||
supported_extensions = []
|
||||
for file_type in file_types:
|
||||
if file_type == "text_encoded":
|
||||
supported_extensions += TEXT_FILE_EXTENSIONS
|
||||
elif file_type == "image":
|
||||
supported_extensions += IMAGE_FILE_EXTENSIONS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid file type: {file_type}. Must be 'text_encoded' or 'image'."
|
||||
)
|
||||
|
||||
return ChatInterface(
|
||||
open_api_stream if streaming else open_api, type="messages", **kwargs
|
||||
open_api_stream if streaming else open_api,
|
||||
type="messages",
|
||||
multimodal=bool(file_types),
|
||||
textbox=gradio.MultimodalTextbox(file_types=supported_extensions)
|
||||
if file_types
|
||||
else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -16,15 +16,15 @@ $ pip install --upgrade gradio
|
||||
|
||||
## Note for OpenAI-API compatible endpoints
|
||||
|
||||
If you have a chat server serving an OpenAI-API compatible endpoint (e.g. Ollama), you can spin up a ChatInterface in a single line of Python. First, also run `pip install openai`. Then, with your own URL, model, and optional token:
|
||||
If you have a chat server serving an OpenAI-API compatible endpoint (such as Ollama), you can spin up a ChatInterface in a single line of Python. First, also run `pip install openai`. Then, with your own URL, model, and optional token:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token="ollama").launch()
|
||||
gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token="***").launch()
|
||||
```
|
||||
|
||||
If you have your own model, keep reading to see how to create an application around any chat model in Python!
|
||||
Read about `gr.load_chat` in [the docs](https://www.gradio.app/docs/gradio/load_chat). If you have your own model, keep reading to see how to create an application around any chat model in Python!
|
||||
|
||||
## Defining a chat function
|
||||
|
||||
|
@ -523,3 +523,39 @@ def test_load_callable():
|
||||
)
|
||||
|
||||
assert isinstance(result, gr.Blocks)
|
||||
|
||||
|
||||
@patch("openai.OpenAI")
|
||||
def test_load_chat_basic(mock_openai):
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value.choices[
|
||||
0
|
||||
].message.content = "Hello human!"
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
chat = gr.load_chat(
|
||||
"http://fake-api.com/v1",
|
||||
model="test-model",
|
||||
token="fake-token",
|
||||
streaming=False,
|
||||
)
|
||||
response = chat.fn("Hi AI!", None)
|
||||
assert response == "Hello human!"
|
||||
|
||||
|
||||
@patch("openai.OpenAI")
|
||||
def test_load_chat_with_streaming(mock_openai):
|
||||
mock_client = MagicMock()
|
||||
mock_stream = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content=" World"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
|
||||
]
|
||||
mock_client.chat.completions.create.return_value = mock_stream
|
||||
mock_openai.return_value = mock_client
|
||||
chat = gr.load_chat(
|
||||
"http://fake-api.com/v1", model="test-model", token="fake-token", streaming=True
|
||||
)
|
||||
response_stream = chat.fn("Hi!", None)
|
||||
responses = list(response_stream)
|
||||
assert responses == ["Hello", "Hello World", "Hello World!"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user