diff --git a/client/ayon_core/pipeline/traits/content.py b/client/ayon_core/pipeline/traits/content.py index b2bf62805f..480d1657f5 100644 --- a/client/ayon_core/pipeline/traits/content.py +++ b/client/ayon_core/pipeline/traits/content.py @@ -113,3 +113,23 @@ class Bundle(TraitBase): def to_representation(self) -> Representation: """Convert to a representation.""" return Representation(traits=self.items) + + +class Fragment(TraitBase): + """Fragment trait model. + + This model represents a fragment trait. A fragment is a part of + a larger entity that is represented by a representation. + + Attributes: + name (str): Trait name. + description (str): Trait description. + id (str): id should be namespaced trait name with version + parent (str): Parent representation id. + + """ + + name: ClassVar[str] = "Fragment" + description: ClassVar[str] = "Fragment Trait" + id: ClassVar[str] = "ayon.content.Fragment.v1" + parent: str = Field(..., title="Parent Representation Id") diff --git a/client/ayon_core/pipeline/traits/trait.py b/client/ayon_core/pipeline/traits/trait.py index 90be38225b..b97ec08cc6 100644 --- a/client/ayon_core/pipeline/traits/trait.py +++ b/client/ayon_core/pipeline/traits/trait.py @@ -2,15 +2,33 @@ from __future__ import annotations import inspect +import re import sys import uuid from abc import ABC, abstractmethod -from collections import OrderedDict -from functools import lru_cache +from functools import cached_property, lru_cache from typing import ClassVar, Optional, Type, Union import pydantic.alias_generators -from pydantic import AliasGenerator, BaseModel, ConfigDict +from pydantic import ( + AliasGenerator, + BaseModel, + ConfigDict, +) + + +def _get_version_from_id(_id: str) -> int: + """Get version from ID. + + Args: + _id (str): ID. + + Returns: + int: Version. + + """ + match = re.search(r"v(\d+)$", _id) + return int(match[1]) if match else None class TraitBase(ABC, BaseModel): @@ -47,6 +65,35 @@ class TraitBase(ABC, BaseModel): """Abstract attribute for description.""" ... + @property + @cached_property + def version(self) -> Union[int, None]: + # sourcery skip: use-named-expression + """Get trait version from ID. + + This assumes Trait ID ends with `.v{version}`. If not, it will + return None. + + """ + version_regex = r"v(\d+)$" + match = re.search(version_regex, self.id) + return int(match[1]) if match else None + + def validate(self, representation: Representation) -> bool: + """Validate the trait. + + This method should be implemented in the derived classes to validate + the trait data. It can be used by traits to validate against other + traits in the representation. + + Args: + representation (Representation): Representation instance. + + Raises: + ValueError: If the trait is invalid within representation. + + """ + return True class Representation: @@ -70,40 +117,9 @@ class Representation: name: str representation_id: str - @lru_cache(maxsize=64) # noqa: B019 - def _get_trait_class(self, trait_id: str) -> Union[Type[TraitBase], None]: - """Get the trait class with corresponding to given ID. - - This method will search for the trait class in all the modules except - the blacklisted modules. There is some issue in Pydantic where - ``issubclass`` is not working properly so we are excluding explicitly - modules with offending classes. This list can be updated as needed to - speed up the search. - - Args: - trait_id (str): Trait ID. - - Returns: - Type[TraitBase]: Trait class. - - """ - modules = sys.modules.copy() - filtered_modules = modules.copy() - for module_name in modules: - for bl_module in self._module_blacklist: - if module_name.startswith(bl_module): - filtered_modules.pop(module_name) - - for module in filtered_modules.values(): - if not module: - continue - for _, klass in inspect.getmembers(module, inspect.isclass): - if inspect.isclass(klass) and \ - issubclass(klass, TraitBase) and \ - klass.id == trait_id: - return klass - return None - + def __hash__(self): + """Return hash of the representation ID.""" + return hash(self.representation_id) def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None: """Add a trait to the Representation. @@ -275,6 +291,7 @@ class Representation: return self._data[trait.id] if self._data.get(trait.id) else None def get_trait_by_id(self, trait_id: str) -> Union[TraitBase, None]: + # sourcery skip: use-named-expression """Get a trait from the representation by id. Args: @@ -284,12 +301,18 @@ class Representation: TraitBase: Trait instance. """ - trait_class = self._get_trait_class(trait_id) - if not trait_class: - error_msg = f"Trait model with ID {trait_id} not found." - raise ValueError(error_msg) + version = _get_version_from_id(trait_id) + if version: + return self._data.get(trait_id) - return self._data[trait_id] if self._data.get(trait_id) else None + return next( + ( + self._data.get(trait_id) + for trait_id in self._data + if trait_id.startswith(trait_id) + ), + None, + ) def get_traits(self, traits: Optional[list[Type[TraitBase]]]=None) -> dict: @@ -338,13 +361,11 @@ class Representation: dict: Traits data dictionary. """ - result = OrderedDict() - for trait_id, trait in self._data.items(): - if not trait or not trait_id: - continue - result[trait_id] = OrderedDict(trait.dict()) - - return result + return { + trait_id: trait.dict() + for trait_id, trait in self._data.items() + if trait and trait_id + } def __len__(self): """Return the length of the data.""" @@ -368,3 +389,287 @@ class Representation: if traits: for trait in traits: self.add_trait(trait) + + @staticmethod + def _get_version_from_id(trait_id: str) -> Union[int, None]: + # sourcery skip: use-named-expression + """Check if the trait has version specified. + + Args: + trait_id (str): Trait ID. + + Returns: + int: Trait version. + None: If the trait id does not have a version. + + """ + version_regex = r"v(\d+)$" + match = re.search(version_regex, trait_id) + return int(match[1]) if match else None + + def __eq__(self, other: Representation) -> bool: # noqa: PLR0911 + """Check if the representation is equal to another. + + Args: + other (Representation): Representation to compare. + + Returns: + bool: True if the representations are equal, False otherwise. + + """ + if not isinstance(other, Representation): + return False + + if self.name != other.name: + return False + + # number of traits + if len(self) != len(other): + return False + + for trait_id, trait in self._data.items(): + if trait_id not in other._data: + return False + if trait != other._data[trait_id]: + return False + for key, value in trait.dict().items(): + if value != other._data[trait_id].dict().get(key): + return False + + return True + + @classmethod + @lru_cache(maxsize=64) + def _get_possible_trait_classes_from_modules( + cls, + trait_id: str) -> set[type[TraitBase]]: + """Get possible trait classes from modules. + + Args: + trait_id (str): Trait ID. + + Returns: + set[type[TraitBase]]: Set of trait classes. + + """ + modules = sys.modules.copy() + filtered_modules = modules.copy() + for module_name in modules: + for bl_module in cls._module_blacklist: + if module_name.startswith(bl_module): + filtered_modules.pop(module_name) + + trait_candidates = set() + for module in filtered_modules.values(): + if not module: + continue + for _, klass in inspect.getmembers(module, inspect.isclass): + if inspect.isclass(klass) \ + and issubclass(klass, TraitBase) \ + and str(klass.id).startswith(trait_id): + trait_candidates.add(klass) + return trait_candidates + + @classmethod + @lru_cache(maxsize=64) + def _get_trait_class( + cls, trait_id: str) -> Union[Type[TraitBase], None]: + """Get the trait class with corresponding to given ID. + + This method will search for the trait class in all the modules except + the blacklisted modules. There is some issue in Pydantic where + ``issubclass`` is not working properly so we are excluding explicitly + modules with offending classes. This list can be updated as needed to + speed up the search. + + Args: + trait_id (str): Trait ID. + + Returns: + Type[TraitBase]: Trait class. + + Raises: + LooseMatchingTraitError: If the trait is found with a loose + matching criteria. This exception will include the trait + class that was found and the expected trait ID. Additional + downstream logic must decide how to handle this error. + + """ + version = cls._get_version_from_id(trait_id) + + trait_candidates = cls._get_possible_trait_classes_from_modules( + trait_id + ) + + for trait_class in trait_candidates: + if trait_class.id == trait_id: + # we found direct match + return trait_class + + # if we didn't find direct match, we will search for the highest + # version of the trait. + if not version: + # sourcery skip: use-named-expression + trait_versions = [ + trait_class for trait_class in trait_candidates + if re.match( + rf"{trait_id}.v(\d+)$", str(trait_class.id)) + ] + if trait_versions: + def _get_version_by_id(trait_klass: Type[TraitBase]) -> int: + match = re.search(r"v(\d+)$", str(trait_klass.id)) + return int(match[1]) if match else 0 + + error = LooseMatchingTraitError( + "Found trait that might match.") + error.found_trait = max( + trait_versions, key=_get_version_by_id) + error.expected_id = trait_id + raise error + + return None + + @classmethod + def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]: + """Get the trait class for the given trait ID. + + Args: + trait_id (str): Trait ID. + + Returns: + type[TraitBase]: Trait class. + + Raises: + IncompatibleTraitVersionError: If the trait version is incompatible + with the current version of the trait. + UpgradableTraitError: If the trait can upgrade existing data + meant for older versions of the trait. + ValueError: If the trait model with the given ID is not found. + + """ + trait_class = None + try: + trait_class = cls._get_trait_class(trait_id=trait_id) + except LooseMatchingTraitError as e: + requested_version = _get_version_from_id(trait_id) + found_version = _get_version_from_id(e.found_trait.id) + + if not requested_version: + trait_class = e.found_trait + + else: + if requested_version > found_version: + error_msg = ( + f"Requested trait version {requested_version} is " + f"higher than the found trait version {found_version}." + ) + raise IncompatibleTraitVersionError(error_msg) from e + + if requested_version < found_version and hasattr( + e.found_trait, "upgrade"): + error_msg = ( + "Requested trait version " + f"{requested_version} is lower " + f"than the found trait version {found_version}." + ) + error = UpgradableTraitError(error_msg) + error.trait = e.found_trait + raise error from e + return trait_class + + @classmethod + def from_dict( + cls, + name: str, + representation_id: Optional[str]=None, + trait_data: Optional[dict] = None) -> Representation: + """Create a representation from a dictionary. + + Args: + name (str): Representation name. + representation_id (str, optional): Representation ID. + trait_data (dict): Representation data. Dictionary with keys + as trait ids and values as trait data. Example:: + + { + "ayon.2d.PixelBased.v1": { + "display_window_width": 1920, + "display_window_height": 1080 + }, + "ayon.2d.Planar.v1": { + "channels": 3 + } + } + + Returns: + Representation: Representation instance. + + """ + traits = [] + for trait_id, value in trait_data.items(): + if not isinstance(value, dict): + msg = ( + f"Invalid trait data for trait ID {trait_id}. " + "Trait data must be a dictionary." + ) + raise TypeError(msg) + + try: + trait_class = cls.get_trait_class_by_trait_id(trait_id) + except UpgradableTraitError as e: + # we found newer version of trait, we will upgrade the data + if hasattr(e.trait, "upgrade"): + traits.append(e.trait.upgrade(value)) + else: + msg = ( + f"Newer version of trait {e.trait.id} found " + f"for requested {trait_id} but without " + "upgrade method." + ) + raise IncompatibleTraitVersionError(msg) from e + else: + if not trait_class: + error_msg = f"Trait model with ID {trait_id} not found." + raise ValueError(error_msg) + + traits.append(trait_class(**value)) + + return cls( + name=name, representation_id=representation_id, traits=traits) + + + +class IncompatibleTraitVersionError(Exception): + """Incompatible trait version exception. + + This exception is raised when the trait version is incompatible with the + current version of the trait. + """ + + +class UpgradableTraitError(Exception): + """Upgradable trait version exception. + + This exception is raised when the trait can upgrade existing data + meant for older versions of the trait. It must implement `upgrade` + method that will take old trait data as argument to handle the upgrade. + """ + + trait: TraitBase + old_data: dict + +class LooseMatchingTraitError(Exception): + """Loose matching trait exception. + + This exception is raised when the trait is found with a loose matching + criteria. + """ + + found_trait: TraitBase + expected_id: str + +class TraitValidationError(Exception): + """Trait validation error exception. + + This exception is raised when the trait validation fails. + """ diff --git a/pyproject.toml b/pyproject.toml index 641faf2536..fdaec51a58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ exclude = [ [tool.ruff.lint.per-file-ignores] "client/ayon_core/lib/__init__.py" = ["E402"] -"tests/*.py" = ["S101"] +"tests/*.py" = ["S101", "PLR2004"] # allow asserts and magical values [tool.ruff.format] # Like Black, use double quotes for strings. diff --git a/tests/client/ayon_core/pipeline/traits/lib/__init__.py b/tests/client/ayon_core/pipeline/traits/lib/__init__.py new file mode 100644 index 0000000000..d7ea7ae0ad --- /dev/null +++ b/tests/client/ayon_core/pipeline/traits/lib/__init__.py @@ -0,0 +1,25 @@ +"""Metadata traits.""" +from typing import ClassVar + +from ayon_core.pipeline.traits import TraitBase + + +class NewTestTrait(TraitBase): + """New Test trait model. + + This model represents a tagged trait. + + Attributes: + name (str): Trait name. + description (str): Trait description. + id (str): id should be namespaced trait name with version + """ + + name: ClassVar[str] = "New Test Trait" + description: ClassVar[str] = ( + "This test trait is used for testing updating." + ) + id: ClassVar[str] = "ayon.test.NewTestTrait.v999" + + +__all__ = ["NewTestTrait"] diff --git a/tests/client/ayon_core/pipeline/traits/test_traits.py b/tests/client/ayon_core/pipeline/traits/test_traits.py index fc38e8fb56..8a48d6eef8 100644 --- a/tests/client/ayon_core/pipeline/traits/test_traits.py +++ b/tests/client/ayon_core/pipeline/traits/test_traits.py @@ -14,6 +14,7 @@ from ayon_core.pipeline.traits import ( Representation, TraitBase, ) +from pipeline.traits import Overscan REPRESENTATION_DATA = { FileLocation.id: { @@ -32,6 +33,15 @@ REPRESENTATION_DATA = { }, } +class UpgradedImage(Image): + """Upgraded image class.""" + id = "ayon.2d.Image.v2" + + @classmethod + def upgrade(cls, data: dict) -> UpgradedImage: # noqa: ARG003 + """Upgrade the trait.""" + return cls() + class InvalidTrait: """Invalid trait class.""" foo = "bar" @@ -62,6 +72,8 @@ def test_representation_errors(representation: Representation) -> None: def test_representation_traits(representation: Representation) -> None: """Test setting and getting traits.""" + assert representation.get_trait_by_id("ayon.2d.PixelBased").version == 1 + assert len(representation) == len(REPRESENTATION_DATA) assert representation.get_trait_by_id(FileLocation.id) assert representation.get_trait_by_id(Image.id) @@ -152,6 +164,7 @@ def test_trait_removing(representation: Representation) -> None: representation.remove_trait(Image) + def test_getting_traits_data(representation: Representation) -> None: """Test getting a batch of traits.""" result = representation.get_traits_by_ids( @@ -218,3 +231,80 @@ def test_bundles() -> None: assert sub_representation.get_trait(trait=MimeType).mime_type in [ "image/jpeg", "image/tiff" ] + +def test_get_version_from_id() -> None: + """Test getting version from trait ID.""" + assert Image().version == 1 + + class TestOverscan(Overscan): + id = "ayon.2d.Overscan.v2" + + assert TestOverscan( + left=0, + right=0, + top=0, + bottom=0 + ).version == 2 + + class TestMimeType(MimeType): + id = "ayon.content.MimeType" + + assert TestMimeType(mime_type="foo/bar").version is None + + +def test_from_dict() -> None: + """Test creating representation from dictionary.""" + traits_data = { + "ayon.content.FileLocation.v1": { + "file_path": "/path/to/file", + "file_size": 1024, + "file_hash": None, + }, + "ayon.2d.Image.v1": {}, + } + + representation = Representation.from_dict( + "test", trait_data=traits_data) + + assert len(representation) == 2 + assert representation.get_trait_by_id("ayon.content.FileLocation.v1") + assert representation.get_trait_by_id("ayon.2d.Image.v1") + + traits_data = { + "ayon.content.FileLocation.v999": { + "file_path": "/path/to/file", + "file_size": 1024, + "file_hash": None, + }, + } + + with pytest.raises(ValueError, match="Trait model with ID .* not found."): + representation = Representation.from_dict( + "test", trait_data=traits_data) + + traits_data = { + "ayon.content.FileLocation": { + "file_path": "/path/to/file", + "file_size": 1024, + "file_hash": None, + }, + } + + representation = Representation.from_dict( + "test", trait_data=traits_data) + + assert len(representation) == 1 + assert representation.get_trait_by_id("ayon.content.FileLocation.v1") + + # this won't work right now because we would need to somewhat mock + # the import + """ + from .lib import NewTestTrait + + traits_data = { + "ayon.test.NewTestTrait.v1": {}, + } + + representation = Representation.from_dict( + "test", trait_data=traits_data) + """