🎨 added id and name to representation

also added versionless trait id processing and trait validation
This commit is contained in:
Ondřej Samohel 2024-10-25 17:12:18 +02:00
parent 4b3469c5ae
commit edefade158
No known key found for this signature in database
GPG key ID: 02376E18990A97C6
5 changed files with 490 additions and 50 deletions

View file

@ -113,3 +113,23 @@ class Bundle(TraitBase):
def to_representation(self) -> Representation:
"""Convert to a representation."""
return Representation(traits=self.items)
class Fragment(TraitBase):
"""Fragment trait model.
This model represents a fragment trait. A fragment is a part of
a larger entity that is represented by a representation.
Attributes:
name (str): Trait name.
description (str): Trait description.
id (str): id should be namespaced trait name with version
parent (str): Parent representation id.
"""
name: ClassVar[str] = "Fragment"
description: ClassVar[str] = "Fragment Trait"
id: ClassVar[str] = "ayon.content.Fragment.v1"
parent: str = Field(..., title="Parent Representation Id")

View file

@ -2,15 +2,33 @@
from __future__ import annotations
import inspect
import re
import sys
import uuid
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from functools import cached_property, lru_cache
from typing import ClassVar, Optional, Type, Union
import pydantic.alias_generators
from pydantic import AliasGenerator, BaseModel, ConfigDict
from pydantic import (
AliasGenerator,
BaseModel,
ConfigDict,
)
def _get_version_from_id(_id: str) -> int:
"""Get version from ID.
Args:
_id (str): ID.
Returns:
int: Version.
"""
match = re.search(r"v(\d+)$", _id)
return int(match[1]) if match else None
class TraitBase(ABC, BaseModel):
@ -47,6 +65,35 @@ class TraitBase(ABC, BaseModel):
"""Abstract attribute for description."""
...
@property
@cached_property
def version(self) -> Union[int, None]:
# sourcery skip: use-named-expression
"""Get trait version from ID.
This assumes Trait ID ends with `.v{version}`. If not, it will
return None.
"""
version_regex = r"v(\d+)$"
match = re.search(version_regex, self.id)
return int(match[1]) if match else None
def validate(self, representation: Representation) -> bool:
"""Validate the trait.
This method should be implemented in the derived classes to validate
the trait data. It can be used by traits to validate against other
traits in the representation.
Args:
representation (Representation): Representation instance.
Raises:
ValueError: If the trait is invalid within representation.
"""
return True
class Representation:
@ -70,40 +117,9 @@ class Representation:
name: str
representation_id: str
@lru_cache(maxsize=64) # noqa: B019
def _get_trait_class(self, trait_id: str) -> Union[Type[TraitBase], None]:
"""Get the trait class with corresponding to given ID.
This method will search for the trait class in all the modules except
the blacklisted modules. There is some issue in Pydantic where
``issubclass`` is not working properly so we are excluding explicitly
modules with offending classes. This list can be updated as needed to
speed up the search.
Args:
trait_id (str): Trait ID.
Returns:
Type[TraitBase]: Trait class.
"""
modules = sys.modules.copy()
filtered_modules = modules.copy()
for module_name in modules:
for bl_module in self._module_blacklist:
if module_name.startswith(bl_module):
filtered_modules.pop(module_name)
for module in filtered_modules.values():
if not module:
continue
for _, klass in inspect.getmembers(module, inspect.isclass):
if inspect.isclass(klass) and \
issubclass(klass, TraitBase) and \
klass.id == trait_id:
return klass
return None
def __hash__(self):
"""Return hash of the representation ID."""
return hash(self.representation_id)
def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None:
"""Add a trait to the Representation.
@ -275,6 +291,7 @@ class Representation:
return self._data[trait.id] if self._data.get(trait.id) else None
def get_trait_by_id(self, trait_id: str) -> Union[TraitBase, None]:
# sourcery skip: use-named-expression
"""Get a trait from the representation by id.
Args:
@ -284,12 +301,18 @@ class Representation:
TraitBase: Trait instance.
"""
trait_class = self._get_trait_class(trait_id)
if not trait_class:
error_msg = f"Trait model with ID {trait_id} not found."
raise ValueError(error_msg)
version = _get_version_from_id(trait_id)
if version:
return self._data.get(trait_id)
return self._data[trait_id] if self._data.get(trait_id) else None
return next(
(
self._data.get(trait_id)
for trait_id in self._data
if trait_id.startswith(trait_id)
),
None,
)
def get_traits(self,
traits: Optional[list[Type[TraitBase]]]=None) -> dict:
@ -338,13 +361,11 @@ class Representation:
dict: Traits data dictionary.
"""
result = OrderedDict()
for trait_id, trait in self._data.items():
if not trait or not trait_id:
continue
result[trait_id] = OrderedDict(trait.dict())
return result
return {
trait_id: trait.dict()
for trait_id, trait in self._data.items()
if trait and trait_id
}
def __len__(self):
"""Return the length of the data."""
@ -368,3 +389,287 @@ class Representation:
if traits:
for trait in traits:
self.add_trait(trait)
@staticmethod
def _get_version_from_id(trait_id: str) -> Union[int, None]:
# sourcery skip: use-named-expression
"""Check if the trait has version specified.
Args:
trait_id (str): Trait ID.
Returns:
int: Trait version.
None: If the trait id does not have a version.
"""
version_regex = r"v(\d+)$"
match = re.search(version_regex, trait_id)
return int(match[1]) if match else None
def __eq__(self, other: Representation) -> bool: # noqa: PLR0911
"""Check if the representation is equal to another.
Args:
other (Representation): Representation to compare.
Returns:
bool: True if the representations are equal, False otherwise.
"""
if not isinstance(other, Representation):
return False
if self.name != other.name:
return False
# number of traits
if len(self) != len(other):
return False
for trait_id, trait in self._data.items():
if trait_id not in other._data:
return False
if trait != other._data[trait_id]:
return False
for key, value in trait.dict().items():
if value != other._data[trait_id].dict().get(key):
return False
return True
@classmethod
@lru_cache(maxsize=64)
def _get_possible_trait_classes_from_modules(
cls,
trait_id: str) -> set[type[TraitBase]]:
"""Get possible trait classes from modules.
Args:
trait_id (str): Trait ID.
Returns:
set[type[TraitBase]]: Set of trait classes.
"""
modules = sys.modules.copy()
filtered_modules = modules.copy()
for module_name in modules:
for bl_module in cls._module_blacklist:
if module_name.startswith(bl_module):
filtered_modules.pop(module_name)
trait_candidates = set()
for module in filtered_modules.values():
if not module:
continue
for _, klass in inspect.getmembers(module, inspect.isclass):
if inspect.isclass(klass) \
and issubclass(klass, TraitBase) \
and str(klass.id).startswith(trait_id):
trait_candidates.add(klass)
return trait_candidates
@classmethod
@lru_cache(maxsize=64)
def _get_trait_class(
cls, trait_id: str) -> Union[Type[TraitBase], None]:
"""Get the trait class with corresponding to given ID.
This method will search for the trait class in all the modules except
the blacklisted modules. There is some issue in Pydantic where
``issubclass`` is not working properly so we are excluding explicitly
modules with offending classes. This list can be updated as needed to
speed up the search.
Args:
trait_id (str): Trait ID.
Returns:
Type[TraitBase]: Trait class.
Raises:
LooseMatchingTraitError: If the trait is found with a loose
matching criteria. This exception will include the trait
class that was found and the expected trait ID. Additional
downstream logic must decide how to handle this error.
"""
version = cls._get_version_from_id(trait_id)
trait_candidates = cls._get_possible_trait_classes_from_modules(
trait_id
)
for trait_class in trait_candidates:
if trait_class.id == trait_id:
# we found direct match
return trait_class
# if we didn't find direct match, we will search for the highest
# version of the trait.
if not version:
# sourcery skip: use-named-expression
trait_versions = [
trait_class for trait_class in trait_candidates
if re.match(
rf"{trait_id}.v(\d+)$", str(trait_class.id))
]
if trait_versions:
def _get_version_by_id(trait_klass: Type[TraitBase]) -> int:
match = re.search(r"v(\d+)$", str(trait_klass.id))
return int(match[1]) if match else 0
error = LooseMatchingTraitError(
"Found trait that might match.")
error.found_trait = max(
trait_versions, key=_get_version_by_id)
error.expected_id = trait_id
raise error
return None
@classmethod
def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]:
"""Get the trait class for the given trait ID.
Args:
trait_id (str): Trait ID.
Returns:
type[TraitBase]: Trait class.
Raises:
IncompatibleTraitVersionError: If the trait version is incompatible
with the current version of the trait.
UpgradableTraitError: If the trait can upgrade existing data
meant for older versions of the trait.
ValueError: If the trait model with the given ID is not found.
"""
trait_class = None
try:
trait_class = cls._get_trait_class(trait_id=trait_id)
except LooseMatchingTraitError as e:
requested_version = _get_version_from_id(trait_id)
found_version = _get_version_from_id(e.found_trait.id)
if not requested_version:
trait_class = e.found_trait
else:
if requested_version > found_version:
error_msg = (
f"Requested trait version {requested_version} is "
f"higher than the found trait version {found_version}."
)
raise IncompatibleTraitVersionError(error_msg) from e
if requested_version < found_version and hasattr(
e.found_trait, "upgrade"):
error_msg = (
"Requested trait version "
f"{requested_version} is lower "
f"than the found trait version {found_version}."
)
error = UpgradableTraitError(error_msg)
error.trait = e.found_trait
raise error from e
return trait_class
@classmethod
def from_dict(
cls,
name: str,
representation_id: Optional[str]=None,
trait_data: Optional[dict] = None) -> Representation:
"""Create a representation from a dictionary.
Args:
name (str): Representation name.
representation_id (str, optional): Representation ID.
trait_data (dict): Representation data. Dictionary with keys
as trait ids and values as trait data. Example::
{
"ayon.2d.PixelBased.v1": {
"display_window_width": 1920,
"display_window_height": 1080
},
"ayon.2d.Planar.v1": {
"channels": 3
}
}
Returns:
Representation: Representation instance.
"""
traits = []
for trait_id, value in trait_data.items():
if not isinstance(value, dict):
msg = (
f"Invalid trait data for trait ID {trait_id}. "
"Trait data must be a dictionary."
)
raise TypeError(msg)
try:
trait_class = cls.get_trait_class_by_trait_id(trait_id)
except UpgradableTraitError as e:
# we found newer version of trait, we will upgrade the data
if hasattr(e.trait, "upgrade"):
traits.append(e.trait.upgrade(value))
else:
msg = (
f"Newer version of trait {e.trait.id} found "
f"for requested {trait_id} but without "
"upgrade method."
)
raise IncompatibleTraitVersionError(msg) from e
else:
if not trait_class:
error_msg = f"Trait model with ID {trait_id} not found."
raise ValueError(error_msg)
traits.append(trait_class(**value))
return cls(
name=name, representation_id=representation_id, traits=traits)
class IncompatibleTraitVersionError(Exception):
"""Incompatible trait version exception.
This exception is raised when the trait version is incompatible with the
current version of the trait.
"""
class UpgradableTraitError(Exception):
"""Upgradable trait version exception.
This exception is raised when the trait can upgrade existing data
meant for older versions of the trait. It must implement `upgrade`
method that will take old trait data as argument to handle the upgrade.
"""
trait: TraitBase
old_data: dict
class LooseMatchingTraitError(Exception):
"""Loose matching trait exception.
This exception is raised when the trait is found with a loose matching
criteria.
"""
found_trait: TraitBase
expected_id: str
class TraitValidationError(Exception):
"""Trait validation error exception.
This exception is raised when the trait validation fails.
"""

