diff --git a/halogen/appel.py b/halogen/appel.py new file mode 100644 index 0000000..34f4d3e --- /dev/null +++ b/halogen/appel.py @@ -0,0 +1,3 @@ +from .test import PaginationResponse + +x = PaginationResponse.serialize({"test"}) diff --git a/halogen/schema.py b/halogen/schema.py index fe76c38..0288952 100644 --- a/halogen/schema.py +++ b/halogen/schema.py @@ -2,7 +2,18 @@ import inspect from collections import OrderedDict, namedtuple -from typing import Iterable, Optional, Union +from typing import ( + Any, + Iterable, + Optional, + Union, + TypeVar, + Generic, + overload, + TYPE_CHECKING, + TypedDict, + List as TypingList, +) from cached_property import cached_property @@ -59,7 +70,9 @@ def get(self, obj, **kwargs): if callable(self.getter): return self.getter(obj, **_get_context(self._getter_argspec, kwargs)) - assert isinstance(self.getter, str), "Accessor must be a function or a dot-separated string." + assert isinstance(self.getter, str), ( + "Accessor must be a function or a dot-separated string." + ) if obj is None: return None @@ -89,7 +102,9 @@ def set(self, obj, value): if callable(self.setter): return self.setter(obj, value) - assert isinstance(self.setter, str), "Accessor must be a function or a dot-separated string." + assert isinstance(self.setter, str), ( + "Accessor must be a function or a dot-separated string." + ) def _set(obj, attr, value): if isinstance(obj, dict): @@ -106,15 +121,57 @@ def _set(obj, attr, value): def __repr__(self): """Accessor representation.""" - return "<{0} getter='{1}', setter='{2}'>".format(self.__class__.__name__, self.getter, self.setter) + return "<{0} getter='{1}', setter='{2}'>".format( + self.__class__.__name__, self.getter, self.setter + ) -class Attr(object): +T = TypeVar("T") + + +class Attr(Generic[T]): """Schema attribute.""" creation_counter = 0 - def __init__(self, attr_type=None, attr=None, required: bool = True, exclude: Optional[Iterable] = None, **kwargs): + @overload + def __init__( + self, + attr_type: "types.Type[T]" = ..., + attr=None, + required: bool = True, + exclude: Optional[Iterable] = None, + **kwargs, + ) -> None: ... + + @overload + def __init__( + self, + attr_type: type["_Schema"] = ..., + attr=None, + required: bool = True, + exclude: Optional[Iterable] = None, + **kwargs, + ) -> None: ... + + @overload + def __init__( + self, + attr_type: object = ..., + attr=None, + required: bool = True, + exclude: Optional[Iterable] = None, + **kwargs, + ) -> None: ... + + def __init__( + self, + attr_type=None, + attr=None, + required: bool = True, + exclude: Optional[Iterable] = None, + **kwargs, + ): """Attribute constructor. :param attr_type: Type, Schema or constant that does the type conversion of the attribute. @@ -132,6 +189,20 @@ def __init__(self, attr_type=None, attr=None, required: bool = True, exclude: Op self.creation_counter = Attr.creation_counter Attr.creation_counter += 1 + @overload + def __get__(self, obj: None, objtype: Optional[type] = None) -> "Attr[T]": ... + + @overload + def __get__(self, obj: object, objtype: Optional[type] = None) -> "Attr[T]": ... + + def __get__(self, obj: Optional[object], objtype: Optional[type] = None): + # Attr instances are removed from Schema classes at runtime, so this is a typing-only aid. + return self + + def __set__(self, obj: object, value: T) -> None: + # Typing-only descriptor hook. + raise AttributeError("Attr descriptors are not set on instances.") + @property def compartment(self): """The key of the compartment this attribute will be placed into (for example: _links or _embedded).""" @@ -184,8 +255,12 @@ def serialize(self, value, **kwargs): raise value = self._default() - value = self.attr_type.serialize(value, **_get_context(self._attr_type_serialize_argspec, kwargs)) - value = self._default() if value is None and hasattr(self, "default") else value + value = self.attr_type.serialize( + value, **_get_context(self._attr_type_serialize_argspec, kwargs) + ) + value = ( + self._default() if value is None and hasattr(self, "default") else value + ) if value in self.exclude: raise ExcludedValueException() return value @@ -282,7 +357,6 @@ def __init__( the target resource. """ if not types.Type.is_type(attr_type): - if attr_type is not None: attr = BYPASS @@ -342,7 +416,9 @@ def __init__(self, attr_type=None, attr=None, required=True, curie=None): :param required: Is this list of links required to be present. :param curie: Link namespace prefix (e.g. ":") or Curie object. """ - super(LinkList, self).__init__(attr_type=attr_type, attr=attr, required=required, curie=curie) + super(LinkList, self).__init__( + attr_type=attr_type, attr=attr, required=required, curie=curie + ) self.attr_type = types.List(self.attr_type) @@ -370,14 +446,22 @@ def __init__(self, name, href, templated=None, type=None): class Embedded(Attr): """Embedded attribute of schema.""" - def __init__(self, attr_type: Union["halogen.Schema", "halogen.types.List"], attr=None, curie=None, required=True): + def __init__( + self, + attr_type: Union["halogen.Schema", "halogen.types.List"], + attr=None, + curie=None, + required=True, + ): """Embedded constructor. :param attr_type: Type, Schema or constant that does the type conversion of the attribute. :param attr: Attribute name, dot-separated attribute path or an `Accessor` instance. :param curie: The curie used for this embedded attribute. """ - super(Embedded, self).__init__(attr_type=attr_type, attr=attr, required=required) + super(Embedded, self).__init__( + attr_type=attr_type, attr=attr, required=required + ) self.curie = curie self.validate() @@ -407,7 +491,9 @@ def validate(self): # Validate self link class_attributes = attribute_type.__dict__.get("__attrs__") if class_attributes is not None and "self" not in class_attributes.keys(): - raise InvalidSchemaDefinition("Invalid HAL standard definition, need `self` link") + raise InvalidSchemaDefinition( + "Invalid HAL standard definition, need `self` link" + ) class _Schema(types.Type): @@ -469,7 +555,9 @@ def deserialize(cls, value, output=None, **kwargs): errors.append(e) except (KeyError, AttributeError): if attr.required: - errors.append(exceptions.ValidationError("Missing attribute.", attr.name)) + errors.append( + exceptions.ValidationError("Missing attribute.", attr.name) + ) if errors: raise exceptions.ValidationError(errors) @@ -489,7 +577,9 @@ def __init__(cls, name, bases, clsattrs): cls.__class_attrs__ = OrderedDict() curies = set([]) - attrs = [(key, value) for key, value in clsattrs.items() if isinstance(value, Attr)] + attrs = [ + (key, value) for key, value in clsattrs.items() if isinstance(value, Attr) + ] attrs.sort(key=lambda attr: attr[1].creation_counter) # Collect the attributes and set their names. @@ -508,7 +598,12 @@ def __init__(cls, name, bases, clsattrs): if curies: link = LinkList( - Schema(href=Attr(), name=Attr(), templated=Attr(required=False), type=Attr(required=False)), + Schema( + href=Attr(), + name=Attr(), + templated=Attr(required=False), + type=Attr(required=False), + ), attr=lambda value: list(curies), required=False, ) @@ -520,6 +615,108 @@ def __init__(cls, name, bases, clsattrs): for base in reversed(cls.__mro__): cls.__attrs__.update(getattr(base, "__class_attrs__", OrderedDict())) + cls.__output_type__ = _build_schema_output_type(cls) + + +def _build_schema_output_type(schema_cls): + fields = {} + optional_fields = set() + compartment_fields = {} + + for attr in schema_cls.__attrs__.values(): + target = fields + if attr.compartment is not None: + target = compartment_fields.setdefault(attr.compartment, {}) + + key = attr.key + target[key] = _python_type_for_attr_type(attr.attr_type) + if not attr.required: + optional_fields.add((attr.compartment, key)) + + for compartment, comp_fields in compartment_fields.items(): + fields[compartment] = _build_typed_dict_for_fields( + f"{schema_cls.__name__}{compartment.title().replace('_', '')}", + comp_fields, + {k for (c, k) in optional_fields if c == compartment}, + ) + + return _build_typed_dict_for_fields( + f"{schema_cls.__name__}Serialized", + fields, + {k for (c, k) in optional_fields if c is None}, + ) + + +def _build_typed_dict_for_fields(name, fields, optional_keys): + try: + from typing import NotRequired + except ImportError: + NotRequired = None + + if NotRequired is not None and optional_keys: + annotations = {} + for key, value in fields.items(): + if key in optional_keys: + annotations[key] = NotRequired[value] + else: + annotations[key] = value + return TypedDict(name, annotations, total=True) + + total = not optional_keys + return TypedDict(name, dict(fields), total=total) + + +def _python_type_for_attr_type(attr_type): + if isinstance(attr_type, halogen.schema._SchemaType): + return getattr(attr_type, "__output_type__", dict) + + if isinstance(attr_type, halogen.types.List): + return TypingList[_python_type_for_attr_type(attr_type.item_type)] + + if isinstance(attr_type, halogen.types.Nullable): + return Optional[_python_type_for_attr_type(attr_type.nested_type)] + + if isinstance(attr_type, halogen.types.String): + return str + + if isinstance(attr_type, halogen.types.Int): + return int + + if isinstance(attr_type, halogen.types.Boolean): + return bool + + if isinstance( + attr_type, + ( + halogen.types.ISODateTime, + halogen.types.ISOUTCDateTime, + halogen.types.ISOUTCDate, + ), + ): + return str + + if isinstance(attr_type, halogen.types.Amount): + return halogen.types.AmountSerialized + + if isinstance(attr_type, halogen.types.Enum): + return str if not attr_type.use_values else object + + if isinstance(attr_type, halogen.types.Type): + return Any + + return type(attr_type) + + +if TYPE_CHECKING: + + class Schema(_Schema, metaclass=_SchemaType): + """Typing-only schema base class.""" + + @classmethod + def serialize(cls, value, **kwargs): ... -Schema = _SchemaType("Schema", (_Schema,), {"__doc__": _Schema.__doc__}) -"""Schema is the basic class used for setting up schemas.""" + @classmethod + def deserialize(cls, value, output=None, **kwargs): ... +else: + Schema = _SchemaType("Schema", (_Schema,), {"__doc__": _Schema.__doc__}) + """Schema is the basic class used for setting up schemas.""" diff --git a/halogen/stubgen.py b/halogen/stubgen.py new file mode 100644 index 0000000..c195248 --- /dev/null +++ b/halogen/stubgen.py @@ -0,0 +1,155 @@ +"""Generate .pyi stubs for Schema subclasses in a module.""" + +from __future__ import annotations + +import argparse +import importlib +import inspect +from pathlib import Path +from typing import Any, get_args, get_origin + +import halogen + + +def _is_typed_dict(obj: Any) -> bool: + return ( + isinstance(obj, type) + and issubclass(obj, dict) + and hasattr(obj, "__annotations__") + and hasattr(obj, "__total__") + ) + + +def _type_to_str(tp: Any) -> str: + if isinstance(tp, str): + return tp + + if _is_typed_dict(tp): + return tp.__name__ + + origin = get_origin(tp) + if origin is None: + if tp is None or tp is type(None): + return "None" + if isinstance(tp, type): + return tp.__name__ + return repr(tp) + + args = get_args(tp) + + if origin is list: + return f"list[{_type_to_str(args[0])}]" + if origin is dict: + return f"dict[{_type_to_str(args[0])}, {_type_to_str(args[1])}]" + if origin is tuple: + return f"tuple[{', '.join(_type_to_str(a) for a in args)}]" + + if str(origin).endswith("typing.Optional"): + return f"Optional[{_type_to_str(args[0])}]" + + if str(origin).endswith("typing.NotRequired"): + return f"NotRequired[{_type_to_str(args[0])}]" + + if str(origin).endswith("typing.Required"): + return f"Required[{_type_to_str(args[0])}]" + + return f"{origin}[{', '.join(_type_to_str(a) for a in args)}]" + + +def _collect_typed_dicts(tp: Any, acc: dict[str, type]) -> None: + if _is_typed_dict(tp): + if tp.__name__ in acc: + return + acc[tp.__name__] = tp + for value in tp.__annotations__.values(): + _collect_typed_dicts(value, acc) + return + + origin = get_origin(tp) + if origin is not None: + for arg in get_args(tp): + _collect_typed_dicts(arg, acc) + + +def _emit_typed_dict(td: type) -> str: + lines = [f"class {td.__name__}(TypedDict, total={td.__total__}):"] + if not td.__annotations__: + lines.append(" pass") + return "\n".join(lines) + + for name, tp in td.__annotations__.items(): + lines.append(f" {name}: {_type_to_str(tp)}") + return "\n".join(lines) + + +def generate(module_name: str, output_path: str | None = None) -> Path: + module = importlib.import_module(module_name) + + schema_classes = [] + for _, obj in inspect.getmembers(module, inspect.isclass): + if obj.__module__ != module.__name__: + continue + if issubclass(obj, halogen.Schema): + schema_classes.append(obj) + + if not schema_classes: + raise SystemExit(f"No Schema subclasses found in {module_name}.") + + if output_path is None: + output_path = str(Path(module.__file__).with_suffix(".pyi")) + + typed_dicts: dict[str, type] = {} + for schema_cls in schema_classes: + td = getattr(schema_cls, "__output_type__", None) + if td is not None: + _collect_typed_dicts(td, typed_dicts) + + lines = [ + "from __future__ import annotations", + "from typing import TypedDict, Optional, NotRequired, Required", + "import halogen", + "", + ] + + for td in typed_dicts.values(): + lines.append(_emit_typed_dict(td)) + lines.append("") + + for schema_cls in schema_classes: + lines.append(f"class {schema_cls.__name__}(halogen.Schema):") + td = getattr(schema_cls, "__output_type__", None) + if td is None: + lines.append(" ...") + lines.append("") + continue + + lines.append(" @classmethod") + lines.append( + f" def serialize(cls, value, **kwargs) -> {td.__name__}: ..." + ) + lines.append(" @classmethod") + lines.append( + f" def deserialize(cls, value, output=None, **kwargs) -> dict: ..." + ) + lines.append("") + + content = "\n".join(lines).rstrip() + "\n" + path = Path(output_path) + path.write_text(content, encoding="utf-8") + return path + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate .pyi stubs for Schema subclasses in a module." + ) + parser.add_argument("module", help="Module path containing Schema subclasses") + parser.add_argument("--out", help="Output .pyi path (default: module path)") + args = parser.parse_args() + + path = generate(args.module, args.out) + print(path) + + +if __name__ == "__main__": + main() diff --git a/halogen/test.py b/halogen/test.py new file mode 100644 index 0000000..b881ebe --- /dev/null +++ b/halogen/test.py @@ -0,0 +1,12 @@ +from typing import final +from halogen import Attr, Schema +from halogen.types import Int, List, Nullable + + +@final +class PaginationResponse(Schema): + """Pagination details.""" + + total = Attr(Int()) + test = Attr(Nullable(Int())) + x = Attr(List(Int())) diff --git a/halogen/test.pyi b/halogen/test.pyi new file mode 100644 index 0000000..19a801e --- /dev/null +++ b/halogen/test.pyi @@ -0,0 +1,14 @@ +from __future__ import annotations +from typing import TypedDict, Optional, NotRequired, Required +import halogen + +class PaginationResponseSerialized(TypedDict, total=True): + total: int + test: typing.Union[int, None] + x: list[int] + +class PaginationResponse(halogen.Schema): + @classmethod + def serialize(cls, value, **kwargs) -> PaginationResponseSerialized: ... + @classmethod + def deserialize(cls, value, output=None, **kwargs) -> dict: ... diff --git a/halogen/types.py b/halogen/types.py index ac9142e..c7db134 100644 --- a/halogen/types.py +++ b/halogen/types.py @@ -4,7 +4,7 @@ import decimal import enum import typing -from typing import Union, Optional, Any +from typing import Union, Optional, Any, TypeVar, Generic, Protocol, TypedDict import dateutil.parser import isodate @@ -16,7 +16,10 @@ from .schema import _Schema -class Type(object): +T = TypeVar("T") + + +class Type(Generic[T]): """Base class for creating types.""" def __init__(self, validators=None, *args, **kwargs): @@ -28,11 +31,11 @@ def __init__(self, validators=None, *args, **kwargs): """ self.validators = [] if validators is None else list(validators) - def serialize(self, value, **kwargs): + def serialize(self, value: T, **kwargs) -> Any: """Serialization of value.""" return value - def deserialize(self, value, **kwargs): + def deserialize(self, value: Any, **kwargs) -> T: """Deserialization of value. :return: Deserialized value. @@ -58,10 +61,16 @@ def is_type(value): return isinstance(value, Type) -class List(Type): +class List(Type[list[T]]): """List type for Halogen schema attribute.""" - def __init__(self, item_type=None, allow_scalar=False, *args, **kwargs): + def __init__( + self, + item_type: Optional["Type[T]"] = None, + allow_scalar: bool = False, + *args, + **kwargs, + ): """Create a new List. :param item_type: Item type or schema. @@ -75,7 +84,9 @@ def serialize(self, value, **kwargs): """Serialize every item of the list.""" if value is None: raise ValueError("None passed, use Nullable type for nullable values") - return super().serialize([self.item_type.serialize(val, **kwargs) for val in value], **kwargs) + return super().serialize( + [self.item_type.serialize(val, **kwargs) for val in value], **kwargs + ) def deserialize(self, value, **kwargs): """Deserialize every item of the list.""" @@ -102,7 +113,7 @@ def deserialize(self, value, **kwargs): return super().deserialize(result, **kwargs) -class ISODateTime(Type): +class ISODateTime(Type[datetime.datetime]): """ISO-8601 datetime schema type.""" type = "datetime" @@ -128,7 +139,7 @@ def deserialize(self, value, **kwargs): return super().deserialize(value) -class ISOUTCDateTime(Type): +class ISOUTCDateTime(Type[datetime.datetime]): """ISO-8601 datetime schema type in UTC timezone.""" type = "datetime" @@ -169,7 +180,7 @@ class ISOUTCDate(ISOUTCDateTime): message = "'{val}' is not a valid ISO-8601 date" -class String(Type): +class String(Type[str]): """String schema type.""" def serialize(self, value, **kwargs): @@ -183,7 +194,7 @@ def deserialize(self, value, **kwargs): return super().deserialize(str(value), **kwargs) -class Int(Type): +class Int(Type[int]): """Int schema type.""" def serialize(self, value, **kwargs): @@ -201,7 +212,7 @@ def deserialize(self, value, **kwargs): return super().deserialize(value, **kwargs) -class Boolean(Type): +class Boolean(Type[bool]): """Boolean schema type.""" def serialize(self, value, **kwargs): @@ -234,12 +245,31 @@ def deserialize(self, value: Union[str, int, bool, None], **kwargs): return super().deserialize(value, **kwargs) -class Amount(Type): +class AmountSerialized(TypedDict): + amount: str + currency: str + + +class AmountLike(Protocol): + currency: str + amount: decimal.Decimal + + def as_quantized(self, *, digits: int) -> "AmountLike": ... + + def as_tuple(self) -> tuple[str, decimal.Decimal]: ... + + +AmountValueT = TypeVar("AmountValueT", bound=AmountLike) + + +class Amount(Type[AmountValueT]): """Amount (money) schema type.""" err_unknown_currency = "'{currency}' is not a valid currency." - def __init__(self, currencies, amount_class, **kwargs): + def __init__( + self, currencies: list[str], amount_class: type[AmountValueT], **kwargs + ): """Initialize new instance of Amount. :param currencies: list of all possible currency codes. @@ -249,7 +279,9 @@ def __init__(self, currencies, amount_class, **kwargs): self.amount_class = amount_class super().__init__(**kwargs) - def amount_object_to_dict(self, amount) -> dict[str, str]: + def amount_object_to_dict( + self, amount: Union[AmountValueT, AmountSerialized] + ) -> AmountSerialized: """Return the dictionary representation of an Amount object. Amount object must have amount and currency properties and as_tuple method which will return (currency, amount) @@ -271,7 +303,9 @@ def amount_object_to_dict(self, amount) -> dict[str, str]: "currency": str(currency), } - def serialize(self, value, **kwargs): + def serialize( + self, value: Union[AmountValueT, AmountSerialized], **kwargs + ) -> AmountSerialized: """Serialize amount. :param value: Amount value. @@ -283,7 +317,9 @@ def serialize(self, value, **kwargs): return super().serialize(self.amount_object_to_dict(value), **kwargs) - def deserialize(self, value, **kwargs): + def deserialize( + self, value: Union[str, AmountSerialized], **kwargs + ) -> AmountValueT: """Deserialize the amount. :param value: Amount in CURRENCYAMOUNT or {"currency": CURRENCY, "amount": AMOUNT} format. For example EUR35.50 @@ -301,7 +337,9 @@ def deserialize(self, value, **kwargs): amount = value[3:] elif isinstance(value, dict): if set(value.keys()) != set(("currency", "amount")): - raise ValueError("Amount object has to have currency and amount fields.") + raise ValueError( + "Amount object has to have currency and amount fields." + ) amount = value["amount"] currency = value["currency"] else: @@ -313,28 +351,32 @@ def deserialize(self, value, **kwargs): try: amount = decimal.Decimal(amount).normalize() except decimal.InvalidOperation: - raise ValueError("'{amount}' cannot be parsed to decimal.".format(amount=amount)) + raise ValueError( + "'{amount}' cannot be parsed to decimal.".format(amount=amount) + ) if amount.as_tuple().exponent < -2: - raise ValueError("'{amount}' has more than 2 decimal places.".format(amount=amount)) + raise ValueError( + "'{amount}' has more than 2 decimal places.".format(amount=amount) + ) value = self.amount_class(currency=currency, amount=amount) return super().deserialize(value) -class Nullable(Type): +class Nullable(Type[Optional[T]]): """Nullable type.""" - def __init__(self, nested_type: Union[type[Type], Type, "_Schema"], *args, **kwargs): + def __init__(self, nested_type: Union["Type[T]", type["_Schema"]], *args, **kwargs): self.nested_type = nested_type super().__init__(*args, **kwargs) - def serialize(self, value: Optional[Any], **kwargs): + def serialize(self, value: Optional[Any], **kwargs) -> Optional[T]: if value is None: return None return self.nested_type.serialize(value, **kwargs) - def deserialize(self, value: Optional[Any], **kwargs): + def deserialize(self, value: Optional[Any], **kwargs) -> Optional[T]: if value is None: return None return self.nested_type.deserialize(value, **kwargs)