Fixes audio streaming issues (#5179)

* changes

* changes

* version bump

* version bump

* version

* add changeset

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
aliabid94 2023-08-11 10:16:30 -07:00 committed by GitHub
parent f440e7c3bb
commit 6fb92b48a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 15 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fixes audio streaming issues

View File

@ -1344,7 +1344,11 @@ Received outputs:
return output
def handle_streaming_outputs(
self, fn_index: int, data: list, session_hash: str | None, run: int | None
self,
fn_index: int,
data: list,
session_hash: str | None,
run: int | None,
) -> list:
if session_hash is None or run is None:
return data
@ -1422,17 +1426,20 @@ Received outputs:
is_generating, iterator = None, None
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
was_generating = iterator is not None
old_iterator = iterators.get(fn_index, None) if iterators else None
was_generating = old_iterator is not None
result = await self.call_function(
fn_index, inputs, iterator, request, event_id, event_data
fn_index, inputs, old_iterator, request, event_id, event_data
)
data = self.postprocess_data(fn_index, result["prediction"], state)
if result["is_generating"] or was_generating:
data = self.handle_streaming_outputs(
fn_index, data, session_hash, id(iterator)
)
is_generating, iterator = result["is_generating"], result["iterator"]
if is_generating or was_generating:
data = self.handle_streaming_outputs(
fn_index,
data,
session_hash=session_hash,
run=id(old_iterator) if was_generating else id(iterator),
)
block_fn.total_runtime += result["duration"]
block_fn.total_runs += 1

View File

@ -343,11 +343,20 @@ class Audio(
if isinstance(y, tuple):
sample_rate, data = y
file_path = self.audio_to_temp_file(
data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format
data,
sample_rate,
format="mp3" if self.streaming else self.format,
)
self.temp_files.add(file_path)
else:
file_path = self.make_temp_copy_if_needed(y)
if isinstance(y, Path):
y = str(y)
if self.streaming and not y.endswith(".mp3"):
sample_rate, data = processing_utils.audio_from_file(y)
file_path = self.audio_to_temp_file(data, sample_rate, format="mp3")
self.temp_files.add(file_path)
else:
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}
def stream_output(self, y):

View File

@ -321,10 +321,8 @@ class IOComponent(Component):
)
return self.pil_to_temp_file(pil_image, dir, format="png")
def audio_to_temp_file(
self, data: np.ndarray, sample_rate: int, dir: str, format: str
):
temp_dir = Path(dir) / self.hash_bytes(data.tobytes())
def audio_to_temp_file(self, data: np.ndarray, sample_rate: int, format: str):
temp_dir = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data.tobytes())
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str(temp_dir / f"audio.{format}")
processing_utils.audio_to_file(sample_rate, data, filename, format=format)

View File

@ -1 +1 @@
3.40.0
3.40.1