View file

@ -106,7 +106,7 @@ exclude = [
[tool.ruff.lint.per-file-ignores]
"client/ayon_core/lib/__init__.py" = ["E402"]
"tests/*.py" = ["S101"]
"tests/*.py" = ["S101", "PLR2004"] # allow asserts and magical values
[tool.ruff.format]
# Like Black, use double quotes for strings.

View file

@ -0,0 +1,25 @@
"""Metadata traits."""
from typing import ClassVar
from ayon_core.pipeline.traits import TraitBase
class NewTestTrait(TraitBase):
"""New Test trait model.
This model represents a tagged trait.
Attributes:
name (str): Trait name.
description (str): Trait description.
id (str): id should be namespaced trait name with version
"""
name: ClassVar[str] = "New Test Trait"
description: ClassVar[str] = (
"This test trait is used for testing updating."
)
id: ClassVar[str] = "ayon.test.NewTestTrait.v999"
__all__ = ["NewTestTrait"]

View file

@ -14,6 +14,7 @@ from ayon_core.pipeline.traits import (
Representation,
TraitBase,
)
from pipeline.traits import Overscan
REPRESENTATION_DATA = {
FileLocation.id: {
@ -32,6 +33,15 @@ REPRESENTATION_DATA = {
},
}
class UpgradedImage(Image):
"""Upgraded image class."""
id = "ayon.2d.Image.v2"
@classmethod
def upgrade(cls, data: dict) -> UpgradedImage: # noqa: ARG003
"""Upgrade the trait."""
return cls()
class InvalidTrait:
"""Invalid trait class."""
foo = "bar"
@ -62,6 +72,8 @@ def test_representation_errors(representation: Representation) -> None:
def test_representation_traits(representation: Representation) -> None:
"""Test setting and getting traits."""
assert representation.get_trait_by_id("ayon.2d.PixelBased").version == 1
assert len(representation) == len(REPRESENTATION_DATA)
assert representation.get_trait_by_id(FileLocation.id)
assert representation.get_trait_by_id(Image.id)
@ -152,6 +164,7 @@ def test_trait_removing(representation: Representation) -> None:
representation.remove_trait(Image)
def test_getting_traits_data(representation: Representation) -> None:
"""Test getting a batch of traits."""
result = representation.get_traits_by_ids(
@ -218,3 +231,80 @@ def test_bundles() -> None:
assert sub_representation.get_trait(trait=MimeType).mime_type in [
"image/jpeg", "image/tiff"
]
def test_get_version_from_id() -> None:
"""Test getting version from trait ID."""
assert Image().version == 1
class TestOverscan(Overscan):
id = "ayon.2d.Overscan.v2"
assert TestOverscan(
left=0,
right=0,
top=0,
bottom=0
).version == 2
class TestMimeType(MimeType):
id = "ayon.content.MimeType"
assert TestMimeType(mime_type="foo/bar").version is None
def test_from_dict() -> None:
"""Test creating representation from dictionary."""
traits_data = {
"ayon.content.FileLocation.v1": {
"file_path": "/path/to/file",
"file_size": 1024,
"file_hash": None,
},
"ayon.2d.Image.v1": {},
}
representation = Representation.from_dict(
"test", trait_data=traits_data)
assert len(representation) == 2
assert representation.get_trait_by_id("ayon.content.FileLocation.v1")
assert representation.get_trait_by_id("ayon.2d.Image.v1")
traits_data = {
"ayon.content.FileLocation.v999": {
"file_path": "/path/to/file",
"file_size": 1024,
"file_hash": None,
},
}
with pytest.raises(ValueError, match="Trait model with ID .* not found."):
representation = Representation.from_dict(
"test", trait_data=traits_data)
traits_data = {
"ayon.content.FileLocation": {
"file_path": "/path/to/file",
"file_size": 1024,
"file_hash": None,
},
}
representation = Representation.from_dict(
"test", trait_data=traits_data)
assert len(representation) == 1
assert representation.get_trait_by_id("ayon.content.FileLocation.v1")
# this won't work right now because we would need to somewhat mock
# the import
"""
from .lib import NewTestTrait
traits_data = {
"ayon.test.NewTestTrait.v1": {},
}
representation = Representation.from_dict(
"test", trait_data=traits_data)
"""