fixed naming

This commit is contained in:
Abubakar Abid 2022-03-07 13:01:43 -06:00
parent c65f9a599f
commit f9034db75b
4 changed files with 20 additions and 20 deletions

View File

@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance
from gradio.process_examples import load_from_cache, process_example
from gradio.routes import PredictRequest
from gradio.routes import PredictBody
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask
@ -559,7 +559,7 @@ class Interface(Launchable):
else:
return predictions
def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]:
def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
flag_index = None
if data.example_id is not None:
if self.cache_examples:

View File

@ -7,7 +7,7 @@ from typing import Dict, Tuple
import requests
from gradio.routes import QueuePushRequest
from gradio.routes import QueuePushBody
DB_FILE = "gradio_queue.db"
@ -109,9 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3]
def push(request: QueuePushRequest) -> Tuple[str, int]:
action = request.action
input_data = json.dumps({'data': request.data})
def push(body: QueuePushBody) -> Tuple[str, int]:
action = body.action
input_data = json.dumps({'data': body.data})
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()

View File

@ -66,7 +66,7 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
class PredictRequest(BaseModel):
class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
@ -80,19 +80,19 @@ class FlagData(BaseModel):
flag_index: Optional[int]
class FlagRequest(BaseModel):
class FlagBody(BaseModel):
data: FlagData
class InterpretRequest(BaseModel):
class InterpretBody(BaseModel):
data: List[Any]
class QueueStatusRequest(BaseModel):
class QueueStatusBody(BaseModel):
hash: str
class QueuePushRequest(BaseModel):
class QueuePushBody(BaseModel):
action: str
data: Any
@ -248,7 +248,7 @@ def api_docs(request: Request):
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(body: PredictRequest, username: str = Depends(get_current_user)):
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if app.launchable.stateful:
session_hash = body.session_hash
state = app.state_holder.get(
@ -271,7 +271,7 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user
@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
async def flag(body: FlagBody, username: str = Depends(get_current_user)):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
await run_in_threadpool(
@ -287,7 +287,7 @@ async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(body: InterpretRequest):
async def interpret(body: InterpretBody):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
raw_input = body.data
@ -301,13 +301,13 @@ async def interpret(body: InterpretRequest):
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(body: QueuePushRequest):
async def queue_push(body: QueuePushBody):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(body: QueueStatusRequest):
async def queue_status(body: QueueStatusBody):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}

View File

@ -4,7 +4,7 @@ import os
import unittest
from gradio import queueing
from gradio.routes import QueuePushRequest
from gradio.routes import QueuePushBody
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -31,10 +31,10 @@ class TestQueuingActions(unittest.TestCase):
queueing.close()
def test_push_pop_status(self):
request = QueuePushRequest(data="test1", action="predict")
request = QueuePushBody(data="test1", action="predict")
hash1, position = queueing.push(request)
self.assertEquals(position, 0)
request = QueuePushRequest(data="test2", action="predict")
request = QueuePushBody(data="test2", action="predict")
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
status, position = queueing.get_status(hash2)
@ -46,7 +46,7 @@ class TestQueuingActions(unittest.TestCase):
self.assertEquals(action, "predict")
def test_jobs(self):
request = QueuePushRequest(data="test1", action="predict")
request = QueuePushBody(data="test1", action="predict")
hash1, _ = queueing.push(request)
hash2, position = queueing.push(request)
self.assertEquals(position, 1)