♻️ refactor TraitsData to Representation

added few helper methods to query/set/remove bunch of traits at once
This commit is contained in:
Ondřej Samohel 2024-10-10 14:28:53 +02:00
parent 3981a2e4da
commit 6d07307898
No known key found for this signature in database
GPG key ID: 02376E18990A97C6
3 changed files with 161 additions and 55 deletions

View file

@ -1,7 +1,7 @@
"""Trait classes for the pipeline."""
from .content import Compressed, FileLocation, RootlessLocation
from .three_dimensional import Spatial
from .trait import TraitBase, TraitsData
from .trait import Representation, TraitBase
from .two_dimensional import (
Deep,
Image,
@ -13,7 +13,7 @@ from .two_dimensional import (
__all__ = [
# base
"TraitBase",
"TraitsData",
"Representation",
# content
"FileLocation",
"RootlessLocation",

View file

@ -1,4 +1,4 @@
"""Defines the base trait model."""
"""Defines the base trait model and representation."""
from __future__ import annotations
import inspect
@ -16,6 +16,9 @@ class TraitBase(ABC, BaseModel):
"""Base trait model.
This model must be used as a base for all trait models.
It is using Pydantic BaseModel for serialization and validation.
``id``, ``name``, and ``description`` are abstract attributes that must be
implemented in the derived classes.
"""
@ -45,10 +48,15 @@ class TraitBase(ABC, BaseModel):
class TraitsData:
"""Traits data container.
class Representation:
"""Representation of products.
This model represents the data of a trait.
Representation defines collection of individual properties that describe
the specific "form" of the product. Each property is represented by a
trait therefore the Representation is a collection of traits.
It holds methods to add, remove, get, and check for the existence of a
trait in the representation. It also provides a method to get all the
"""
_data: dict
@ -90,8 +98,8 @@ class TraitsData:
return None
def add(self, trait: TraitBase, *, exists_ok: bool=False) -> None:
"""Add a trait to the data.
def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None:
"""Add a trait to the Representation.
Args:
trait (TraitBase): Trait to add.
@ -111,9 +119,22 @@ class TraitsData:
raise ValueError(error_msg)
self._data[trait.id] = trait
def remove(self,
trait_id: Optional[str],
trait: Optional[Type[TraitBase]]) -> None:
def add_traits(
self, traits: list[TraitBase], *, exists_ok: bool=False) -> None:
"""Add a list of traits to the Representation.
Args:
traits (list[TraitBase]): List of traits to add.
exists_ok (bool, optional): If True, do not raise an error if the
trait already exists. Defaults to False.
"""
for trait in traits:
self.add_trait(trait, exists_ok=exists_ok)
def remove_trait(self,
trait_id: Optional[str]=None,
trait: Optional[Type[TraitBase]]=None) -> None:
"""Remove a trait from the data.
Args:
@ -126,6 +147,23 @@ class TraitsData:
elif trait:
self._data.pop(trait.id)
def remove_traits(self,
trait_ids: Optional[list[str]]=None,
traits: Optional[list[Type[TraitBase]]]=None) -> None:
"""Remove a list of traits from the Representation.
Args:
trait_ids (list[str], optional): List of trait IDs.
traits (list[TraitBase], optional): List of trait classes.
"""
if trait_ids:
for trait_id in trait_ids:
self.remove_trait(trait_id=trait_id)
elif traits:
for trait in traits:
self.remove_trait(trait=trait)
def has_trait(self,
trait_id: Optional[str]=None,
trait: Optional[Type[TraitBase]]=None) -> bool:
@ -143,10 +181,34 @@ class TraitsData:
trait_id = trait.id
return hasattr(self, trait_id)
def get(self,
trait_id: Optional[str]=None,
trait: Optional[Type[TraitBase]]=None) -> Union[TraitBase, None]:
"""Get a trait from the data.
def has_traits(self,
trait_ids: Optional[list[str]]=None,
traits: Optional[list[Type[TraitBase]]]=None) -> bool:
"""Check if the traits exist.
Args:
trait_ids (list[str], optional): List of trait IDs.
traits (list[TraitBase], optional): List of trait classes.
Returns:
bool: True if all traits exist, False otherwise.
"""
if trait_ids:
for trait_id in trait_ids:
if not self.has_trait(trait_id=trait_id):
return False
elif traits:
for trait in traits:
if not self.has_trait(trait=trait):
return False
return True
def get_trait(self,
trait_id: Optional[str]=None,
trait: Optional[Type[TraitBase]]=None
) -> Union[TraitBase, None]:
"""Get a trait from the representation.
Args:
trait_id (str, optional): Trait ID.
@ -173,11 +235,33 @@ class TraitsData:
return self._data[trait_id] if self._data.get(trait_id) else None
def as_dict(self) -> dict:
"""Return the data as a dictionary.
def get_traits(self,
trait_ids: Optional[list[str]]=None,
traits: Optional[list[Type[TraitBase]]]=None) -> dict:
"""Get a list of traits from the representation.
Args:
trait_ids (list[str], optional): List of trait IDs.
traits (list[TraitBase], optional): List of trait classes.
Returns:
dict: Dictionary of traits.
"""
result = {}
if trait_ids:
for trait_id in trait_ids:
result[trait_id] = self.get_trait(trait_id=trait_id)
elif traits:
for trait in traits:
result[trait.id] = self.get_trait(trait=trait)
return result
def traits_as_dict(self) -> dict:
"""Return the traits from Representation data as a dictionary.
Returns:
dict: Data dictionary.
dict: Traits data dictionary.
"""
result = OrderedDict()
@ -197,4 +281,4 @@ class TraitsData:
self._data = {}
if traits:
for trait in traits:
self.add(trait)
self.add_trait(trait)

View file

@ -9,11 +9,11 @@ from ayon_core.pipeline.traits import (
Image,
PixelBased,
Planar,
Representation,
TraitBase,
TraitsData,
)
TRAITS_DATA = {
REPRESENTATION_DATA = {
FileLocation.id: {
"file_path": Path("/path/to/file"),
"file_size": 1024,
@ -32,52 +32,74 @@ TRAITS_DATA = {
@pytest.fixture()
def traits_data() -> TraitsData:
def representation() -> Representation:
"""Return a traits data instance."""
return TraitsData(traits=[
FileLocation(**TRAITS_DATA[FileLocation.id]),
return Representation(traits=[
FileLocation(**REPRESENTATION_DATA[FileLocation.id]),
Image(),
PixelBased(**TRAITS_DATA[PixelBased.id]),
Planar(**TRAITS_DATA[Planar.id]),
PixelBased(**REPRESENTATION_DATA[PixelBased.id]),
Planar(**REPRESENTATION_DATA[Planar.id]),
])
def test_traits_data(traits_data: TraitsData) -> None:
def test_representation_traits(representation: Representation) -> None:
"""Test setting and getting traits."""
assert len(traits_data) == len(TRAITS_DATA)
assert traits_data.get(trait_id=FileLocation.id)
assert traits_data.get(trait_id=Image.id)
assert traits_data.get(trait_id=PixelBased.id)
assert traits_data.get(trait_id=Planar.id)
assert len(representation) == len(REPRESENTATION_DATA)
assert representation.get_trait(trait_id=FileLocation.id)
assert representation.get_trait(trait_id=Image.id)
assert representation.get_trait(trait_id=PixelBased.id)
assert representation.get_trait(trait_id=Planar.id)
assert traits_data.get(trait=FileLocation)
assert traits_data.get(trait=Image)
assert traits_data.get(trait=PixelBased)
assert traits_data.get(trait=Planar)
assert representation.get_trait(trait=FileLocation)
assert representation.get_trait(trait=Image)
assert representation.get_trait(trait=PixelBased)
assert representation.get_trait(trait=Planar)
assert issubclass(type(traits_data.get(trait=FileLocation)), TraitBase)
assert issubclass(
type(representation.get_trait(trait=FileLocation)), TraitBase)
assert traits_data.get(
trait=FileLocation) == traits_data.get(trait_id=FileLocation.id)
assert traits_data.get(
trait=Image) == traits_data.get(trait_id=Image.id)
assert traits_data.get(
trait=PixelBased) == traits_data.get(trait_id=PixelBased.id)
assert traits_data.get(
trait=Planar) == traits_data.get(trait_id=Planar.id)
assert representation.get_trait(
trait=FileLocation) == representation.get_trait(
trait_id=FileLocation.id)
assert representation.get_trait(
trait=Image) == representation.get_trait(
trait_id=Image.id)
assert representation.get_trait(
trait=PixelBased) == representation.get_trait(
trait_id=PixelBased.id)
assert representation.get_trait(
trait=Planar) == representation.get_trait(
trait_id=Planar.id)
assert traits_data.get(trait_id="ayon.2d.Image.v1")
assert traits_data.get(trait_id="ayon.2d.PixelBased.v1")
assert traits_data.get(trait_id="ayon.2d.Planar.v1")
assert representation.get_trait(trait_id="ayon.2d.Image.v1")
assert representation.get_trait(trait_id="ayon.2d.PixelBased.v1")
assert representation.get_trait(trait_id="ayon.2d.Planar.v1")
assert traits_data.get(
assert representation.get_trait(
trait_id="ayon.2d.PixelBased.v1").display_window_width == \
TRAITS_DATA[PixelBased.id]["display_window_width"]
assert traits_data.get(
REPRESENTATION_DATA[PixelBased.id]["display_window_width"]
assert representation.get_trait(
trait=PixelBased).display_window_height == \
TRAITS_DATA[PixelBased.id]["display_window_height"]
REPRESENTATION_DATA[PixelBased.id]["display_window_height"]
def test_getting_traits_data(representation: Representation) -> None:
"""Test getting a batch of traits."""
result = representation.get_traits(
trait_ids=[FileLocation.id, Image.id, PixelBased.id, Planar.id])
assert result == {
"ayon.2d.Image.v1": Image(),
"ayon.2d.PixelBased.v1": PixelBased(
display_window_width=1920,
display_window_height=1080,
pixel_aspect_ratio=1.0),
"ayon.2d.Planar.v1": Planar(planar_configuration="RGB"),
"ayon.content.FileLocation.v1": FileLocation(
file_path=Path("/path/to/file"),
file_size=1024,
file_hash=None)
}
def test_traits_data_to_dict(traits_data: TraitsData) -> None:
def test_traits_data_to_dict(representation: Representation) -> None:
"""Test converting traits data to dictionary."""
result = traits_data.as_dict()
assert result == TRAITS_DATA
result = representation.traits_as_dict()
assert result == REPRESENTATION_DATA