mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-18 12:50:30 +08:00
fixed naming
This commit is contained in:
parent
c65f9a599f
commit
f9034db75b
@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
from gradio.routes import PredictRequest
|
||||
from gradio.routes import PredictBody
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
@ -559,7 +559,7 @@ class Interface(Launchable):
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]:
|
||||
def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
|
||||
flag_index = None
|
||||
if data.example_id is not None:
|
||||
if self.cache_examples:
|
||||
|
@ -7,7 +7,7 @@ from typing import Dict, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from gradio.routes import QueuePushRequest
|
||||
from gradio.routes import QueuePushBody
|
||||
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
@ -109,9 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]:
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
|
||||
def push(request: QueuePushRequest) -> Tuple[str, int]:
|
||||
action = request.action
|
||||
input_data = json.dumps({'data': request.data})
|
||||
def push(body: QueuePushBody) -> Tuple[str, int]:
|
||||
action = body.action
|
||||
input_data = json.dumps({'data': body.data})
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
|
@ -66,7 +66,7 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
###########
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
class PredictBody(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
example_id: Optional[int]
|
||||
data: List[Any]
|
||||
@ -80,19 +80,19 @@ class FlagData(BaseModel):
|
||||
flag_index: Optional[int]
|
||||
|
||||
|
||||
class FlagRequest(BaseModel):
|
||||
class FlagBody(BaseModel):
|
||||
data: FlagData
|
||||
|
||||
|
||||
class InterpretRequest(BaseModel):
|
||||
class InterpretBody(BaseModel):
|
||||
data: List[Any]
|
||||
|
||||
|
||||
class QueueStatusRequest(BaseModel):
|
||||
class QueueStatusBody(BaseModel):
|
||||
hash: str
|
||||
|
||||
|
||||
class QueuePushRequest(BaseModel):
|
||||
class QueuePushBody(BaseModel):
|
||||
action: str
|
||||
data: Any
|
||||
|
||||
@ -248,7 +248,7 @@ def api_docs(request: Request):
|
||||
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(body: PredictRequest, username: str = Depends(get_current_user)):
|
||||
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||
if app.launchable.stateful:
|
||||
session_hash = body.session_hash
|
||||
state = app.state_holder.get(
|
||||
@ -271,7 +271,7 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user
|
||||
|
||||
|
||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||
async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
|
||||
async def flag(body: FlagBody, username: str = Depends(get_current_user)):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
||||
await run_in_threadpool(
|
||||
@ -287,7 +287,7 @@ async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
|
||||
|
||||
|
||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||
async def interpret(body: InterpretRequest):
|
||||
async def interpret(body: InterpretBody):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
||||
raw_input = body.data
|
||||
@ -301,13 +301,13 @@ async def interpret(body: InterpretRequest):
|
||||
|
||||
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(body: QueuePushRequest):
|
||||
async def queue_push(body: QueuePushBody):
|
||||
job_hash, queue_position = queueing.push(body)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(body: QueueStatusRequest):
|
||||
async def queue_status(body: QueueStatusBody):
|
||||
status, data = queueing.get_status(body.hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.routes import QueuePushRequest
|
||||
from gradio.routes import QueuePushBody
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -31,10 +31,10 @@ class TestQueuingActions(unittest.TestCase):
|
||||
queueing.close()
|
||||
|
||||
def test_push_pop_status(self):
|
||||
request = QueuePushRequest(data="test1", action="predict")
|
||||
request = QueuePushBody(data="test1", action="predict")
|
||||
hash1, position = queueing.push(request)
|
||||
self.assertEquals(position, 0)
|
||||
request = QueuePushRequest(data="test2", action="predict")
|
||||
request = QueuePushBody(data="test2", action="predict")
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
status, position = queueing.get_status(hash2)
|
||||
@ -46,7 +46,7 @@ class TestQueuingActions(unittest.TestCase):
|
||||
self.assertEquals(action, "predict")
|
||||
|
||||
def test_jobs(self):
|
||||
request = QueuePushRequest(data="test1", action="predict")
|
||||
request = QueuePushBody(data="test1", action="predict")
|
||||
hash1, _ = queueing.push(request)
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user