Declare exports in __all__ for type checking (#10238)

* Declare exports

* add changeset

* type fixes

* more type fixes

* add changeset

* notebooks

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Dmitry Ustalov 2024-12-23 23:33:22 +01:00 committed by GitHub
parent f0cf3b789a
commit 3f192100d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 144 additions and 23 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Declare exports in __all__ for type checking

View File

@ -1356,8 +1356,8 @@ class Endpoint:
f"File {file_path} exceeds the maximum file size of {max_file_size} bytes "
f"set in {component_config.get('label', '') + ''} component."
)
with open(file_path, "rb") as f:
files = [("files", (orig_name.name, f))]
with open(file_path, "rb") as f_:
files = [("files", (orig_name.name, f_))]
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space(\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg))\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio transformers>=4.47.0"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from dataclasses import asdict\n", "from transformers import Tool, ReactCodeAgent # type: ignore\n", "from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore\n", "\n", "# Import tool from Hub\n", "image_generation_tool = Tool.from_space( # type: ignore\n", " space_id=\"black-forest-labs/FLUX.1-schnell\",\n", " name=\"image_generator\",\n", " description=\"Generates an image following your prompt. Returns a PIL Image.\",\n", " api_name=\"/infer\",\n", ")\n", "\n", "llm_engine = HfApiEngine(\"Qwen/Qwen2.5-Coder-32B-Instruct\")\n", "# Initialize the agent with both tools and engine\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, history):\n", " messages = []\n", " yield messages\n", " for msg in stream_to_gradio(agent, prompt):\n", " messages.append(asdict(msg)) # type: ignore\n", " yield messages\n", " yield messages\n", "\n", "\n", "demo = gr.ChatInterface(\n", " interact_with_agent,\n", " chatbot= gr.Chatbot(\n", " label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(\n", " None,\n", " \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\",\n", " ),\n", " ),\n", " examples=[\n", " [\"Generate an image of an astronaut riding an alligator\"],\n", " [\"I am writing a children's book for my daughter. Can you help me with some illustrations?\"],\n", " ],\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -4,7 +4,7 @@ from transformers import Tool, ReactCodeAgent # type: ignore
from transformers.agents import stream_to_gradio, HfApiEngine # type: ignore
# Import tool from Hub
image_generation_tool = Tool.from_space(
image_generation_tool = Tool.from_space( # type: ignore
space_id="black-forest-labs/FLUX.1-schnell",
name="image_generator",
description="Generates an image following your prompt. Returns a PIL Image.",
@ -20,7 +20,7 @@ def interact_with_agent(prompt, history):
messages = []
yield messages
for msg in stream_to_gradio(agent, prompt):
messages.append(asdict(msg))
messages.append(asdict(msg)) # type: ignore
yield messages
yield messages

View File

@ -119,3 +119,121 @@ if not IS_WASM:
from gradio.ipython_ext import load_ipython_extension
__version__ = get_package_version()
__all__ = [
"Accordion",
"AnnotatedImage",
"Annotatedimage",
"Audio",
"BarPlot",
"Blocks",
"BrowserState",
"Brush",
"Button",
"CSVLogger",
"ChatInterface",
"ChatMessage",
"Chatbot",
"Checkbox",
"CheckboxGroup",
"Checkboxgroup",
"ClearButton",
"Code",
"ColorPicker",
"Column",
"CopyData",
"DataFrame",
"Dataframe",
"Dataset",
"DateTime",
"DeletedFileData",
"DownloadButton",
"DownloadData",
"Dropdown",
"DuplicateButton",
"EditData",
"Eraser",
"Error",
"EventData",
"Examples",
"File",
"FileData",
"FileExplorer",
"FileSize",
"Files",
"FlaggingCallback",
"Gallery",
"Group",
"HTML",
"Highlight",
"HighlightedText",
"Highlightedtext",
"IS_WASM",
"Image",
"ImageEditor",
"ImageMask",
"Info",
"Interface",
"JSON",
"Json",
"KeyUpData",
"Label",
"LikeData",
"LinePlot",
"List",
"LoginButton",
"Markdown",
"Matrix",
"MessageDict",
"Mic",
"Microphone",
"Model3D",
"MultimodalTextbox",
"NO_RELOAD",
"Number",
"Numpy",
"OAuthProfile",
"OAuthToken",
"Paint",
"ParamViewer",
"PlayableVideo",
"Plot",
"Progress",
"Radio",
"Request",
"RetryData",
"Row",
"ScatterPlot",
"SelectData",
"SimpleCSVLogger",
"Sketchpad",
"Slider",
"State",
"Tab",
"TabItem",
"TabbedInterface",
"Tabs",
"Text",
"TextArea",
"Textbox",
"Theme",
"Timer",
"UndoData",
"UploadButton",
"Video",
"Warning",
"WaveformOptions",
"__version__",
"close_all",
"deploy",
"get_package_version",
"load",
"load_chat",
"load_ipython_extension",
"mount_gradio_app",
"on",
"render",
"set_static_paths",
"skip",
"update",
]

View File

@ -1475,10 +1475,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
return False
if any(block.stateful for block in dependency.inputs):
return False
if any(block.stateful for block in dependency.outputs):
return False
return True
return not any(block.stateful for block in dependency.outputs)
def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
"""

View File

@ -71,7 +71,7 @@ def get_param_name(param):
def format_none(value):
"""Formats None and NonType values."""
if value is None or value is type(None) or value == "None" or value == "NoneType":
if value is None or value is type(None) or value in ("None", "NoneType"):
return "None"
return value

View File

@ -147,10 +147,10 @@ class NativePlot(Component):
every=every,
inputs=inputs,
)
for key, val in kwargs.items():
if key == "color_legend_title":
for key_, val in kwargs.items():
if key_ == "color_legend_title":
self.color_title = val
if key in [
if key_ in [
"stroke_dash",
"overlay_point",
"x_label_angle",
@ -161,7 +161,7 @@ class NativePlot(Component):
"width",
]:
warnings.warn(
f"Argument '{key}' has been deprecated.", DeprecationWarning
f"Argument '{key_}' has been deprecated.", DeprecationWarning
)
def get_block_name(self) -> str:

View File

@ -276,7 +276,7 @@ class Video(StreamingOutput, Component):
"""
if self.streaming:
return value # type: ignore
if value is None or value == [None, None] or value == (None, None):
if value is None or value in ([None, None], (None, None)):
return None
if isinstance(value, (str, Path)):
processed_files = (self._format_video(value), None)

View File

@ -615,7 +615,7 @@ def load_chat(
[{"role": "system", "content": system_message}] if system_message else []
)
def open_api(message: str, history: list | None) -> str:
def open_api(message: str, history: list | None) -> str | None:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
@ -641,7 +641,8 @@ def load_chat(
)
response = ""
for chunk in stream:
response += chunk.choices[0].delta.content
yield response
if chunk.choices[0].delta.content is not None:
response += chunk.choices[0].delta.content
yield response
return ChatInterface(open_api_stream if streaming else open_api, type="messages")

View File

@ -126,8 +126,7 @@ class ThemeClass:
if (
not prop.startswith("_")
or prop.startswith("_font")
or prop == "_stylesheets"
or prop == "name"
or prop in ("_stylesheets", "name")
) and isinstance(getattr(self, prop), (list, str)):
schema["theme"][prop] = getattr(self, prop)
return schema

View File

@ -1244,7 +1244,7 @@ def diff(old, new):
if obj1 == obj2:
return edits
if type(obj1) != type(obj2):
if type(obj1) is not type(obj2):
edits.append(("replace", path, obj2))
return edits

View File

@ -126,5 +126,5 @@ class TestGallery:
output = gallery.postprocess(
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
)
if type(output.root[0]) == GalleryImage:
if isinstance(output.root[0], GalleryImage):
assert output.root[0].image.path.endswith(".jpeg")

View File

@ -304,7 +304,7 @@ class TestGetTypeHints:
for x in test_objs:
hints = get_type_hints(x)
assert len(hints) == 1
assert hints["s"] == str
assert hints["s"] is str
assert len(get_type_hints(GenericObject())) == 0