from collections import namedtuple from datetime import datetime, timedelta from enum import Enum from pathlib import Path from typing import ClassVar, Dict, List, Literal, Optional, Set, Tuple, Union from uuid import UUID import pytest from gradio_client.utils import json_schema_to_python_type from pydantic import Field, confloat, conint, conlist from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress from gradio.data_classes import GradioModel, GradioRootModel class StringModel(GradioModel): data: str answer: ClassVar = "Dict(data: str)" class IntegerRootModel(GradioRootModel): root: int answer: ClassVar = "int" class FloatModel(GradioModel): data: float answer: ClassVar = "Dict(data: float)" class ListModel(GradioModel): items: List[int] answer: ClassVar = "Dict(items: List[int])" class DictModel(GradioModel): data_dict: Dict[str, int] answer: ClassVar = "Dict(data_dict: Dict(str, int))" class DictModel2(GradioModel): data_dict: Dict[str, List[float]] answer: ClassVar = "Dict(data_dict: Dict(str, List[float]))" class OptionalModel(GradioModel): optional_data: Optional[int] answer: ClassVar = "Dict(optional_data: int | None)" class ColorEnum(Enum): RED = "red" GREEN = "green" BLUE = "blue" class EnumRootModel(GradioModel): color: ColorEnum answer: ClassVar = "Dict(color: Literal[red, green, blue])" class EmailModel(GradioModel): email: EmailStr answer: ClassVar = "Dict(email: str)" class RootWithNestedModel(GradioModel): nested_int: IntegerRootModel nested_enum: EnumRootModel nested_dict: DictModel2 answer: ClassVar = "Dict(nested_int: int, nested_enum: Dict(color: Literal[red, green, blue]), nested_dict: Dict(data_dict: Dict(str, List[float])))" class LessNestedModel(GradioModel): nested_int: int nested_enum: ColorEnum nested_dict: Dict[str, List[Union[int, float]]] answer: ClassVar = "Dict(nested_int: int, nested_enum: Literal[red, green, blue], nested_dict: Dict(str, List[int | float]))" class StatusModel(GradioModel): status: Literal["active", "inactive"] answer: ClassVar = "Dict(status: Literal[active, inactive])" class PointModel(GradioRootModel): root: Tuple[float, float] answer: ClassVar = "Tuple[float, float]" class UuidModel(GradioModel): uuid: UUID answer: ClassVar = "Dict(uuid: str)" class UrlModel(GradioModel): url: AnyUrl answer: ClassVar = "Dict(url: str)" class CustomFieldModel(GradioModel): name: str = Field(..., title="Name of the item", max_length=50) price: float = Field(..., title="Price of the item", gt=0) answer: ClassVar = "Dict(name: str, price: float)" class DurationModel(GradioModel): duration: timedelta answer: ClassVar = "Dict(duration: str)" class IPv4Model(GradioModel): ipv4_address: IPvAnyAddress answer: ClassVar = "Dict(ipv4_address: str)" class DateTimeModel(GradioModel): created_at: datetime updated_at: datetime answer: ClassVar = "Dict(created_at: str, updated_at: str)" class SetModel(GradioModel): unique_numbers: Set[int] answer: ClassVar = "Dict(unique_numbers: List[int])" class ItemModel(GradioModel): name: str price: float class OrderModel(GradioModel): items: List[ItemModel] answer: ClassVar = "Dict(items: List[Dict(name: str, price: float)])" class TemperatureUnitEnum(Enum): CELSIUS = "Celsius" FAHRENHEIT = "Fahrenheit" KELVIN = "Kelvin" class TemperatureConversionModel(GradioModel): temperature: confloat(ge=-273.15, le=1.416808) from_unit: TemperatureUnitEnum to_unit: TemperatureUnitEnum = Field(..., title="Target temperature unit") answer: ClassVar = "Dict(temperature: float, from_unit: Literal[Celsius, Fahrenheit, Kelvin], to_unit: All[Literal[Celsius, Fahrenheit, Kelvin]])" class CartItemModel(GradioModel): product_name: str = Field(..., title="Name of the product", max_length=50) quantity: int = Field(..., title="Quantity of the product", ge=1) price_per_unit: float = Field(..., title="Price per unit", gt=0) class ShoppingCartModel(GradioModel): items: List[CartItemModel] answer: ClassVar = "Dict(items: List[Dict(product_name: str, quantity: int, price_per_unit: float)])" class CoordinateModel(GradioModel): latitude: float longitude: float class PathModel(GradioModel): coordinates: conlist(CoordinateModel, min_length=2, max_length=2) answer: ClassVar = ( "Dict(coordinates: List[Dict(latitude: float, longitude: float)])" ) class CreditCardModel(GradioModel): card_number: conint(ge=1, le=9999999999999999) answer: ClassVar = "Dict(card_number: int)" class TupleListModel(GradioModel): data: List[Tuple[int, str]] answer: ClassVar = "Dict(data: List[Tuple[int, str]]" class PathListModel(GradioModel): file_paths: List[Path] answer: ClassVar = "Dict(file_paths: List[str])" class PostModel(GradioModel): author: str content: str tags: List[str] likes: int = 0 answer: ClassVar = "Dict(author: str, content: str, tags: List[str], likes: int)" Person = namedtuple("Person", ["name", "age"]) class NamedTupleDictionaryModel(GradioModel): people: Dict[str, Person] answer: ClassVar = "Dict(people: Dict(str, Tuple[Any, Any]))" MODELS = [ StringModel, IntegerRootModel, FloatModel, ListModel, DictModel, DictModel2, OptionalModel, EnumRootModel, EmailModel, RootWithNestedModel, LessNestedModel, StatusModel, PointModel, UuidModel, UrlModel, CustomFieldModel, DurationModel, IPv4Model, DateTimeModel, SetModel, OrderModel, TemperatureConversionModel, ShoppingCartModel, PathModel, CreditCardModel, PathListModel, NamedTupleDictionaryModel, ] @pytest.mark.parametrize("model", MODELS) def test_api_info_for_model(model): assert json_schema_to_python_type(model.model_json_schema()) == model.answer