mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Fixes audio streaming issues (#5179)
* changes * changes * version bump * version bump * version * add changeset --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
f440e7c3bb
commit
6fb92b48a9
5
.changeset/fair-shirts-show.md
Normal file
5
.changeset/fair-shirts-show.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fixes audio streaming issues
|
@ -1344,7 +1344,11 @@ Received outputs:
|
||||
return output
|
||||
|
||||
def handle_streaming_outputs(
|
||||
self, fn_index: int, data: list, session_hash: str | None, run: int | None
|
||||
self,
|
||||
fn_index: int,
|
||||
data: list,
|
||||
session_hash: str | None,
|
||||
run: int | None,
|
||||
) -> list:
|
||||
if session_hash is None or run is None:
|
||||
return data
|
||||
@ -1422,17 +1426,20 @@ Received outputs:
|
||||
is_generating, iterator = None, None
|
||||
else:
|
||||
inputs = self.preprocess_data(fn_index, inputs, state)
|
||||
iterator = iterators.get(fn_index, None) if iterators else None
|
||||
was_generating = iterator is not None
|
||||
old_iterator = iterators.get(fn_index, None) if iterators else None
|
||||
was_generating = old_iterator is not None
|
||||
result = await self.call_function(
|
||||
fn_index, inputs, iterator, request, event_id, event_data
|
||||
fn_index, inputs, old_iterator, request, event_id, event_data
|
||||
)
|
||||
data = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
if result["is_generating"] or was_generating:
|
||||
data = self.handle_streaming_outputs(
|
||||
fn_index, data, session_hash, id(iterator)
|
||||
)
|
||||
is_generating, iterator = result["is_generating"], result["iterator"]
|
||||
if is_generating or was_generating:
|
||||
data = self.handle_streaming_outputs(
|
||||
fn_index,
|
||||
data,
|
||||
session_hash=session_hash,
|
||||
run=id(old_iterator) if was_generating else id(iterator),
|
||||
)
|
||||
|
||||
block_fn.total_runtime += result["duration"]
|
||||
block_fn.total_runs += 1
|
||||
|
@ -343,11 +343,20 @@ class Audio(
|
||||
if isinstance(y, tuple):
|
||||
sample_rate, data = y
|
||||
file_path = self.audio_to_temp_file(
|
||||
data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format
|
||||
data,
|
||||
sample_rate,
|
||||
format="mp3" if self.streaming else self.format,
|
||||
)
|
||||
self.temp_files.add(file_path)
|
||||
else:
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
if isinstance(y, Path):
|
||||
y = str(y)
|
||||
if self.streaming and not y.endswith(".mp3"):
|
||||
sample_rate, data = processing_utils.audio_from_file(y)
|
||||
file_path = self.audio_to_temp_file(data, sample_rate, format="mp3")
|
||||
self.temp_files.add(file_path)
|
||||
else:
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
return {"name": file_path, "data": None, "is_file": True}
|
||||
|
||||
def stream_output(self, y):
|
||||
|
@ -321,10 +321,8 @@ class IOComponent(Component):
|
||||
)
|
||||
return self.pil_to_temp_file(pil_image, dir, format="png")
|
||||
|
||||
def audio_to_temp_file(
|
||||
self, data: np.ndarray, sample_rate: int, dir: str, format: str
|
||||
):
|
||||
temp_dir = Path(dir) / self.hash_bytes(data.tobytes())
|
||||
def audio_to_temp_file(self, data: np.ndarray, sample_rate: int, format: str):
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / self.hash_bytes(data.tobytes())
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
filename = str(temp_dir / f"audio.{format}")
|
||||
processing_utils.audio_to_file(sample_rate, data, filename, format=format)
|
||||
|
@ -1 +1 @@
|
||||
3.40.0
|
||||
3.40.1
|
Loading…
Reference in New Issue
Block a user