From b64e0340d976ee35ff5d3bb4d71b0e8189696ee2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Samohel?= Date: Tue, 12 Nov 2024 23:40:59 +0100 Subject: [PATCH] :recycle: refactor to remove circular imports add `FileLocations` validations --- client/ayon_core/pipeline/traits/__init__.py | 3 +- client/ayon_core/pipeline/traits/content.py | 25 +- .../pipeline/traits/representation.py | 609 ++++++++++++++++++ client/ayon_core/pipeline/traits/time.py | 30 +- client/ayon_core/pipeline/traits/trait.py | 597 +---------------- client/ayon_core/pipeline/traits/utils.py | 14 +- tests/__init__.py | 1 + .../ayon_core/pipeline/traits/test_traits.py | 27 +- tests/conftest.py | 3 +- 9 files changed, 668 insertions(+), 641 deletions(-) create mode 100644 client/ayon_core/pipeline/traits/representation.py create mode 100644 tests/__init__.py diff --git a/client/ayon_core/pipeline/traits/__init__.py b/client/ayon_core/pipeline/traits/__init__.py index d8b74a4c70..16ee7d6975 100644 --- a/client/ayon_core/pipeline/traits/__init__.py +++ b/client/ayon_core/pipeline/traits/__init__.py @@ -13,6 +13,7 @@ from .content import ( from .cryptography import DigitallySigned, GPGSigned from .lifecycle import Persistent, Transient from .meta import Tagged, TemplatePath +from .representation import Representation from .three_dimensional import Geometry, IESProfile, Lighting, Shader, Spatial from .time import ( FrameRanged, @@ -22,7 +23,7 @@ from .time import ( SMPTETimecode, Static, ) -from .trait import MissingTraitError, Representation, TraitBase +from .trait import MissingTraitError, TraitBase from .two_dimensional import ( UDIM, Deep, diff --git a/client/ayon_core/pipeline/traits/content.py b/client/ayon_core/pipeline/traits/content.py index eb156f44e1..1563d3b96a 100644 --- a/client/ayon_core/pipeline/traits/content.py +++ b/client/ayon_core/pipeline/traits/content.py @@ -7,14 +7,14 @@ from typing import ClassVar, Optional from pydantic import Field -from .time import Sequence +from .representation import Representation +from .time import FrameRanged from .trait import ( MissingTraitError, - Representation, TraitBase, TraitValidationError, - get_sequence_from_files, ) +from .utils import get_sequence_from_files class MimeType(TraitBase): @@ -125,26 +125,25 @@ class FileLocations(TraitBase): msg = "No file locations defined (empty list)" raise TraitValidationError(self.name, msg) - tmp_seq: Sequence = get_sequence_from_files( + tmp_frame_ranged: FrameRanged = get_sequence_from_files( [f.file_path for f in self.file_paths]) - if len(self.file_paths) != \ - tmp_seq.frame_end - tmp_seq.frame_start: + if len(self.file_paths) - 1 != \ + tmp_frame_ranged.frame_end - tmp_frame_ranged.frame_start: # If the number of file paths does not match the frame range, # the trait is invalid msg = ( - f"Number of file locations ({len(self.file_paths)}) " + f"Number of file locations ({len(self.file_paths) - 1}) " "does not match frame range " - f"({tmp_seq.frame_end - tmp_seq.frame_start})" + f"({tmp_frame_ranged.frame_end - tmp_frame_ranged.frame_start})" ) raise TraitValidationError(self.name, msg) try: - sequence: Sequence = representation.get_trait(Sequence) + sequence: FrameRanged = representation.get_trait(FrameRanged) - if sequence.frame_start != tmp_seq.frame_start or \ - sequence.frame_end != tmp_seq.frame_end or \ - sequence.frame_padding != tmp_seq.frame_padding: + if sequence.frame_start != tmp_frame_ranged.frame_start or \ + sequence.frame_end != tmp_frame_ranged.frame_end: # If the frame range does not match the sequence trait, the # trait is invalid. Note that we don't check the frame rate # because it is not stored in the file paths and is not @@ -154,7 +153,7 @@ class FileLocations(TraitBase): f"({sequence.frame_start}-{sequence.frame_end}) " "in sequence trait does not match " "frame range " - f"({tmp_seq.frame_start}-{tmp_seq.frame_end}) " + f"({tmp_frame_ranged.frame_start}-{tmp_frame_ranged.frame_end}) " "defined in files." ) raise TraitValidationError(self.name, msg) diff --git a/client/ayon_core/pipeline/traits/representation.py b/client/ayon_core/pipeline/traits/representation.py new file mode 100644 index 0000000000..acd78a6ce5 --- /dev/null +++ b/client/ayon_core/pipeline/traits/representation.py @@ -0,0 +1,609 @@ +"""Defines the base trait model and representation.""" +from __future__ import annotations + +import inspect +import re +import sys +import uuid +from functools import lru_cache +from typing import ClassVar, Optional, Type, TypeVar, Union + +from .trait import ( + IncompatibleTraitVersionError, + LooseMatchingTraitError, + MissingTraitError, + TraitBase, + UpgradableTraitError, +) + +T = TypeVar("T", bound=TraitBase) + + +def _get_version_from_id(_id: str) -> int: + """Get version from ID. + + Args: + _id (str): ID. + + Returns: + int: Version. + + """ + match = re.search(r"v(\d+)$", _id) + return int(match[1]) if match else None + + +class Representation: + """Representation of products. + + Representation defines collection of individual properties that describe + the specific "form" of the product. Each property is represented by a + trait therefore the Representation is a collection of traits. + + It holds methods to add, remove, get, and check for the existence of a + trait in the representation. It also provides a method to get all the + + Arguments: + name (str): Representation name. Must be unique within instance. + representation_id (str): Representation ID. + + """ + _data: dict + _module_blacklist: ClassVar[list[str]] = [ + "_", "builtins", "pydantic"] + name: str + representation_id: str + + def __hash__(self): + """Return hash of the representation ID.""" + return hash(self.representation_id) + + def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None: + """Add a trait to the Representation. + + Args: + trait (TraitBase): Trait to add. + exists_ok (bool, optional): If True, do not raise an error if the + trait already exists. Defaults to False. + + Raises: + ValueError: If the trait ID is not provided or the trait already + exists. + + """ + if not hasattr(trait, "id"): + error_msg = f"Invalid trait {trait} - ID is required." + raise ValueError(error_msg) + if trait.id in self._data and not exists_ok: + error_msg = f"Trait with ID {trait.id} already exists." + raise ValueError(error_msg) + self._data[trait.id] = trait + + def add_traits( + self, traits: list[TraitBase], *, exists_ok: bool=False) -> None: + """Add a list of traits to the Representation. + + Args: + traits (list[TraitBase]): List of traits to add. + exists_ok (bool, optional): If True, do not raise an error if the + trait already exists. Defaults to False. + + """ + for trait in traits: + self.add_trait(trait, exists_ok=exists_ok) + + def remove_trait(self, trait: Type[TraitBase]) -> None: + """Remove a trait from the data. + + Args: + trait (TraitBase, optional): Trait class. + + Raises: + ValueError: If the trait is not found. + + """ + try: + self._data.pop(trait.id) + except KeyError as e: + error_msg = f"Trait with ID {trait.id} not found." + raise ValueError(error_msg) from e + + def remove_trait_by_id(self, trait_id: str) -> None: + """Remove a trait from the data by its ID. + + Args: + trait_id (str): Trait ID. + + Raises: + ValueError: If the trait is not found. + + """ + try: + self._data.pop(trait_id) + except KeyError as e: + error_msg = f"Trait with ID {trait_id} not found." + raise ValueError(error_msg) from e + + def remove_traits(self, traits: list[Type[TraitBase]]) -> None: + """Remove a list of traits from the Representation. + + If no trait IDs or traits are provided, all traits will be removed. + + Args: + traits (list[TraitBase]): List of trait classes. + + """ + if not traits: + self._data = {} + return + + for trait in traits: + self.remove_trait(trait) + + def remove_traits_by_id(self, trait_ids: list[str]) -> None: + """Remove a list of traits from the Representation by their ID. + + If no trait IDs or traits are provided, all traits will be removed. + + Args: + trait_ids (list[str], optional): List of trait IDs. + + """ + for trait_id in trait_ids: + self.remove_trait_by_id(trait_id) + + + def has_traits(self) -> bool: + """Check if the Representation has any traits. + + Returns: + bool: True if the Representation has any traits, False otherwise. + + """ + return bool(self._data) + + def contains_trait(self, trait: Type[TraitBase]) -> bool: + """Check if the trait exists in the Representation. + + Args: + trait (TraitBase): Trait class. + + Returns: + bool: True if the trait exists, False otherwise. + + """ + return bool(self._data.get(trait.id)) + + def contains_trait_by_id(self, trait_id: str) -> bool: + """Check if the trait exists using trait id. + + Args: + trait_id (str): Trait ID. + + Returns: + bool: True if the trait exists, False otherwise. + + """ + return bool(self._data.get(trait_id)) + + def contains_traits(self, traits: list[Type[TraitBase]]) -> bool: + """Check if the traits exist. + + Args: + traits (list[TraitBase], optional): List of trait classes. + + Returns: + bool: True if all traits exist, False otherwise. + + """ + return all(self.contains_trait(trait=trait) for trait in traits) + + def contains_traits_by_id(self, trait_ids: list[str]) -> bool: + """Check if the traits exist by id. + + If no trait IDs or traits are provided, it will check if the + representation has any traits. + + Args: + trait_ids (list[str]): List of trait IDs. + + Returns: + bool: True if all traits exist, False otherwise. + + """ + return all( + self.contains_trait_by_id(trait_id) for trait_id in trait_ids + ) + + def get_trait(self, trait: Type[T]) -> Union[T]: + """Get a trait from the representation. + + Args: + trait (TraitBase, optional): Trait class. + + Returns: + TraitBase: Trait instance. + + Raises: + MissingTraitError: If the trait is not found. + + """ + try: + return self._data[trait.id] + except KeyError as e: + msg = f"Trait with ID {trait.id} not found." + raise MissingTraitError(msg) from e + + def get_trait_by_id(self, trait_id: str) -> Union[T]: + # sourcery skip: use-named-expression + """Get a trait from the representation by id. + + Args: + trait_id (str): Trait ID. + + Returns: + TraitBase: Trait instance. + + Raises: + MissingTraitError: If the trait is not found. + + """ + version = _get_version_from_id(trait_id) + if version: + try: + return self._data[trait_id] + except KeyError as e: + msg = f"Trait with ID {trait_id} not found." + raise MissingTraitError(msg) from e + + result = next( + ( + self._data.get(trait_id) + for trait_id in self._data + if trait_id.startswith(trait_id) + ), + None, + ) + if not result: + msg = f"Trait with ID {trait_id} not found." + raise MissingTraitError(msg) + return result + + def get_traits(self, + traits: Optional[list[Type[TraitBase]]]=None + ) -> dict[str, T]: + """Get a list of traits from the representation. + + If no trait IDs or traits are provided, all traits will be returned. + + Args: + traits (list[TraitBase], optional): List of trait classes. + + Returns: + dict: Dictionary of traits. + + """ + result = {} + if not traits: + for trait_id in self._data: + result[trait_id] = self.get_trait_by_id(trait_id=trait_id) + return result + + for trait in traits: + result[trait.id] = self.get_trait(trait=trait) + return result + + def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]: + """Get a list of traits from the representation by their id. + + If no trait IDs or traits are provided, all traits will be returned. + + Args: + trait_ids (list[str]): List of trait IDs. + + Returns: + dict: Dictionary of traits. + + """ + return { + trait_id: self.get_trait_by_id(trait_id) + for trait_id in trait_ids + } + + def traits_as_dict(self) -> dict: + """Return the traits from Representation data as a dictionary. + + Returns: + dict: Traits data dictionary. + + """ + return { + trait_id: trait.model_dump() + for trait_id, trait in self._data.items() + if trait and trait_id + } + + def __len__(self): + """Return the length of the data.""" + return len(self._data) + + def __init__( + self, + name: str, + representation_id: Optional[str]=None, + traits: Optional[list[TraitBase]]=None): + """Initialize the data. + + Args: + name (str): Representation name. Must be unique within instance. + representation_id (str, optional): Representation ID. + traits (list[TraitBase], optional): List of traits. + """ + self.name = name + self.representation_id = representation_id or uuid.uuid4().hex + self._data = {} + if traits: + for trait in traits: + self.add_trait(trait) + + @staticmethod + def _get_version_from_id(trait_id: str) -> Union[int, None]: + # sourcery skip: use-named-expression + """Check if the trait has version specified. + + Args: + trait_id (str): Trait ID. + + Returns: + int: Trait version. + None: If the trait id does not have a version. + + """ + version_regex = r"v(\d+)$" + match = re.search(version_regex, trait_id) + return int(match[1]) if match else None + + def __eq__(self, other: Representation) -> bool: # noqa: PLR0911 + """Check if the representation is equal to another. + + Args: + other (Representation): Representation to compare. + + Returns: + bool: True if the representations are equal, False otherwise. + + """ + if self.representation_id != other.representation_id: + return False + + if not isinstance(other, Representation): + return False + + if self.name != other.name: + return False + + # number of traits + if len(self) != len(other): + return False + + for trait_id, trait in self._data.items(): + if trait_id not in other._data: + return False + if trait != other._data[trait_id]: + return False + for key, value in trait.model_dump().items(): + if value != other._data[trait_id].model_dump().get(key): + return False + + return True + + @classmethod + @lru_cache(maxsize=64) + def _get_possible_trait_classes_from_modules( + cls, + trait_id: str) -> set[type[TraitBase]]: + """Get possible trait classes from modules. + + Args: + trait_id (str): Trait ID. + + Returns: + set[type[TraitBase]]: Set of trait classes. + + """ + modules = sys.modules.copy() + filtered_modules = modules.copy() + for module_name in modules: + for bl_module in cls._module_blacklist: + if module_name.startswith(bl_module): + filtered_modules.pop(module_name) + + trait_candidates = set() + for module in filtered_modules.values(): + if not module: + continue + for _, klass in inspect.getmembers(module, inspect.isclass): + if inspect.isclass(klass) \ + and issubclass(klass, TraitBase) \ + and str(klass.id).startswith(trait_id): + trait_candidates.add(klass) + return trait_candidates + + @classmethod + @lru_cache(maxsize=64) + def _get_trait_class( + cls, trait_id: str) -> Union[Type[TraitBase], None]: + """Get the trait class with corresponding to given ID. + + This method will search for the trait class in all the modules except + the blacklisted modules. There is some issue in Pydantic where + ``issubclass`` is not working properly so we are excluding explicitly + modules with offending classes. This list can be updated as needed to + speed up the search. + + Args: + trait_id (str): Trait ID. + + Returns: + Type[TraitBase]: Trait class. + + Raises: + LooseMatchingTraitError: If the trait is found with a loose + matching criteria. This exception will include the trait + class that was found and the expected trait ID. Additional + downstream logic must decide how to handle this error. + + """ + version = cls._get_version_from_id(trait_id) + + trait_candidates = cls._get_possible_trait_classes_from_modules( + trait_id + ) + + for trait_class in trait_candidates: + if trait_class.id == trait_id: + # we found direct match + return trait_class + + # if we didn't find direct match, we will search for the highest + # version of the trait. + if not version: + # sourcery skip: use-named-expression + trait_versions = [ + trait_class for trait_class in trait_candidates + if re.match( + rf"{trait_id}.v(\d+)$", str(trait_class.id)) + ] + if trait_versions: + def _get_version_by_id(trait_klass: Type[TraitBase]) -> int: + match = re.search(r"v(\d+)$", str(trait_klass.id)) + return int(match[1]) if match else 0 + + error = LooseMatchingTraitError( + "Found trait that might match.") + error.found_trait = max( + trait_versions, key=_get_version_by_id) + error.expected_id = trait_id + raise error + + return None + + @classmethod + def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]: + """Get the trait class for the given trait ID. + + Args: + trait_id (str): Trait ID. + + Returns: + type[TraitBase]: Trait class. + + Raises: + IncompatibleTraitVersionError: If the trait version is incompatible + with the current version of the trait. + UpgradableTraitError: If the trait can upgrade existing data + meant for older versions of the trait. + ValueError: If the trait model with the given ID is not found. + + """ + trait_class = None + try: + trait_class = cls._get_trait_class(trait_id=trait_id) + except LooseMatchingTraitError as e: + requested_version = _get_version_from_id(trait_id) + found_version = _get_version_from_id(e.found_trait.id) + + if not requested_version: + trait_class = e.found_trait + + else: + if requested_version > found_version: + error_msg = ( + f"Requested trait version {requested_version} is " + f"higher than the found trait version {found_version}." + ) + raise IncompatibleTraitVersionError(error_msg) from e + + if requested_version < found_version and hasattr( + e.found_trait, "upgrade"): + error_msg = ( + "Requested trait version " + f"{requested_version} is lower " + f"than the found trait version {found_version}." + ) + error = UpgradableTraitError(error_msg) + error.trait = e.found_trait + raise error from e + return trait_class + + @classmethod + def from_dict( + cls, + name: str, + representation_id: Optional[str]=None, + trait_data: Optional[dict] = None) -> Representation: + """Create a representation from a dictionary. + + Args: + name (str): Representation name. + representation_id (str, optional): Representation ID. + trait_data (dict): Representation data. Dictionary with keys + as trait ids and values as trait data. Example:: + + { + "ayon.2d.PixelBased.v1": { + "display_window_width": 1920, + "display_window_height": 1080 + }, + "ayon.2d.Planar.v1": { + "channels": 3 + } + } + + Returns: + Representation: Representation instance. + + """ + traits = [] + for trait_id, value in trait_data.items(): + if not isinstance(value, dict): + msg = ( + f"Invalid trait data for trait ID {trait_id}. " + "Trait data must be a dictionary." + ) + raise TypeError(msg) + + try: + trait_class = cls.get_trait_class_by_trait_id(trait_id) + except UpgradableTraitError as e: + # we found newer version of trait, we will upgrade the data + if hasattr(e.trait, "upgrade"): + traits.append(e.trait.upgrade(value)) + else: + msg = ( + f"Newer version of trait {e.trait.id} found " + f"for requested {trait_id} but without " + "upgrade method." + ) + raise IncompatibleTraitVersionError(msg) from e + else: + if not trait_class: + error_msg = f"Trait model with ID {trait_id} not found." + raise ValueError(error_msg) + + traits.append(trait_class(**value)) + + return cls( + name=name, representation_id=representation_id, traits=traits) + + + def validate(self) -> bool: + """Validate the representation. + + This method will validate all the traits in the representation. + + Returns: + bool: True if the representation is valid, False otherwise. + + """ + return all(trait.validate(self) for trait in self._data.values()) diff --git a/client/ayon_core/pipeline/traits/time.py b/client/ayon_core/pipeline/traits/time.py index 1b2f211468..16aeaba2d1 100644 --- a/client/ayon_core/pipeline/traits/time.py +++ b/client/ayon_core/pipeline/traits/time.py @@ -2,16 +2,15 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Union -from pydantic import Field +from pydantic import Field, PlainSerializer -from .content import FileLocations -from .trait import MissingTraitError, Representation, TraitBase +from .representation import Representation +from .trait import MissingTraitError, TraitBase if TYPE_CHECKING: from decimal import Decimal - from fractions import Fraction class GapPolicy(Enum): @@ -42,6 +41,13 @@ class FrameRanged(TraitBase): * frame_end -> end_frame ... + Note: frames_per_second is a string to allow various precision + formats. FPS is a floating point number, but it can be also + represented as a fraction (e.g. "30000/1001") or as a decimal + or even as irrational number. We need to support all these + formats. To work with FPS, we'll need some helper function + to convert FPS to Decimal from string. + Attributes: name (str): Trait name. description (str): Trait description. @@ -50,7 +56,7 @@ class FrameRanged(TraitBase): frame_end (int): Frame end. frame_in (int): Frame in. frame_out (int): Frame out. - frames_per_second (float, Fraction, Decimal): Frames per second. + frames_per_second (str): Frames per second. step (int): Step. """ @@ -63,8 +69,7 @@ class FrameRanged(TraitBase): ..., title="Frame Start") frame_in: Optional[int] = Field(None, title="In Frame") frame_out: Optional[int] = Field(None, title="Out Frame") - frames_per_second: Union[float, Fraction, Decimal] = Field( - ..., title="Frames Per Second") + frames_per_second: str = Field(..., title="Frames Per Second") step: Optional[int] = Field(1, title="Step") @@ -83,9 +88,9 @@ class Handles(TraitBase): frame_end_handle (int): Frame end handle. """ - name: ClassVar[str] = "Clip" - description: ClassVar[str] = "Clip Trait" - id: ClassVar[str] = "ayon.time.Clip.v1" + name: ClassVar[str] = "Handles" + description: ClassVar[str] = "Handles Trait" + id: ClassVar[str] = "ayon.time.Handles.v1" inclusive: Optional[bool] = Field( False, title="Handles are inclusive") # noqa: FBT003 frame_start_handle: Optional[int] = Field( @@ -93,7 +98,7 @@ class Handles(TraitBase): frame_end_handle: Optional[int] = Field( 0, title="Frame End Handle") -class Sequence(FrameRanged, Handles): +class Sequence(TraitBase): """Sequence trait model. This model represents a sequence trait. Based on the FrameRanged trait @@ -130,6 +135,7 @@ class Sequence(FrameRanged, Handles): # if there is FileLocations trait, run validation # on it as well try: + from .content import FileLocations file_locs: FileLocations = representation.get_trait( FileLocations) file_locs.validate(representation) diff --git a/client/ayon_core/pipeline/traits/trait.py b/client/ayon_core/pipeline/traits/trait.py index 15e48b6f5a..1f0c72cd9d 100644 --- a/client/ayon_core/pipeline/traits/trait.py +++ b/client/ayon_core/pipeline/traits/trait.py @@ -1,13 +1,9 @@ """Defines the base trait model and representation.""" from __future__ import annotations -import inspect import re -import sys -import uuid from abc import ABC, abstractmethod -from functools import lru_cache -from typing import ClassVar, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Optional import pydantic.alias_generators from pydantic import ( @@ -16,19 +12,8 @@ from pydantic import ( ConfigDict, ) - -def _get_version_from_id(_id: str) -> int: - """Get version from ID. - - Args: - _id (str): ID. - - Returns: - int: Version. - - """ - match = re.search(r"v(\d+)$", _id) - return int(match[1]) if match else None +if TYPE_CHECKING: + from .representation import Representation class TraitBase(ABC, BaseModel): @@ -106,584 +91,8 @@ class TraitBase(ABC, BaseModel): return re.sub(r"\.v\d+$", "", str(cls.id)) -T = TypeVar("T", bound=TraitBase) -class Representation: - """Representation of products. - - Representation defines collection of individual properties that describe - the specific "form" of the product. Each property is represented by a - trait therefore the Representation is a collection of traits. - - It holds methods to add, remove, get, and check for the existence of a - trait in the representation. It also provides a method to get all the - - Arguments: - name (str): Representation name. Must be unique within instance. - representation_id (str): Representation ID. - - """ - _data: dict - _module_blacklist: ClassVar[list[str]] = [ - "_", "builtins", "pydantic"] - name: str - representation_id: str - - def __hash__(self): - """Return hash of the representation ID.""" - return hash(self.representation_id) - - def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None: - """Add a trait to the Representation. - - Args: - trait (TraitBase): Trait to add. - exists_ok (bool, optional): If True, do not raise an error if the - trait already exists. Defaults to False. - - Raises: - ValueError: If the trait ID is not provided or the trait already - exists. - - """ - if not hasattr(trait, "id"): - error_msg = f"Invalid trait {trait} - ID is required." - raise ValueError(error_msg) - if trait.id in self._data and not exists_ok: - error_msg = f"Trait with ID {trait.id} already exists." - raise ValueError(error_msg) - self._data[trait.id] = trait - - def add_traits( - self, traits: list[TraitBase], *, exists_ok: bool=False) -> None: - """Add a list of traits to the Representation. - - Args: - traits (list[TraitBase]): List of traits to add. - exists_ok (bool, optional): If True, do not raise an error if the - trait already exists. Defaults to False. - - """ - for trait in traits: - self.add_trait(trait, exists_ok=exists_ok) - - def remove_trait(self, trait: Type[TraitBase]) -> None: - """Remove a trait from the data. - - Args: - trait (TraitBase, optional): Trait class. - - Raises: - ValueError: If the trait is not found. - - """ - try: - self._data.pop(trait.id) - except KeyError as e: - error_msg = f"Trait with ID {trait.id} not found." - raise ValueError(error_msg) from e - - def remove_trait_by_id(self, trait_id: str) -> None: - """Remove a trait from the data by its ID. - - Args: - trait_id (str): Trait ID. - - Raises: - ValueError: If the trait is not found. - - """ - try: - self._data.pop(trait_id) - except KeyError as e: - error_msg = f"Trait with ID {trait_id} not found." - raise ValueError(error_msg) from e - - def remove_traits(self, traits: list[Type[TraitBase]]) -> None: - """Remove a list of traits from the Representation. - - If no trait IDs or traits are provided, all traits will be removed. - - Args: - traits (list[TraitBase]): List of trait classes. - - """ - if not traits: - self._data = {} - return - - for trait in traits: - self.remove_trait(trait) - - def remove_traits_by_id(self, trait_ids: list[str]) -> None: - """Remove a list of traits from the Representation by their ID. - - If no trait IDs or traits are provided, all traits will be removed. - - Args: - trait_ids (list[str], optional): List of trait IDs. - - """ - for trait_id in trait_ids: - self.remove_trait_by_id(trait_id) - - - def has_traits(self) -> bool: - """Check if the Representation has any traits. - - Returns: - bool: True if the Representation has any traits, False otherwise. - - """ - return bool(self._data) - - def contains_trait(self, trait: Type[TraitBase]) -> bool: - """Check if the trait exists in the Representation. - - Args: - trait (TraitBase): Trait class. - - Returns: - bool: True if the trait exists, False otherwise. - - """ - return bool(self._data.get(trait.id)) - - def contains_trait_by_id(self, trait_id: str) -> bool: - """Check if the trait exists using trait id. - - Args: - trait_id (str): Trait ID. - - Returns: - bool: True if the trait exists, False otherwise. - - """ - return bool(self._data.get(trait_id)) - - def contains_traits(self, traits: list[Type[TraitBase]]) -> bool: - """Check if the traits exist. - - Args: - traits (list[TraitBase], optional): List of trait classes. - - Returns: - bool: True if all traits exist, False otherwise. - - """ - return all(self.contains_trait(trait=trait) for trait in traits) - - def contains_traits_by_id(self, trait_ids: list[str]) -> bool: - """Check if the traits exist by id. - - If no trait IDs or traits are provided, it will check if the - representation has any traits. - - Args: - trait_ids (list[str]): List of trait IDs. - - Returns: - bool: True if all traits exist, False otherwise. - - """ - return all( - self.contains_trait_by_id(trait_id) for trait_id in trait_ids - ) - - def get_trait(self, trait: Type[T]) -> Union[T]: - """Get a trait from the representation. - - Args: - trait (TraitBase, optional): Trait class. - - Returns: - TraitBase: Trait instance. - - Raises: - MissingTraitError: If the trait is not found. - - """ - try: - return self._data[trait.id] - except KeyError as e: - msg = f"Trait with ID {trait.id} not found." - raise MissingTraitError(msg) from e - - def get_trait_by_id(self, trait_id: str) -> Union[T]: - # sourcery skip: use-named-expression - """Get a trait from the representation by id. - - Args: - trait_id (str): Trait ID. - - Returns: - TraitBase: Trait instance. - - Raises: - MissingTraitError: If the trait is not found. - - """ - version = _get_version_from_id(trait_id) - if version: - try: - return self._data[trait_id] - except KeyError as e: - msg = f"Trait with ID {trait_id} not found." - raise MissingTraitError(msg) from e - - result = next( - ( - self._data.get(trait_id) - for trait_id in self._data - if trait_id.startswith(trait_id) - ), - None, - ) - if not result: - msg = f"Trait with ID {trait_id} not found." - raise MissingTraitError(msg) - return result - - def get_traits(self, - traits: Optional[list[Type[TraitBase]]]=None - ) -> dict[str, T]: - """Get a list of traits from the representation. - - If no trait IDs or traits are provided, all traits will be returned. - - Args: - traits (list[TraitBase], optional): List of trait classes. - - Returns: - dict: Dictionary of traits. - - """ - result = {} - if not traits: - for trait_id in self._data: - result[trait_id] = self.get_trait_by_id(trait_id=trait_id) - return result - - for trait in traits: - result[trait.id] = self.get_trait(trait=trait) - return result - - def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]: - """Get a list of traits from the representation by their id. - - If no trait IDs or traits are provided, all traits will be returned. - - Args: - trait_ids (list[str]): List of trait IDs. - - Returns: - dict: Dictionary of traits. - - """ - return { - trait_id: self.get_trait_by_id(trait_id) - for trait_id in trait_ids - } - - def traits_as_dict(self) -> dict: - """Return the traits from Representation data as a dictionary. - - Returns: - dict: Traits data dictionary. - - """ - return { - trait_id: trait.dict() - for trait_id, trait in self._data.items() - if trait and trait_id - } - - def __len__(self): - """Return the length of the data.""" - return len(self._data) - - def __init__( - self, - name: str, - representation_id: Optional[str]=None, - traits: Optional[list[TraitBase]]=None): - """Initialize the data. - - Args: - name (str): Representation name. Must be unique within instance. - representation_id (str, optional): Representation ID. - traits (list[TraitBase], optional): List of traits. - """ - self.name = name - self.representation_id = representation_id or uuid.uuid4().hex - self._data = {} - if traits: - for trait in traits: - self.add_trait(trait) - - @staticmethod - def _get_version_from_id(trait_id: str) -> Union[int, None]: - # sourcery skip: use-named-expression - """Check if the trait has version specified. - - Args: - trait_id (str): Trait ID. - - Returns: - int: Trait version. - None: If the trait id does not have a version. - - """ - version_regex = r"v(\d+)$" - match = re.search(version_regex, trait_id) - return int(match[1]) if match else None - - def __eq__(self, other: Representation) -> bool: # noqa: PLR0911 - """Check if the representation is equal to another. - - Args: - other (Representation): Representation to compare. - - Returns: - bool: True if the representations are equal, False otherwise. - - """ - if self.representation_id != other.representation_id: - return False - - if not isinstance(other, Representation): - return False - - if self.name != other.name: - return False - - # number of traits - if len(self) != len(other): - return False - - for trait_id, trait in self._data.items(): - if trait_id not in other._data: - return False - if trait != other._data[trait_id]: - return False - for key, value in trait.dict().items(): - if value != other._data[trait_id].dict().get(key): - return False - - return True - - @classmethod - @lru_cache(maxsize=64) - def _get_possible_trait_classes_from_modules( - cls, - trait_id: str) -> set[type[TraitBase]]: - """Get possible trait classes from modules. - - Args: - trait_id (str): Trait ID. - - Returns: - set[type[TraitBase]]: Set of trait classes. - - """ - modules = sys.modules.copy() - filtered_modules = modules.copy() - for module_name in modules: - for bl_module in cls._module_blacklist: - if module_name.startswith(bl_module): - filtered_modules.pop(module_name) - - trait_candidates = set() - for module in filtered_modules.values(): - if not module: - continue - for _, klass in inspect.getmembers(module, inspect.isclass): - if inspect.isclass(klass) \ - and issubclass(klass, TraitBase) \ - and str(klass.id).startswith(trait_id): - trait_candidates.add(klass) - return trait_candidates - - @classmethod - @lru_cache(maxsize=64) - def _get_trait_class( - cls, trait_id: str) -> Union[Type[TraitBase], None]: - """Get the trait class with corresponding to given ID. - - This method will search for the trait class in all the modules except - the blacklisted modules. There is some issue in Pydantic where - ``issubclass`` is not working properly so we are excluding explicitly - modules with offending classes. This list can be updated as needed to - speed up the search. - - Args: - trait_id (str): Trait ID. - - Returns: - Type[TraitBase]: Trait class. - - Raises: - LooseMatchingTraitError: If the trait is found with a loose - matching criteria. This exception will include the trait - class that was found and the expected trait ID. Additional - downstream logic must decide how to handle this error. - - """ - version = cls._get_version_from_id(trait_id) - - trait_candidates = cls._get_possible_trait_classes_from_modules( - trait_id - ) - - for trait_class in trait_candidates: - if trait_class.id == trait_id: - # we found direct match - return trait_class - - # if we didn't find direct match, we will search for the highest - # version of the trait. - if not version: - # sourcery skip: use-named-expression - trait_versions = [ - trait_class for trait_class in trait_candidates - if re.match( - rf"{trait_id}.v(\d+)$", str(trait_class.id)) - ] - if trait_versions: - def _get_version_by_id(trait_klass: Type[TraitBase]) -> int: - match = re.search(r"v(\d+)$", str(trait_klass.id)) - return int(match[1]) if match else 0 - - error = LooseMatchingTraitError( - "Found trait that might match.") - error.found_trait = max( - trait_versions, key=_get_version_by_id) - error.expected_id = trait_id - raise error - - return None - - @classmethod - def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]: - """Get the trait class for the given trait ID. - - Args: - trait_id (str): Trait ID. - - Returns: - type[TraitBase]: Trait class. - - Raises: - IncompatibleTraitVersionError: If the trait version is incompatible - with the current version of the trait. - UpgradableTraitError: If the trait can upgrade existing data - meant for older versions of the trait. - ValueError: If the trait model with the given ID is not found. - - """ - trait_class = None - try: - trait_class = cls._get_trait_class(trait_id=trait_id) - except LooseMatchingTraitError as e: - requested_version = _get_version_from_id(trait_id) - found_version = _get_version_from_id(e.found_trait.id) - - if not requested_version: - trait_class = e.found_trait - - else: - if requested_version > found_version: - error_msg = ( - f"Requested trait version {requested_version} is " - f"higher than the found trait version {found_version}." - ) - raise IncompatibleTraitVersionError(error_msg) from e - - if requested_version < found_version and hasattr( - e.found_trait, "upgrade"): - error_msg = ( - "Requested trait version " - f"{requested_version} is lower " - f"than the found trait version {found_version}." - ) - error = UpgradableTraitError(error_msg) - error.trait = e.found_trait - raise error from e - return trait_class - - @classmethod - def from_dict( - cls, - name: str, - representation_id: Optional[str]=None, - trait_data: Optional[dict] = None) -> Representation: - """Create a representation from a dictionary. - - Args: - name (str): Representation name. - representation_id (str, optional): Representation ID. - trait_data (dict): Representation data. Dictionary with keys - as trait ids and values as trait data. Example:: - - { - "ayon.2d.PixelBased.v1": { - "display_window_width": 1920, - "display_window_height": 1080 - }, - "ayon.2d.Planar.v1": { - "channels": 3 - } - } - - Returns: - Representation: Representation instance. - - """ - traits = [] - for trait_id, value in trait_data.items(): - if not isinstance(value, dict): - msg = ( - f"Invalid trait data for trait ID {trait_id}. " - "Trait data must be a dictionary." - ) - raise TypeError(msg) - - try: - trait_class = cls.get_trait_class_by_trait_id(trait_id) - except UpgradableTraitError as e: - # we found newer version of trait, we will upgrade the data - if hasattr(e.trait, "upgrade"): - traits.append(e.trait.upgrade(value)) - else: - msg = ( - f"Newer version of trait {e.trait.id} found " - f"for requested {trait_id} but without " - "upgrade method." - ) - raise IncompatibleTraitVersionError(msg) from e - else: - if not trait_class: - error_msg = f"Trait model with ID {trait_id} not found." - raise ValueError(error_msg) - - traits.append(trait_class(**value)) - - return cls( - name=name, representation_id=representation_id, traits=traits) - - - def validate(self) -> bool: - """Validate the representation. - - This method will validate all the traits in the representation. - - Returns: - bool: True if the representation is valid, False otherwise. - - """ - return all(trait.validate(self) for trait in self._data.values()) - class IncompatibleTraitVersionError(Exception): diff --git a/client/ayon_core/pipeline/traits/utils.py b/client/ayon_core/pipeline/traits/utils.py index cd75443cd0..54386fe8ca 100644 --- a/client/ayon_core/pipeline/traits/utils.py +++ b/client/ayon_core/pipeline/traits/utils.py @@ -5,13 +5,13 @@ from typing import TYPE_CHECKING from clique import assemble -from ayon_core.pipeline.traits import Sequence +from ayon_core.pipeline.traits.time import FrameRanged if TYPE_CHECKING: from pathlib import Path -def get_sequence_from_files(paths: list[Path]) -> Sequence: +def get_sequence_from_files(paths: list[Path]) -> FrameRanged: """Get original frame range from files. Note that this cannot guess frame rate, so it's set to 25. @@ -20,7 +20,7 @@ def get_sequence_from_files(paths: list[Path]) -> Sequence: paths (list[Path]): List of file paths. Returns: - Sequence: Sequence trait. + FrameRanged: FrameRanged trait. """ col = assemble([path.as_posix() for path in paths])[0][0] @@ -30,9 +30,9 @@ def get_sequence_from_files(paths: list[Path]) -> Sequence: # Get last frame for padding last_frame = sorted_frames[-1] # Use padding from collection of length of last frame as string - padding = max(col.padding, len(str(last_frame))) + # padding = max(col.padding, len(str(last_frame))) - return Sequence( - frame_start=first_frame, frame_end=last_frame, frame_padding=padding, - frames_per_second=25 + return FrameRanged( + frame_start=first_frame, frame_end=last_frame, + frames_per_second="25.0" ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..d420712d8b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests.""" diff --git a/tests/client/ayon_core/pipeline/traits/test_traits.py b/tests/client/ayon_core/pipeline/traits/test_traits.py index 7b183a1104..f72a1169c0 100644 --- a/tests/client/ayon_core/pipeline/traits/test_traits.py +++ b/tests/client/ayon_core/pipeline/traits/test_traits.py @@ -8,15 +8,17 @@ from ayon_core.pipeline.traits import ( Bundle, FileLocation, FileLocations, + FrameRanged, Image, MimeType, + Overscan, PixelBased, Planar, Representation, Sequence, TraitBase, ) -from pipeline.traits import Overscan +from ayon_core.pipeline.traits.trait import TraitValidationError REPRESENTATION_DATA = { FileLocation.id: { @@ -340,7 +342,7 @@ def test_file_locations_validation() -> None: file_size=1024, file_hash=None, ) - for frame in range(1001, 1050) + for frame in range(1001, 1051) ] representation = Representation(name="test", traits=[ @@ -351,39 +353,40 @@ def test_file_locations_validation() -> None: file_paths=file_locations_list) # this should be valid trait - assert file_locations_trait.validate(representation) is True + file_locations_trait.validate(representation) - # add valid sequence trait - sequence_trait = Sequence( + # add valid FrameRanged trait + sequence_trait = FrameRanged( frame_start=1001, frame_end=1050, frame_padding=4, - frames_per_second=25 + frames_per_second="25" ) representation.add_trait(sequence_trait) # it should still validate fine - assert file_locations_trait.validate(representation) is True + file_locations_trait.validate(representation) # create empty file locations trait empty_file_locations_trait = FileLocations(file_paths=[]) representation = Representation(name="test", traits=[ empty_file_locations_trait ]) - assert empty_file_locations_trait.validate( - representation) is False + with pytest.raises(TraitValidationError): + empty_file_locations_trait.validate(representation) # create valid file locations trait but with not matching sequence # trait representation = Representation(name="test", traits=[ FileLocations(file_paths=file_locations_list) ]) - invalid_sequence_trait = Sequence( + invalid_sequence_trait = FrameRanged( frame_start=1001, frame_end=1051, frame_padding=4, - frames_per_second=25 + frames_per_second="25" ) representation.add_trait(invalid_sequence_trait) - assert file_locations_trait.validate(representation) is False + with pytest.raises(TraitValidationError): + file_locations_trait.validate(representation) diff --git a/tests/conftest.py b/tests/conftest.py index a3c46a9dd7..33c29d13f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +"""conftest.py: pytest configuration file.""" import sys from pathlib import Path @@ -5,5 +6,3 @@ client_path = Path(__file__).resolve().parent.parent / "client" # add client path to sys.path sys.path.append(str(client_path)) - -print(f"Added {client_path} to sys.path")