Merge pull request #995 from ynput/enhancement/993-add-type-hints-to-attribute-definitions

Chore: Add type hints to attribute definitions
This commit is contained in:
Jakub Trllo 2024-11-07 11:51:41 +01:00 committed by GitHub
commit f88c466562
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,82 +6,58 @@ import json
import copy
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, Optional
import typing
from typing import (
Any,
Optional,
List,
Set,
Dict,
Iterable,
TypeVar,
)
import clique
if typing.TYPE_CHECKING:
from typing import Self, Tuple, Union, TypedDict, Pattern
class EnumItemDict(TypedDict):
label: str
value: Any
EnumItemsInputType = Union[
Dict[Any, str],
List[Tuple[Any, str]],
List[Any],
List[EnumItemDict]
]
class FileDefItemDict(TypedDict):
directory: str
filenames: List[str]
frames: Optional[List[int]]
template: Optional[str]
is_sequence: Optional[bool]
# Global variable which store attribute definitions by type
# - default types are registered on import
_attr_defs_by_type = {}
def register_attr_def_class(cls):
"""Register attribute definition.
Currently registered definitions are used to deserialize data to objects.
Attrs:
cls (AbstractAttrDef): Non-abstract class to be registered with unique
'type' attribute.
Raises:
KeyError: When type was already registered.
"""
if cls.type in _attr_defs_by_type:
raise KeyError("Type \"{}\" was already registered".format(cls.type))
_attr_defs_by_type[cls.type] = cls
def get_attributes_keys(attribute_definitions):
"""Collect keys from list of attribute definitions.
Args:
attribute_definitions (List[AbstractAttrDef]): Objects of attribute
definitions.
Returns:
Set[str]: Keys that will be created using passed attribute definitions.
"""
keys = set()
if not attribute_definitions:
return keys
for attribute_def in attribute_definitions:
if not isinstance(attribute_def, UIDef):
keys.add(attribute_def.key)
return keys
def get_default_values(attribute_definitions):
"""Receive default values for attribute definitions.
Args:
attribute_definitions (List[AbstractAttrDef]): Attribute definitions
for which default values should be collected.
Returns:
Dict[str, Any]: Default values for passed attribute definitions.
"""
output = {}
if not attribute_definitions:
return output
for attr_def in attribute_definitions:
# Skip UI definitions
if not isinstance(attr_def, UIDef):
output[attr_def.key] = attr_def.default
return output
# Type hint helpers
IntFloatType = "Union[int, float]"
class AbstractAttrDefMeta(ABCMeta):
"""Metaclass to validate the existence of 'key' attribute.
Each object of `AbstractAttrDef` must have defined 'key' attribute.
"""
"""
def __call__(cls, *args, **kwargs):
obj = super(AbstractAttrDefMeta, cls).__call__(*args, **kwargs)
init_class = getattr(obj, "__init__class__", None)
@ -93,8 +69,12 @@ class AbstractAttrDefMeta(ABCMeta):
def _convert_reversed_attr(
main_value, depr_value, main_label, depr_label, default
):
main_value: Any,
depr_value: Any,
main_label: str,
depr_label: str,
default: Any,
) -> Any:
if main_value is not None and depr_value is not None:
if main_value == depr_value:
print(
@ -140,8 +120,8 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
enabled (Optional[bool]): Item is enabled (for UI purposes).
hidden (Optional[bool]): DEPRECATED: Use 'visible' instead.
disabled (Optional[bool]): DEPRECATED: Use 'enabled' instead.
"""
"""
type_attributes = []
is_value_def = True
@ -183,7 +163,7 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
def id(self) -> str:
return self._id
def clone(self):
def clone(self) -> "Self":
data = self.serialize()
data.pop("type")
return self.deserialize(data)
@ -251,28 +231,28 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
Returns:
str: Type of attribute definition.
"""
"""
pass
@abstractmethod
def convert_value(self, value):
def convert_value(self, value: Any) -> Any:
"""Convert value to a valid one.
Convert passed value to a valid type. Use default if value can't be
converted.
"""
"""
pass
def serialize(self):
def serialize(self) -> Dict[str, Any]:
"""Serialize object to data so it's possible to recreate it.
Returns:
Dict[str, Any]: Serialized object that can be passed to
'deserialize' method.
"""
"""
data = {
"type": self.type,
"key": self.key,
@ -288,7 +268,7 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
return data
@classmethod
def deserialize(cls, data):
def deserialize(cls, data: Dict[str, Any]) -> "Self":
"""Recreate object from data.
Data can be received using 'serialize' method.
@ -299,10 +279,12 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
return cls(**data)
def _def_type_compare(self, other: "AbstractAttrDef") -> bool:
def _def_type_compare(self, other: "Self") -> bool:
return True
AttrDefType = TypeVar("AttrDefType", bound=AbstractAttrDef)
# -----------------------------------------
# UI attribute definitions won't hold value
# -----------------------------------------
@ -310,13 +292,19 @@ class AbstractAttrDef(metaclass=AbstractAttrDefMeta):
class UIDef(AbstractAttrDef):
is_value_def = False
def __init__(self, key=None, default=None, *args, **kwargs):
def __init__(
self,
key: Optional[str] = None,
default: Optional[Any] = None,
*args,
**kwargs
):
super().__init__(key, default, *args, **kwargs)
def is_value_valid(self, value: Any) -> bool:
return True
def convert_value(self, value):
def convert_value(self, value: Any) -> Any:
return value
@ -343,18 +331,18 @@ class UnknownDef(AbstractAttrDef):
This attribute can be used to keep existing data unchanged but does not
have known definition of type.
"""
"""
type = "unknown"
def __init__(self, key, default=None, **kwargs):
def __init__(self, key: str, default: Optional[Any] = None, **kwargs):
kwargs["default"] = default
super().__init__(key, **kwargs)
def is_value_valid(self, value: Any) -> bool:
return True
def convert_value(self, value):
def convert_value(self, value: Any) -> Any:
return value
@ -365,11 +353,11 @@ class HiddenDef(AbstractAttrDef):
to other attributes (e.g. in multi-page UIs).
Keep in mind the value should be possible to parse by json parser.
"""
"""
type = "hidden"
def __init__(self, key, default=None, **kwargs):
def __init__(self, key: str, default: Optional[Any] = None, **kwargs):
kwargs["default"] = default
kwargs["visible"] = False
super().__init__(key, **kwargs)
@ -377,7 +365,7 @@ class HiddenDef(AbstractAttrDef):
def is_value_valid(self, value: Any) -> bool:
return True
def convert_value(self, value):
def convert_value(self, value: Any) -> Any:
return value
@ -392,8 +380,8 @@ class NumberDef(AbstractAttrDef):
maximum(int, float): Maximum possible value.
decimals(int): Maximum decimal points of value.
default(int, float): Default value for conversion.
"""
"""
type = "number"
type_attributes = [
"minimum",
@ -402,7 +390,12 @@ class NumberDef(AbstractAttrDef):
]
def __init__(
self, key, minimum=None, maximum=None, decimals=None, default=None,
self,
key: str,
minimum: Optional[IntFloatType] = None,
maximum: Optional[IntFloatType] = None,
decimals: Optional[int] = None,
default: Optional[IntFloatType] = None,
**kwargs
):
minimum = 0 if minimum is None else minimum
@ -428,9 +421,9 @@ class NumberDef(AbstractAttrDef):
super().__init__(key, default=default, **kwargs)
self.minimum = minimum
self.maximum = maximum
self.decimals = 0 if decimals is None else decimals
self.minimum: IntFloatType = minimum
self.maximum: IntFloatType = maximum
self.decimals: int = 0 if decimals is None else decimals
def is_value_valid(self, value: Any) -> bool:
if self.decimals == 0:
@ -442,7 +435,7 @@ class NumberDef(AbstractAttrDef):
return False
return True
def convert_value(self, value):
def convert_value(self, value: Any) -> IntFloatType:
if isinstance(value, str):
try:
value = float(value)
@ -477,8 +470,8 @@ class TextDef(AbstractAttrDef):
regex(str, re.Pattern): Regex validation.
placeholder(str): UI placeholder for attribute.
default(str, None): Default value. Empty string used when not defined.
"""
"""
type = "text"
type_attributes = [
"multiline",
@ -486,7 +479,12 @@ class TextDef(AbstractAttrDef):
]
def __init__(
self, key, multiline=None, regex=None, placeholder=None, default=None,
self,
key: str,
multiline: Optional[bool] = None,
regex: Optional[str] = None,
placeholder: Optional[str] = None,
default: Optional[str] = None,
**kwargs
):
if default is None:
@ -505,9 +503,9 @@ class TextDef(AbstractAttrDef):
if isinstance(regex, str):
regex = re.compile(regex)
self.multiline = multiline
self.placeholder = placeholder
self.regex = regex
self.multiline: bool = multiline
self.placeholder: Optional[str] = placeholder
self.regex: Optional["Pattern"] = regex
def is_value_valid(self, value: Any) -> bool:
if not isinstance(value, str):
@ -516,12 +514,12 @@ class TextDef(AbstractAttrDef):
return False
return True
def convert_value(self, value):
def convert_value(self, value: Any) -> str:
if isinstance(value, str):
return value
return self.default
def serialize(self):
def serialize(self) -> Dict[str, Any]:
data = super().serialize()
regex = None
if self.regex is not None:
@ -545,18 +543,24 @@ class EnumDef(AbstractAttrDef):
is enabled.
Args:
items (Union[list[str], list[dict[str, Any]]): Items definition that
can be converted using 'prepare_enum_items'.
key (str): Key under which value is stored.
items (EnumItemsInputType): Items definition that can be converted
using 'prepare_enum_items'.
default (Optional[Any]): Default value. Must be one key(value) from
passed items or list of values for multiselection.
multiselection (Optional[bool]): If True, multiselection is allowed.
Output is list of selected items.
"""
"""
type = "enum"
def __init__(
self, key, items, default=None, multiselection=False, **kwargs
self,
key: str,
items: "EnumItemsInputType",
default: "Union[str, List[Any]]" = None,
multiselection: Optional[bool] = False,
**kwargs
):
if not items:
raise ValueError((
@ -567,6 +571,9 @@ class EnumDef(AbstractAttrDef):
items = self.prepare_enum_items(items)
item_values = [item["value"] for item in items]
item_values_set = set(item_values)
if multiselection is None:
multiselection = False
if multiselection:
if default is None:
default = []
@ -577,9 +584,9 @@ class EnumDef(AbstractAttrDef):
super().__init__(key, default=default, **kwargs)
self.items = items
self._item_values = item_values_set
self.multiselection = multiselection
self.items: List["EnumItemDict"] = items
self._item_values: Set[Any] = item_values_set
self.multiselection: bool = multiselection
def convert_value(self, value):
if not self.multiselection:
@ -609,7 +616,7 @@ class EnumDef(AbstractAttrDef):
return data
@staticmethod
def prepare_enum_items(items):
def prepare_enum_items(items: "EnumItemsInputType") -> List["EnumItemDict"]:
"""Convert items to unified structure.
Output is a list where each item is dictionary with 'value'
@ -625,13 +632,12 @@ class EnumDef(AbstractAttrDef):
```
Args:
items (Union[Dict[str, Any], List[Any], List[Dict[str, Any]]): The
items to convert.
items (EnumItemsInputType): The items to convert.
Returns:
List[Dict[str, Any]]: Unified structure of items.
"""
List[EnumItemDict]: Unified structure of items.
"""
output = []
if isinstance(items, dict):
for value, label in items.items():
@ -682,11 +688,11 @@ class BoolDef(AbstractAttrDef):
Args:
default(bool): Default value. Set to `False` if not defined.
"""
"""
type = "bool"
def __init__(self, key, default=None, **kwargs):
def __init__(self, key: str, default: Optional[bool] = None, **kwargs):
if default is None:
default = False
super().__init__(key, default=default, **kwargs)
@ -694,7 +700,7 @@ class BoolDef(AbstractAttrDef):
def is_value_valid(self, value: Any) -> bool:
return isinstance(value, bool)
def convert_value(self, value):
def convert_value(self, value: Any) -> bool:
if isinstance(value, bool):
return value
return self.default
@ -702,7 +708,11 @@ class BoolDef(AbstractAttrDef):
class FileDefItem:
def __init__(
self, directory, filenames, frames=None, template=None
self,
directory: str,
filenames: List[str],
frames: Optional[List[int]] = None,
template: Optional[str] = None,
):
self.directory = directory
@ -731,7 +741,7 @@ class FileDefItem:
)
@property
def label(self):
def label(self) -> Optional[str]:
if self.is_empty:
return None
@ -774,7 +784,7 @@ class FileDefItem:
filename_template, ",".join(ranges)
)
def split_sequence(self):
def split_sequence(self) -> List["Self"]:
if not self.is_sequence:
raise ValueError("Cannot split single file item")
@ -785,7 +795,7 @@ class FileDefItem:
return self.from_paths(paths, False)
@property
def ext(self):
def ext(self) -> Optional[str]:
if self.is_empty:
return None
_, ext = os.path.splitext(self.filenames[0])
@ -794,14 +804,14 @@ class FileDefItem:
return None
@property
def lower_ext(self):
def lower_ext(self) -> Optional[str]:
ext = self.ext
if ext is not None:
return ext.lower()
return ext
@property
def is_dir(self):
def is_dir(self) -> bool:
if self.is_empty:
return False
@ -810,10 +820,15 @@ class FileDefItem:
return False
return True
def set_directory(self, directory):
def set_directory(self, directory: str):
self.directory = directory
def set_filenames(self, filenames, frames=None, template=None):
def set_filenames(
self,
filenames: List[str],
frames: Optional[List[int]] = None,
template: Optional[str] = None,
):
if frames is None:
frames = []
is_sequence = False
@ -830,17 +845,21 @@ class FileDefItem:
self.is_sequence = is_sequence
@classmethod
def create_empty_item(cls):
def create_empty_item(cls) -> "Self":
return cls("", "")
@classmethod
def from_value(cls, value, allow_sequences):
def from_value(
cls,
value: "Union[List[FileDefItemDict], FileDefItemDict]",
allow_sequences: bool,
) -> List["Self"]:
"""Convert passed value to FileDefItem objects.
Returns:
list: Created FileDefItem objects.
"""
"""
# Convert single item to iterable
if not isinstance(value, (list, tuple, set)):
value = [value]
@ -872,7 +891,7 @@ class FileDefItem:
return output
@classmethod
def from_dict(cls, data):
def from_dict(cls, data: "FileDefItemDict") -> "Self":
return cls(
data["directory"],
data["filenames"],
@ -881,7 +900,11 @@ class FileDefItem:
)
@classmethod
def from_paths(cls, paths, allow_sequences):
def from_paths(
cls,
paths: List[str],
allow_sequences: bool,
) -> List["Self"]:
filenames_by_dir = collections.defaultdict(list)
for path in paths:
normalized = os.path.normpath(path)
@ -910,7 +933,7 @@ class FileDefItem:
return output
def to_dict(self):
def to_dict(self) -> "FileDefItemDict":
output = {
"is_sequence": self.is_sequence,
"directory": self.directory,
@ -948,8 +971,15 @@ class FileDef(AbstractAttrDef):
]
def __init__(
self, key, single_item=True, folders=None, extensions=None,
allow_sequences=True, extensions_label=None, default=None, **kwargs
self,
key: str,
single_item: Optional[bool] = True,
folders: Optional[bool] = None,
extensions: Optional[Iterable[str]] = None,
allow_sequences: Optional[bool] = True,
extensions_label: Optional[str] = None,
default: Optional["Union[FileDefItemDict, List[str]]"] = None,
**kwargs
):
if folders is None and extensions is None:
folders = True
@ -985,14 +1015,14 @@ class FileDef(AbstractAttrDef):
if is_label_horizontal is None:
kwargs["is_label_horizontal"] = False
self.single_item = single_item
self.folders = folders
self.extensions = set(extensions)
self.allow_sequences = allow_sequences
self.extensions_label = extensions_label
self.single_item: bool = single_item
self.folders: bool = folders
self.extensions: Set[str] = set(extensions)
self.allow_sequences: bool = allow_sequences
self.extensions_label: Optional[str] = extensions_label
super().__init__(key, default=default, **kwargs)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not super().__eq__(other):
return False
@ -1026,7 +1056,9 @@ class FileDef(AbstractAttrDef):
return False
return True
def convert_value(self, value):
def convert_value(
self, value: Any
) -> "Union[FileDefItemDict, List[FileDefItemDict]]":
if isinstance(value, (str, dict)):
value = [value]
@ -1062,55 +1094,124 @@ class FileDef(AbstractAttrDef):
return []
def serialize_attr_def(attr_def):
def register_attr_def_class(cls: AttrDefType):
"""Register attribute definition.
Currently registered definitions are used to deserialize data to objects.
Attrs:
cls (AttrDefType): Non-abstract class to be registered with unique
'type' attribute.
Raises:
KeyError: When type was already registered.
"""
if cls.type in _attr_defs_by_type:
raise KeyError("Type \"{}\" was already registered".format(cls.type))
_attr_defs_by_type[cls.type] = cls
def get_attributes_keys(
attribute_definitions: List[AttrDefType]
) -> Set[str]:
"""Collect keys from list of attribute definitions.
Args:
attribute_definitions (List[AttrDefType]): Objects of attribute
definitions.
Returns:
Set[str]: Keys that will be created using passed attribute definitions.
"""
keys = set()
if not attribute_definitions:
return keys
for attribute_def in attribute_definitions:
if not isinstance(attribute_def, UIDef):
keys.add(attribute_def.key)
return keys
def get_default_values(
attribute_definitions: List[AttrDefType]
) -> Dict[str, Any]:
"""Receive default values for attribute definitions.
Args:
attribute_definitions (List[AttrDefType]): Attribute definitions
for which default values should be collected.
Returns:
Dict[str, Any]: Default values for passed attribute definitions.
"""
output = {}
if not attribute_definitions:
return output
for attr_def in attribute_definitions:
# Skip UI definitions
if not isinstance(attr_def, UIDef):
output[attr_def.key] = attr_def.default
return output
def serialize_attr_def(attr_def: AttrDefType) -> Dict[str, Any]:
"""Serialize attribute definition to data.
Args:
attr_def (AbstractAttrDef): Attribute definition to serialize.
attr_def (AttrDefType): Attribute definition to serialize.
Returns:
Dict[str, Any]: Serialized data.
"""
"""
return attr_def.serialize()
def serialize_attr_defs(attr_defs):
def serialize_attr_defs(
attr_defs: List[AttrDefType]
) -> List[Dict[str, Any]]:
"""Serialize attribute definitions to data.
Args:
attr_defs (List[AbstractAttrDef]): Attribute definitions to serialize.
attr_defs (List[AttrDefType]): Attribute definitions to serialize.
Returns:
List[Dict[str, Any]]: Serialized data.
"""
"""
return [
serialize_attr_def(attr_def)
for attr_def in attr_defs
]
def deserialize_attr_def(attr_def_data):
def deserialize_attr_def(attr_def_data: Dict[str, Any]) -> AttrDefType:
"""Deserialize attribute definition from data.
Args:
attr_def_data (Dict[str, Any]): Attribute definition data to
deserialize.
"""
"""
attr_type = attr_def_data.pop("type")
cls = _attr_defs_by_type[attr_type]
return cls.deserialize(attr_def_data)
def deserialize_attr_defs(attr_defs_data):
def deserialize_attr_defs(
attr_defs_data: List[Dict[str, Any]]
) -> List[AttrDefType]:
"""Deserialize attribute definitions.
Args:
List[Dict[str, Any]]: List of attribute definitions.
"""
"""
return [
deserialize_attr_def(attr_def_data)
for attr_def_data in attr_defs_data