Reverts the revert that dropped changes from 2 PRs (#7495)

* Revert "Revert "changes""

This reverts commit 032435368d.

* patch
This commit is contained in:
Abubakar Abid 2024-02-20 18:36:10 -08:00 committed by GitHub
parent 254c7dc9c3
commit ddd4d3e4d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 150 additions and 77 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
feat:Enable Ruff S101

View File

@ -0,0 +1,6 @@
---
"@gradio/dataframe": patch
"gradio": patch
---
fix:ensure Dataframe headers are aligned with content when scrolling

View File

@ -1183,7 +1183,8 @@ class Endpoint:
file_name = utils.decode_base64_to_file(x, dir=save_dir).name
elif isinstance(x, dict):
filepath = x.get("path")
assert filepath is not None, f"The 'path' field is missing in {x}"
if not filepath:
raise ValueError(f"The 'path' field is missing in {x}")
file_name = utils.download_file(
root_url + "file=" + filepath,
hf_token=hf_token,

View File

@ -187,9 +187,12 @@ def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
if "args" in parameter_doc["doc"]:
parameter_doc["args"] = True
parameter_docs.append(parameter_doc)
assert (
len(parameters) == 0
), f"Documentation format for {fn.__name__} documents nonexistent parameters: {', '.join(parameters.keys())}. Valid parameters: {', '.join(signature.parameters.keys())}"
if parameters:
raise ValueError(
f"Documentation format for {fn.__name__} documents "
f"nonexistent parameters: {', '.join(parameters.keys())}. "
f"Valid parameters: {', '.join(signature.parameters.keys())}"
)
if len(returns) == 0:
return_docs = {}
elif len(returns) == 1:
@ -337,7 +340,7 @@ def generate_documentation():
inherited_fn["description"] = extract_instance_attr_doc(
cls, inherited_fn["name"]
)
except (ValueError, AssertionError):
except ValueError:
pass
documentation[mode][i]["fns"].append(inherited_fn)
return documentation

View File

@ -318,7 +318,10 @@ class FileSerializable(Serializable):
else:
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
elif x.get("is_stream"):
assert x["name"] and root_url and save_dir
if not (x["name"] and root_url and save_dir):
raise ValueError(
"name and root_url and save_dir must all be present"
)
if not self.stream or self.stream_name != x["name"]:
self.stream = self._setup_stream(
root_url + "stream/" + x["name"], hf_token=hf_token

View File

@ -377,7 +377,8 @@ async def get_pred_from_sse_v0(
except asyncio.CancelledError:
pass
assert len(done) == 1
if len(done) != 1:
raise ValueError(f"Did not expect {len(done)} tasks to be done.")
for task in done:
return task.result()
@ -407,7 +408,8 @@ async def get_pred_from_sse_v1_v2(
except asyncio.CancelledError:
pass
assert len(done) == 1
if len(done) != 1:
raise ValueError(f"Did not expect {len(done)} tasks to be done.")
for task in done:
exception = task.exception()
if exception:

View File

@ -60,11 +60,11 @@ include = [
[tool.ruff]
extend = "../../pyproject.toml"
[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = [
"gradio_client"
]
[tool.pytest.ini_options]
GRADIO_ANALYTICS_ENABLED = "False"
HF_HUB_DISABLE_TELEMETRY = "1"
HF_HUB_DISABLE_TELEMETRY = "1"

View File

@ -1,6 +1,6 @@
pytest-asyncio
pytest==7.1.2
ruff==0.1.13
ruff==0.2.2
pyright==1.1.327
gradio
pydub==0.25.1

View File

@ -623,7 +623,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
def get_component(self, id: int) -> Component | BlockContext:
comp = self.blocks[id]
assert isinstance(comp, (components.Component, BlockContext)), f"{comp}"
if not isinstance(comp, (components.Component, BlockContext)):
raise TypeError(f"Block with id {id} is not a Component or BlockContext")
return comp
@property
@ -2379,7 +2380,8 @@ Received outputs:
continue
label = component["props"].get("label", f"parameter_{i}")
comp = self.get_component(component["id"])
assert isinstance(comp, components.Component)
if not isinstance(comp, components.Component):
raise TypeError(f"{comp!r} is not a Component")
info = component["api_info"]
example = comp.example_inputs()
python_type = client_utils.json_schema_to_python_type(info)
@ -2409,7 +2411,8 @@ Received outputs:
continue
label = component["props"].get("label", f"value_{o}")
comp = self.get_component(component["id"])
assert isinstance(comp, components.Component)
if not isinstance(comp, components.Component):
raise TypeError(f"{comp!r} is not a Component")
info = component["api_info"]
example = comp.example_inputs()
python_type = client_utils.json_schema_to_python_type(info)

View File

@ -199,7 +199,10 @@ class ChatInterface(Blocks):
textbox.container = False
textbox.show_label = False
textbox_ = textbox.render()
assert isinstance(textbox_, Textbox)
if not isinstance(textbox_, Textbox):
raise TypeError(
f"Expected a gr.Textbox, but got {type(textbox_)}"
)
self.textbox = textbox_
else:
self.textbox = Textbox(

View File

@ -323,6 +323,9 @@ def _create_backend(
"Please pass in a valid component name via the --template option. It must match the name of the python class."
)
if not module:
raise ValueError("Module not found")
readme_contents = textwrap.dedent(
"""
# {package_name}
@ -381,7 +384,6 @@ __all__ = ['{name}']
p = Path(inspect.getfile(gradio)).parent
python_file = backend / f"{name.lower()}.py"
assert module is not None
shutil.copy(
str(p / module / component.python_file_name),

View File

@ -247,7 +247,8 @@ def _publish(
except Exception:
latest_release = None
assert demo_dir
if not demo_dir:
raise ValueError("demo_dir must be set")
demo_path = resolve_demo(demo_dir)
if prefer_local or not latest_release:

View File

@ -78,7 +78,8 @@ def extract_class_source_code(
for node in ast.walk(class_ast):
if isinstance(node, ast.ClassDef) and node.name == class_name:
segment = ast.get_source_segment(code, node)
assert segment
if not segment:
raise ValueError("segment not found")
return segment, node.lineno
return None, None
@ -92,8 +93,9 @@ def create_or_modify_pyi(
current_impl, lineno = extract_class_source_code(source_code, class_name)
assert current_impl
assert lineno
if not (current_impl and lineno):
raise ValueError("Couldn't find class source code")
new_interface = create_pyi(current_impl, events)
pyi_file = source_file.with_suffix(".pyi")

View File

@ -194,7 +194,9 @@ class Audio(
if payload is None:
return payload
assert payload.path
if not payload.path:
raise ValueError("payload path missing")
# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice
temp_file_path = Path(payload.path)

View File

@ -330,7 +330,8 @@ def component(cls_name: str, render: bool) -> Component:
obj = utils.component_or_layout_class(cls_name)(render=render)
if isinstance(obj, BlockContext):
raise ValueError(f"Invalid component: {obj.__class__}")
assert isinstance(obj, Component)
if not isinstance(obj, Component):
raise TypeError(f"Expected a Component instance, but got {obj.__class__}")
return obj
@ -363,5 +364,8 @@ def get_component_instance(
component_obj.render()
elif unrender and component_obj.is_rendered:
component_obj.unrender()
assert isinstance(component_obj, Component)
if not isinstance(component_obj, Component):
raise TypeError(
f"Expected a Component instance, but got {component_obj.__class__}"
)
return component_obj

View File

@ -76,9 +76,10 @@ class Dataset(Component):
]
# Narrow type to Component
assert all(
isinstance(c, Component) for c in self._components
), "All components in a `Dataset` must be subclasses of `Component`"
if not all(isinstance(c, Component) for c in self._components):
raise TypeError(
"All components in a `Dataset` must be subclasses of `Component`"
)
self._components = [c for c in self._components if isinstance(c, Component)]
self.proxy_url = proxy_url
for component in self._components:

View File

@ -152,7 +152,8 @@ class Dropdown(FormComponent):
if payload is None:
return None
elif self.multiselect:
assert isinstance(payload, list)
if not isinstance(payload, list):
raise TypeError("Multiselect dropdown payload must be a list")
return [
choice_values.index(choice) if choice in choice_values else None
for choice in payload

View File

@ -165,7 +165,8 @@ class Video(Component):
"""
if payload is None:
return None
assert payload.video.path
if not payload.video.path:
raise ValueError("Payload path missing")
file_name = Path(payload.video.path)
uploaded_format = file_name.suffix.replace(".", "")
needs_formatting = self.format is not None and uploaded_format != self.format
@ -257,7 +258,8 @@ class Video(Component):
else:
raise Exception(f"Cannot process type as video: {type(value)}")
assert processed_files[0]
if not processed_files[0]:
raise ValueError("Video data missing")
return VideoData(video=processed_files[0], subtitles=processed_files[1])
def _format_video(self, video: str | Path | None) -> FileData | None:

View File

@ -135,9 +135,9 @@ class LogMessage(BaseModel):
class GradioBaseModel(ABC):
def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel:
assert isinstance(self, (BaseModel, RootModel))
if isinstance(dir, str):
dir = pathlib.Path(dir)
if not isinstance(self, (BaseModel, RootModel)):
raise TypeError("must be used in a Pydantic model")
dir = pathlib.Path(dir)
# TODO: Making sure path is unique should be done in caller
def unique_copy(obj: dict):
@ -204,7 +204,8 @@ class FileData(GradioModel):
pathlib.Path(dir).mkdir(exist_ok=True)
new_obj = dict(self)
assert self.path
if not self.path:
raise ValueError("Source file path is not set")
new_name = shutil.copy(self.path, dir)
new_obj["path"] = new_name
return self.__class__(**new_obj)

View File

@ -431,7 +431,10 @@ def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
# Use end_to_end_fn here to properly upload/download all files
predict_fns = []
for fn_index, endpoint in enumerate(client.endpoints):
assert isinstance(endpoint, Endpoint)
if not isinstance(endpoint, Endpoint):
raise TypeError(
f"Expected endpoint to be an Endpoint, but got {type(endpoint)}"
)
helper = client.new_helper(fn_index)
if endpoint.backend_fn:
predict_fns.append(endpoint.make_end_to_end_fn(helper))

View File

@ -194,8 +194,9 @@ def tabular_wrapper(client: InferenceClient, pipeline: str):
# automatically loaded when using the tabular_classification and tabular_regression methods.
# See: https://github.com/huggingface/huggingface_hub/issues/2015
def tabular_inner(data):
assert pipeline in ["tabular_classification", "tabular_regression"]
assert client.model is not None
if pipeline not in ("tabular_classification", "tabular_regression"):
raise TypeError(f"pipeline type {pipeline!r} not supported")
assert client.model # noqa: S101
if pipeline == "tabular_classification":
return client.tabular_classification(data, model=client.model)
else:

View File

@ -421,9 +421,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
assert Path(deserialized).exists()
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
except (AssertionError, TypeError, ValueError):
deserialized_path = Path(deserialized)
if not deserialized_path.exists():
raise FileNotFoundError(f"File {deserialized} not found")
row.append(str(deserialized_path.relative_to(self.dataset_dir)))
except (FileNotFoundError, TypeError, ValueError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)

View File

@ -346,7 +346,8 @@ class Examples:
batch=self.batch,
)
assert self.outputs is not None
if self.outputs is None:
raise ValueError("self.outputs is missing")
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
print(f"Caching example {example_id + 1}/{len(self.examples)}")
@ -405,7 +406,8 @@ class Examples:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
output = []
assert self.outputs is not None
if self.outputs is None:
raise ValueError("self.outputs is missing")
for component, value in zip(self.outputs, example):
value_to_use = value
try:
@ -417,9 +419,10 @@ class Examples:
component, components.File
):
value_to_use = value_as_dict
assert utils.is_update(value_as_dict)
if not utils.is_update(value_as_dict):
raise TypeError("value wasn't an update") # caught below
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError, AssertionError):
except (ValueError, TypeError, SyntaxError):
output.append(component.read_from_flag(value_to_use))
return output
@ -784,9 +787,7 @@ def special_args(
# Inject user token
elif type_hint in (Optional[oauth.OAuthToken], oauth.OAuthToken):
oauth_info = (
session["oauth_info"] if "oauth_info" in session else None
)
oauth_info = session.get("oauth_info", None)
oauth_token = (
oauth.OAuthToken(
token=oauth_info["access_token"],
@ -912,7 +913,8 @@ def make_waveform(
return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
def get_color_gradient(c1, c2, n):
assert n > 1
if n < 1:
raise ValueError("Must have at least one stop in gradient")
c1_rgb = np.array(hex_to_rgb(c1)) / 255
c2_rgb = np.array(hex_to_rgb(c2)) / 255
mix_pcts = [x / (n - 1) for x in range(n)]

View File

@ -179,8 +179,14 @@ class Interface(Blocks):
if additional_inputs is None:
additional_inputs = []
assert isinstance(inputs, (str, list, Component))
assert isinstance(outputs, (str, list, Component))
if not isinstance(inputs, (str, list, Component)):
raise TypeError(
f"inputs must be a string, list, or Component, not {inputs}"
)
if not isinstance(outputs, (str, list, Component)):
raise TypeError(
f"outputs must be a string, list, or Component, not {outputs}"
)
if not isinstance(inputs, list):
inputs = [inputs]
@ -279,7 +285,10 @@ class Interface(Blocks):
InterfaceTypes.OUTPUT_ONLY,
]:
for o in self.output_components:
assert isinstance(o, Component)
if not isinstance(o, Component):
raise TypeError(
f"Output component must be a Component, not {type(o)}"
)
if o.interactive is None:
# Unless explicitly otherwise specified, force output components to
# be non-interactive
@ -418,11 +427,17 @@ class Interface(Blocks):
except (TypeError, ValueError):
param_names = utils.default_input_labels()
for component, param_name in zip(self.input_components, param_names):
assert isinstance(component, Component)
if not isinstance(component, Component):
raise TypeError(
f"Input component must be a Component, not {type(component)}"
)
if component.label is None:
component.label = param_name
for i, component in enumerate(self.output_components):
assert isinstance(component, Component)
if not isinstance(component, Component):
raise TypeError(
f"Output component must be a Component, not {type(component)}"
)
if component.label is None:
if len(self.output_components) == 1:
component.label = "output"
@ -795,7 +810,10 @@ class Interface(Blocks):
flag_components = self.input_components + self.output_components
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
assert isinstance(value, str)
if not isinstance(value, str):
raise TypeError(
f"Flagging option value must be a string, not {value!r}"
)
flag_method = FlagMethod(self.flagging_callback, label, value)
flag_btn.click(
lambda: Button(value="Saving...", interactive=False),

View File

@ -41,7 +41,8 @@ class JupyterReloader(BaseReloader):
@property
def running_app(self) -> App:
assert self.running_demo.server
if not self.running_demo.server:
raise RuntimeError("Server not running")
return self.running_demo.server.running_app # type: ignore
@property

View File

@ -264,8 +264,9 @@ def move_files_to_cache(
payload.path = payload.url
elif not block.proxy_url:
# If the file is on a remote server, do not move it to cache.
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
temp_file_path = block.move_resource_to_block_cache(payload.path)
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
payload.path = temp_file_path
if add_urls:

View File

@ -82,7 +82,8 @@ class RangedFileResponse(Response):
self.stat_result = stat_result
def set_range_headers(self, range: ClosedRange) -> None:
assert self.stat_result
if not self.stat_result:
raise ValueError("No stat result to set range headers with")
total_length = self.stat_result.st_size
content_length = len(range)
self.headers[

View File

@ -399,9 +399,6 @@ class GradioMultiPartParser:
upload_id: str | None = None,
upload_progress: FileUploadProgress | None = None,
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
@ -538,11 +535,11 @@ class GradioMultiPartParser:
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
assert part.file # for type checkers # noqa: S101
await part.file.write(data)
part.file.sha.update(data) # type: ignore
for part in self._file_parts_to_finish:
assert part.file # for type checkers
assert part.file # for type checkers # noqa: S101
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()

View File

@ -171,7 +171,7 @@ class App(FastAPI):
def build_proxy_request(self, url_path):
url = httpx.URL(url_path)
assert self.blocks
assert self.blocks # noqa: S101
# Don't proxy a URL unless it's a URL specifically loaded by the user using
# gr.load() to prevent SSRF or harvesting of HF tokens by malicious Spaces.
is_safe_url = any(
@ -801,7 +801,8 @@ class App(FastAPI):
files_to_copy = []
locations: list[str] = []
for temp_file in form.getlist("files"):
assert isinstance(temp_file, GradioUploadFile)
if not isinstance(temp_file, GradioUploadFile):
raise TypeError("File is not an instance of GradioUploadFile")
if temp_file.filename:
file_name = Path(temp_file.filename).name
name = client_utils.strip_invalid_filename_characters(file_name)

View File

@ -114,7 +114,7 @@ class Tunnel:
if time.time() - start_timestamp >= TUNNEL_TIMEOUT_SECONDS:
_raise_tunnel_error()
assert self.proc is not None
assert self.proc is not None # noqa: S101
if self.proc.stdout is None:
continue

View File

@ -95,7 +95,7 @@ class BaseReloader(ABC):
)
def swap_blocks(self, demo: Blocks):
assert self.running_app.blocks
assert self.running_app.blocks # noqa: S101
# Copy over the blocks to get new components and events but
# not a new queue
self.running_app.blocks._queue.block_fns = demo.fns

View File

@ -178,7 +178,7 @@ const response = await fetch(
const audio_file = await response.blob();
const app = await client("abidlabs/whisper");
const result = await client.predict("/predict", [audio_file]);
const result = await app.predict("/predict", [audio_file]);
```
## Using events

View File

@ -328,5 +328,6 @@
left: 0;
z-index: var(--layer-1);
box-shadow: var(--shadow-drop);
overflow: hidden;
}
</style>

View File

@ -105,6 +105,9 @@ exclude = [
]
[tool.ruff]
exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]
[tool.ruff.lint]
extend-select = [
"ARG",
"B",
@ -114,6 +117,7 @@ extend-select = [
"I",
"N",
"PL",
"S101",
"SIM",
"UP",
"W",
@ -133,9 +137,8 @@ ignore = [
"UP006", # use `list` instead of `List` for type annotations (fails for 3.8)
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
]
exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"demo/*" = [
"ARG",
"E402", # Demos may have imports not at the top
@ -155,9 +158,11 @@ exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]
]
"client/python/test/*" = [
"ARG",
"S101", # tests may use assertions
]
"test/*" = [
"ARG",
"S101", # tests may use assertions
]
[tool.pytest.ini_options]

View File

@ -23,4 +23,4 @@ typing_extensions~=4.0
uvicorn>=0.14.0
typer[all]>=0.9,<1.0
tomlkit==0.12.0
ruff >= 0.1.7
ruff>=0.2.2

View File

@ -161,7 +161,7 @@ requests==2.28.1
# transformers
respx==0.19.2
# via -r requirements.in
ruff==0.1.13
ruff==0.2.2
# via -r requirements.in
rfc3986[idna2008]==1.5.0
# via httpx

View File

@ -276,13 +276,10 @@ class TestLoadInterface:
try:
if resp.status_code != 200:
warnings.warn("Request for speech recognition model failed!")
if (
assert (
"Could not complete request to HuggingFace API"
in resp.json()["error"]
):
pass
else:
raise AssertionError()
not in resp.json()["error"]
)
else:
assert resp.json()["data"] is not None
finally: