Merge remote-tracking branch 'origin/feature/909-define-basic-trait-type-using-dataclasses' into feature/911-new-traits-based-integrator

This commit is contained in:
Ondřej Samohel 2024-11-12 23:41:42 +01:00
commit 8d9d9de153
No known key found for this signature in database
GPG key ID: 02376E18990A97C6
9 changed files with 668 additions and 641 deletions

View file

@ -13,6 +13,7 @@ from .content import (
from .cryptography import DigitallySigned, GPGSigned
from .lifecycle import Persistent, Transient
from .meta import Tagged, TemplatePath
from .representation import Representation
from .three_dimensional import Geometry, IESProfile, Lighting, Shader, Spatial
from .time import (
FrameRanged,
@ -22,7 +23,7 @@ from .time import (
SMPTETimecode,
Static,
)
from .trait import MissingTraitError, Representation, TraitBase
from .trait import MissingTraitError, TraitBase
from .two_dimensional import (
UDIM,
Deep,

View file

@ -7,14 +7,14 @@ from typing import ClassVar, Optional
from pydantic import Field
from .time import Sequence
from .representation import Representation
from .time import FrameRanged
from .trait import (
MissingTraitError,
Representation,
TraitBase,
TraitValidationError,
get_sequence_from_files,
)
from .utils import get_sequence_from_files
class MimeType(TraitBase):
@ -125,26 +125,25 @@ class FileLocations(TraitBase):
msg = "No file locations defined (empty list)"
raise TraitValidationError(self.name, msg)
tmp_seq: Sequence = get_sequence_from_files(
tmp_frame_ranged: FrameRanged = get_sequence_from_files(
[f.file_path for f in self.file_paths])
if len(self.file_paths) != \
tmp_seq.frame_end - tmp_seq.frame_start:
if len(self.file_paths) - 1 != \
tmp_frame_ranged.frame_end - tmp_frame_ranged.frame_start:
# If the number of file paths does not match the frame range,
# the trait is invalid
msg = (
f"Number of file locations ({len(self.file_paths)}) "
f"Number of file locations ({len(self.file_paths) - 1}) "
"does not match frame range "
f"({tmp_seq.frame_end - tmp_seq.frame_start})"
f"({tmp_frame_ranged.frame_end - tmp_frame_ranged.frame_start})"
)
raise TraitValidationError(self.name, msg)
try:
sequence: Sequence = representation.get_trait(Sequence)
sequence: FrameRanged = representation.get_trait(FrameRanged)
if sequence.frame_start != tmp_seq.frame_start or \
sequence.frame_end != tmp_seq.frame_end or \
sequence.frame_padding != tmp_seq.frame_padding:
if sequence.frame_start != tmp_frame_ranged.frame_start or \
sequence.frame_end != tmp_frame_ranged.frame_end:
# If the frame range does not match the sequence trait, the
# trait is invalid. Note that we don't check the frame rate
# because it is not stored in the file paths and is not
@ -154,7 +153,7 @@ class FileLocations(TraitBase):
f"({sequence.frame_start}-{sequence.frame_end}) "
"in sequence trait does not match "
"frame range "
f"({tmp_seq.frame_start}-{tmp_seq.frame_end}) "
f"({tmp_frame_ranged.frame_start}-{tmp_frame_ranged.frame_end}) "
"defined in files."
)
raise TraitValidationError(self.name, msg)

View file

@ -0,0 +1,609 @@
"""Defines the base trait model and representation."""
from __future__ import annotations
import inspect
import re
import sys
import uuid
from functools import lru_cache
from typing import ClassVar, Optional, Type, TypeVar, Union
from .trait import (
IncompatibleTraitVersionError,
LooseMatchingTraitError,
MissingTraitError,
TraitBase,
UpgradableTraitError,
)
T = TypeVar("T", bound=TraitBase)
def _get_version_from_id(_id: str) -> int:
"""Get version from ID.
Args:
_id (str): ID.
Returns:
int: Version.
"""
match = re.search(r"v(\d+)$", _id)
return int(match[1]) if match else None
class Representation:
"""Representation of products.
Representation defines collection of individual properties that describe
the specific "form" of the product. Each property is represented by a
trait therefore the Representation is a collection of traits.
It holds methods to add, remove, get, and check for the existence of a
trait in the representation. It also provides a method to get all the
Arguments:
name (str): Representation name. Must be unique within instance.
representation_id (str): Representation ID.
"""
_data: dict
_module_blacklist: ClassVar[list[str]] = [
"_", "builtins", "pydantic"]
name: str
representation_id: str
def __hash__(self):
"""Return hash of the representation ID."""
return hash(self.representation_id)
def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None:
"""Add a trait to the Representation.
Args:
trait (TraitBase): Trait to add.
exists_ok (bool, optional): If True, do not raise an error if the
trait already exists. Defaults to False.
Raises:
ValueError: If the trait ID is not provided or the trait already
exists.
"""
if not hasattr(trait, "id"):
error_msg = f"Invalid trait {trait} - ID is required."
raise ValueError(error_msg)
if trait.id in self._data and not exists_ok:
error_msg = f"Trait with ID {trait.id} already exists."
raise ValueError(error_msg)
self._data[trait.id] = trait
def add_traits(
self, traits: list[TraitBase], *, exists_ok: bool=False) -> None:
"""Add a list of traits to the Representation.
Args:
traits (list[TraitBase]): List of traits to add.
exists_ok (bool, optional): If True, do not raise an error if the
trait already exists. Defaults to False.
"""
for trait in traits:
self.add_trait(trait, exists_ok=exists_ok)
def remove_trait(self, trait: Type[TraitBase]) -> None:
"""Remove a trait from the data.
Args:
trait (TraitBase, optional): Trait class.
Raises:
ValueError: If the trait is not found.
"""
try:
self._data.pop(trait.id)
except KeyError as e:
error_msg = f"Trait with ID {trait.id} not found."
raise ValueError(error_msg) from e
def remove_trait_by_id(self, trait_id: str) -> None:
"""Remove a trait from the data by its ID.
Args:
trait_id (str): Trait ID.
Raises:
ValueError: If the trait is not found.
"""
try:
self._data.pop(trait_id)
except KeyError as e:
error_msg = f"Trait with ID {trait_id} not found."
raise ValueError(error_msg) from e
def remove_traits(self, traits: list[Type[TraitBase]]) -> None:
"""Remove a list of traits from the Representation.
If no trait IDs or traits are provided, all traits will be removed.
Args:
traits (list[TraitBase]): List of trait classes.
"""
if not traits:
self._data = {}
return
for trait in traits:
self.remove_trait(trait)
def remove_traits_by_id(self, trait_ids: list[str]) -> None:
"""Remove a list of traits from the Representation by their ID.
If no trait IDs or traits are provided, all traits will be removed.
Args:
trait_ids (list[str], optional): List of trait IDs.
"""
for trait_id in trait_ids:
self.remove_trait_by_id(trait_id)
def has_traits(self) -> bool:
"""Check if the Representation has any traits.
Returns:
bool: True if the Representation has any traits, False otherwise.
"""
return bool(self._data)
def contains_trait(self, trait: Type[TraitBase]) -> bool:
"""Check if the trait exists in the Representation.
Args:
trait (TraitBase): Trait class.
Returns:
bool: True if the trait exists, False otherwise.
"""
return bool(self._data.get(trait.id))
def contains_trait_by_id(self, trait_id: str) -> bool:
"""Check if the trait exists using trait id.
Args:
trait_id (str): Trait ID.
Returns:
bool: True if the trait exists, False otherwise.
"""
return bool(self._data.get(trait_id))
def contains_traits(self, traits: list[Type[TraitBase]]) -> bool:
"""Check if the traits exist.
Args:
traits (list[TraitBase], optional): List of trait classes.
Returns:
bool: True if all traits exist, False otherwise.
"""
return all(self.contains_trait(trait=trait) for trait in traits)
def contains_traits_by_id(self, trait_ids: list[str]) -> bool:
"""Check if the traits exist by id.
If no trait IDs or traits are provided, it will check if the
representation has any traits.
Args:
trait_ids (list[str]): List of trait IDs.
Returns:
bool: True if all traits exist, False otherwise.
"""
return all(
self.contains_trait_by_id(trait_id) for trait_id in trait_ids
)
def get_trait(self, trait: Type[T]) -> Union[T]:
"""Get a trait from the representation.
Args:
trait (TraitBase, optional): Trait class.
Returns:
TraitBase: Trait instance.
Raises:
MissingTraitError: If the trait is not found.
"""
try:
return self._data[trait.id]
except KeyError as e:
msg = f"Trait with ID {trait.id} not found."
raise MissingTraitError(msg) from e
def get_trait_by_id(self, trait_id: str) -> Union[T]:
# sourcery skip: use-named-expression
"""Get a trait from the representation by id.
Args:
trait_id (str): Trait ID.
Returns:
TraitBase: Trait instance.
Raises:
MissingTraitError: If the trait is not found.
"""
version = _get_version_from_id(trait_id)
if version:
try:
return self._data[trait_id]
except KeyError as e:
msg = f"Trait with ID {trait_id} not found."
raise MissingTraitError(msg) from e
result = next(
(
self._data.get(trait_id)
for trait_id in self._data
if trait_id.startswith(trait_id)
),
None,
)
if not result:
msg = f"Trait with ID {trait_id} not found."
raise MissingTraitError(msg)
return result
def get_traits(self,
traits: Optional[list[Type[TraitBase]]]=None
) -> dict[str, T]:
"""Get a list of traits from the representation.
If no trait IDs or traits are provided, all traits will be returned.
Args:
traits (list[TraitBase], optional): List of trait classes.
Returns:
dict: Dictionary of traits.
"""
result = {}
if not traits:
for trait_id in self._data:
result[trait_id] = self.get_trait_by_id(trait_id=trait_id)
return result
for trait in traits:
result[trait.id] = self.get_trait(trait=trait)
return result
def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]:
"""Get a list of traits from the representation by their id.
If no trait IDs or traits are provided, all traits will be returned.
Args:
trait_ids (list[str]): List of trait IDs.
Returns:
dict: Dictionary of traits.
"""
return {
trait_id: self.get_trait_by_id(trait_id)
for trait_id in trait_ids
}
def traits_as_dict(self) -> dict:
"""Return the traits from Representation data as a dictionary.
Returns:
dict: Traits data dictionary.
"""
return {
trait_id: trait.model_dump()
for trait_id, trait in self._data.items()
if trait and trait_id
}
def __len__(self):
"""Return the length of the data."""
return len(self._data)
def __init__(
self,
name: str,
representation_id: Optional[str]=None,
traits: Optional[list[TraitBase]]=None):
"""Initialize the data.
Args:
name (str): Representation name. Must be unique within instance.
representation_id (str, optional): Representation ID.
traits (list[TraitBase], optional): List of traits.
"""
self.name = name
self.representation_id = representation_id or uuid.uuid4().hex
self._data = {}
if traits:
for trait in traits:
self.add_trait(trait)
@staticmethod
def _get_version_from_id(trait_id: str) -> Union[int, None]:
# sourcery skip: use-named-expression
"""Check if the trait has version specified.
Args:
trait_id (str): Trait ID.
Returns:
int: Trait version.
None: If the trait id does not have a version.
"""
version_regex = r"v(\d+)$"
match = re.search(version_regex, trait_id)
return int(match[1]) if match else None
def __eq__(self, other: Representation) -> bool: # noqa: PLR0911
"""Check if the representation is equal to another.
Args:
other (Representation): Representation to compare.
Returns:
bool: True if the representations are equal, False otherwise.
"""
if self.representation_id != other.representation_id:
return False
if not isinstance(other, Representation):
return False
if self.name != other.name:
return False
# number of traits
if len(self) != len(other):
return False
for trait_id, trait in self._data.items():
if trait_id not in other._data:
return False
if trait != other._data[trait_id]:
return False
for key, value in trait.model_dump().items():
if value != other._data[trait_id].model_dump().get(key):
return False
return True
@classmethod
@lru_cache(maxsize=64)
def _get_possible_trait_classes_from_modules(
cls,
trait_id: str) -> set[type[TraitBase]]:
"""Get possible trait classes from modules.
Args:
trait_id (str): Trait ID.
Returns:
set[type[TraitBase]]: Set of trait classes.
"""
modules = sys.modules.copy()
filtered_modules = modules.copy()
for module_name in modules:
for bl_module in cls._module_blacklist:
if module_name.startswith(bl_module):
filtered_modules.pop(module_name)
trait_candidates = set()
for module in filtered_modules.values():
if not module:
continue
for _, klass in inspect.getmembers(module, inspect.isclass):
if inspect.isclass(klass) \
and issubclass(klass, TraitBase) \
and str(klass.id).startswith(trait_id):
trait_candidates.add(klass)
return trait_candidates
@classmethod
@lru_cache(maxsize=64)
def _get_trait_class(
cls, trait_id: str) -> Union[Type[TraitBase], None]:
"""Get the trait class with corresponding to given ID.
This method will search for the trait class in all the modules except
the blacklisted modules. There is some issue in Pydantic where
``issubclass`` is not working properly so we are excluding explicitly
modules with offending classes. This list can be updated as needed to
speed up the search.
Args:
trait_id (str): Trait ID.
Returns:
Type[TraitBase]: Trait class.
Raises:
LooseMatchingTraitError: If the trait is found with a loose
matching criteria. This exception will include the trait
class that was found and the expected trait ID. Additional
downstream logic must decide how to handle this error.
"""
version = cls._get_version_from_id(trait_id)
trait_candidates = cls._get_possible_trait_classes_from_modules(
trait_id
)
for trait_class in trait_candidates:
if trait_class.id == trait_id:
# we found direct match
return trait_class
# if we didn't find direct match, we will search for the highest
# version of the trait.
if not version:
# sourcery skip: use-named-expression
trait_versions = [
trait_class for trait_class in trait_candidates
if re.match(
rf"{trait_id}.v(\d+)$", str(trait_class.id))
]
if trait_versions:
def _get_version_by_id(trait_klass: Type[TraitBase]) -> int:
match = re.search(r"v(\d+)$", str(trait_klass.id))
return int(match[1]) if match else 0
error = LooseMatchingTraitError(
"Found trait that might match.")
error.found_trait = max(
trait_versions, key=_get_version_by_id)
error.expected_id = trait_id
raise error
return None
@classmethod
def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]:
"""Get the trait class for the given trait ID.
Args:
trait_id (str): Trait ID.
Returns:
type[TraitBase]: Trait class.
Raises:
IncompatibleTraitVersionError: If the trait version is incompatible
with the current version of the trait.
UpgradableTraitError: If the trait can upgrade existing data
meant for older versions of the trait.
ValueError: If the trait model with the given ID is not found.
"""
trait_class = None
try:
trait_class = cls._get_trait_class(trait_id=trait_id)
except LooseMatchingTraitError as e:
requested_version = _get_version_from_id(trait_id)
found_version = _get_version_from_id(e.found_trait.id)
if not requested_version:
trait_class = e.found_trait
else:
if requested_version > found_version:
error_msg = (
f"Requested trait version {requested_version} is "
f"higher than the found trait version {found_version}."
)
raise IncompatibleTraitVersionError(error_msg) from e
if requested_version < found_version and hasattr(
e.found_trait, "upgrade"):
error_msg = (
"Requested trait version "
f"{requested_version} is lower "
f"than the found trait version {found_version}."
)
error = UpgradableTraitError(error_msg)
error.trait = e.found_trait
raise error from e
return trait_class
@classmethod
def from_dict(
cls,
name: str,
representation_id: Optional[str]=None,
trait_data: Optional[dict] = None) -> Representation:
"""Create a representation from a dictionary.
Args:
name (str): Representation name.
representation_id (str, optional): Representation ID.
trait_data (dict): Representation data. Dictionary with keys
as trait ids and values as trait data. Example::
{
"ayon.2d.PixelBased.v1": {
"display_window_width": 1920,
"display_window_height": 1080
},
"ayon.2d.Planar.v1": {
"channels": 3
}
}
Returns:
Representation: Representation instance.
"""
traits = []
for trait_id, value in trait_data.items():
if not isinstance(value, dict):
msg = (
f"Invalid trait data for trait ID {trait_id}. "
"Trait data must be a dictionary."
)
raise TypeError(msg)
try:
trait_class = cls.get_trait_class_by_trait_id(trait_id)
except UpgradableTraitError as e:
# we found newer version of trait, we will upgrade the data
if hasattr(e.trait, "upgrade"):
traits.append(e.trait.upgrade(value))
else:
msg = (
f"Newer version of trait {e.trait.id} found "
f"for requested {trait_id} but without "
"upgrade method."
)
raise IncompatibleTraitVersionError(msg) from e
else:
if not trait_class:
error_msg = f"Trait model with ID {trait_id} not found."
raise ValueError(error_msg)
traits.append(trait_class(**value))
return cls(
name=name, representation_id=representation_id, traits=traits)
def validate(self) -> bool:
"""Validate the representation.
This method will validate all the traits in the representation.
Returns:
bool: True if the representation is valid, False otherwise.
"""
return all(trait.validate(self) for trait in self._data.values())

View file

@ -2,16 +2,15 @@
from __future__ import annotations
from enum import Enum, auto
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Union
from pydantic import Field
from pydantic import Field, PlainSerializer
from .content import FileLocations
from .trait import MissingTraitError, Representation, TraitBase
from .representation import Representation
from .trait import MissingTraitError, TraitBase
if TYPE_CHECKING:
from decimal import Decimal
from fractions import Fraction
class GapPolicy(Enum):
@ -42,6 +41,13 @@ class FrameRanged(TraitBase):
* frame_end -> end_frame
...
Note: frames_per_second is a string to allow various precision
formats. FPS is a floating point number, but it can be also
represented as a fraction (e.g. "30000/1001") or as a decimal
or even as irrational number. We need to support all these
formats. To work with FPS, we'll need some helper function
to convert FPS to Decimal from string.
Attributes:
name (str): Trait name.
description (str): Trait description.
@ -50,7 +56,7 @@ class FrameRanged(TraitBase):
frame_end (int): Frame end.
frame_in (int): Frame in.
frame_out (int): Frame out.
frames_per_second (float, Fraction, Decimal): Frames per second.
frames_per_second (str): Frames per second.
step (int): Step.
"""
@ -63,8 +69,7 @@ class FrameRanged(TraitBase):
..., title="Frame Start")
frame_in: Optional[int] = Field(None, title="In Frame")
frame_out: Optional[int] = Field(None, title="Out Frame")
frames_per_second: Union[float, Fraction, Decimal] = Field(
..., title="Frames Per Second")
frames_per_second: str = Field(..., title="Frames Per Second")
step: Optional[int] = Field(1, title="Step")
@ -83,9 +88,9 @@ class Handles(TraitBase):
frame_end_handle (int): Frame end handle.
"""
name: ClassVar[str] = "Clip"
description: ClassVar[str] = "Clip Trait"
id: ClassVar[str] = "ayon.time.Clip.v1"
name: ClassVar[str] = "Handles"
description: ClassVar[str] = "Handles Trait"
id: ClassVar[str] = "ayon.time.Handles.v1"
inclusive: Optional[bool] = Field(
False, title="Handles are inclusive") # noqa: FBT003
frame_start_handle: Optional[int] = Field(
@ -93,7 +98,7 @@ class Handles(TraitBase):
frame_end_handle: Optional[int] = Field(
0, title="Frame End Handle")
class Sequence(FrameRanged, Handles):
class Sequence(TraitBase):
"""Sequence trait model.
This model represents a sequence trait. Based on the FrameRanged trait
@ -130,6 +135,7 @@ class Sequence(FrameRanged, Handles):
# if there is FileLocations trait, run validation
# on it as well
try:
from .content import FileLocations
file_locs: FileLocations = representation.get_trait(
FileLocations)
file_locs.validate(representation)

View file

@ -1,13 +1,9 @@
"""Defines the base trait model and representation."""
from __future__ import annotations
import inspect
import re
import sys
import uuid
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import ClassVar, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Optional
import pydantic.alias_generators
from pydantic import (
@ -16,19 +12,8 @@ from pydantic import (
ConfigDict,
)
def _get_version_from_id(_id: str) -> int:
"""Get version from ID.
Args:
_id (str): ID.
Returns:
int: Version.
"""
match = re.search(r"v(\d+)$", _id)
return int(match[1]) if match else None
if TYPE_CHECKING:
from .representation import Representation
class TraitBase(ABC, BaseModel):
@ -106,584 +91,8 @@ class TraitBase(ABC, BaseModel):
return re.sub(r"\.v\d+$", "", str(cls.id))
T = TypeVar("T", bound=TraitBase)
class Representation:
"""Representation of products.
Representation defines collection of individual properties that describe
the specific "form" of the product. Each property is represented by a
trait therefore the Representation is a collection of traits.
It holds methods to add, remove, get, and check for the existence of a
trait in the representation. It also provides a method to get all the
Arguments:
name (str): Representation name. Must be unique within instance.
representation_id (str): Representation ID.
"""
_data: dict
_module_blacklist: ClassVar[list[str]] = [
"_", "builtins", "pydantic"]
name: str
representation_id: str
def __hash__(self):
"""Return hash of the representation ID."""
return hash(self.representation_id)
def add_trait(self, trait: TraitBase, *, exists_ok: bool=False) -> None:
"""Add a trait to the Representation.
Args:
trait (TraitBase): Trait to add.
exists_ok (bool, optional): If True, do not raise an error if the
trait already exists. Defaults to False.
Raises:
ValueError: If the trait ID is not provided or the trait already
exists.
"""
if not hasattr(trait, "id"):
error_msg = f"Invalid trait {trait} - ID is required."
raise ValueError(error_msg)
if trait.id in self._data and not exists_ok:
error_msg = f"Trait with ID {trait.id} already exists."
raise ValueError(error_msg)
self._data[trait.id] = trait
def add_traits(
self, traits: list[TraitBase], *, exists_ok: bool=False) -> None:
"""Add a list of traits to the Representation.
Args:
traits (list[TraitBase]): List of traits to add.
exists_ok (bool, optional): If True, do not raise an error if the
trait already exists. Defaults to False.
"""
for trait in traits:
self.add_trait(trait, exists_ok=exists_ok)
def remove_trait(self, trait: Type[TraitBase]) -> None:
"""Remove a trait from the data.
Args:
trait (TraitBase, optional): Trait class.
Raises:
ValueError: If the trait is not found.
"""
try:
self._data.pop(trait.id)
except KeyError as e:
error_msg = f"Trait with ID {trait.id} not found."
raise ValueError(error_msg) from e
def remove_trait_by_id(self, trait_id: str) -> None:
"""Remove a trait from the data by its ID.
Args:
trait_id (str): Trait ID.
Raises:
ValueError: If the trait is not found.
"""
try:
self._data.pop(trait_id)
except KeyError as e:
error_msg = f"Trait with ID {trait_id} not found."
raise ValueError(error_msg) from e
def remove_traits(self, traits: list[Type[TraitBase]]) -> None:
"""Remove a list of traits from the Representation.
If no trait IDs or traits are provided, all traits will be removed.
Args:
traits (list[TraitBase]): List of trait classes.
"""
if not traits:
self._data = {}
return
for trait in traits:
self.remove_trait(trait)
def remove_traits_by_id(self, trait_ids: list[str]) -> None:
"""Remove a list of traits from the Representation by their ID.
If no trait IDs or traits are provided, all traits will be removed.
Args:
trait_ids (list[str], optional): List of trait IDs.
"""
for trait_id in trait_ids:
self.remove_trait_by_id(trait_id)
def has_traits(self) -> bool:
"""Check if the Representation has any traits.
Returns:
bool: True if the Representation has any traits, False otherwise.
"""
return bool(self._data)
def contains_trait(self, trait: Type[TraitBase]) -> bool:
"""Check if the trait exists in the Representation.
Args:
trait (TraitBase): Trait class.
Returns:
bool: True if the trait exists, False otherwise.
"""
return bool(self._data.get(trait.id))
def contains_trait_by_id(self, trait_id: str) -> bool:
"""Check if the trait exists using trait id.
Args:
trait_id (str): Trait ID.
Returns:
bool: True if the trait exists, False otherwise.
"""
return bool(self._data.get(trait_id))
def contains_traits(self, traits: list[Type[TraitBase]]) -> bool:
"""Check if the traits exist.
Args:
traits (list[TraitBase], optional): List of trait classes.
Returns:
bool: True if all traits exist, False otherwise.
"""
return all(self.contains_trait(trait=trait) for trait in traits)
def contains_traits_by_id(self, trait_ids: list[str]) -> bool:
"""Check if the traits exist by id.
If no trait IDs or traits are provided, it will check if the
representation has any traits.
Args:
trait_ids (list[str]): List of trait IDs.
Returns:
bool: True if all traits exist, False otherwise.
"""
return all(
self.contains_trait_by_id(trait_id) for trait_id in trait_ids
)
def get_trait(self, trait: Type[T]) -> Union[T]:
"""Get a trait from the representation.
Args:
trait (TraitBase, optional): Trait class.
Returns:
TraitBase: Trait instance.
Raises:
MissingTraitError: If the trait is not found.
"""
try:
return self._data[trait.id]
except KeyError as e:
msg = f"Trait with ID {trait.id} not found."
raise MissingTraitError(msg) from e
def get_trait_by_id(self, trait_id: str) -> Union[T]:
# sourcery skip: use-named-expression
"""Get a trait from the representation by id.
Args:
trait_id (str): Trait ID.
Returns:
TraitBase: Trait instance.
Raises:
MissingTraitError: If the trait is not found.
"""
version = _get_version_from_id(trait_id)
if version:
try:
return self._data[trait_id]
except KeyError as e:
msg = f"Trait with ID {trait_id} not found."
raise MissingTraitError(msg) from e
result = next(
(
self._data.get(trait_id)
for trait_id in self._data
if trait_id.startswith(trait_id)
),
None,
)
if not result:
msg = f"Trait with ID {trait_id} not found."
raise MissingTraitError(msg)
return result
def get_traits(self,
traits: Optional[list[Type[TraitBase]]]=None
) -> dict[str, T]:
"""Get a list of traits from the representation.
If no trait IDs or traits are provided, all traits will be returned.
Args:
traits (list[TraitBase], optional): List of trait classes.
Returns:
dict: Dictionary of traits.
"""
result = {}
if not traits:
for trait_id in self._data:
result[trait_id] = self.get_trait_by_id(trait_id=trait_id)
return result
for trait in traits:
result[trait.id] = self.get_trait(trait=trait)
return result
def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]:
"""Get a list of traits from the representation by their id.
If no trait IDs or traits are provided, all traits will be returned.
Args:
trait_ids (list[str]): List of trait IDs.
Returns:
dict: Dictionary of traits.
"""
return {
trait_id: self.get_trait_by_id(trait_id)
for trait_id in trait_ids
}
def traits_as_dict(self) -> dict:
"""Return the traits from Representation data as a dictionary.
Returns:
dict: Traits data dictionary.
"""
return {
trait_id: trait.dict()
for trait_id, trait in self._data.items()
if trait and trait_id
}
def __len__(self):
"""Return the length of the data."""
return len(self._data)
def __init__(
self,
name: str,
representation_id: Optional[str]=None,
traits: Optional[list[TraitBase]]=None):
"""Initialize the data.
Args:
name (str): Representation name. Must be unique within instance.
representation_id (str, optional): Representation ID.
traits (list[TraitBase], optional): List of traits.
"""
self.name = name
self.representation_id = representation_id or uuid.uuid4().hex
self._data = {}
if traits:
for trait in traits:
self.add_trait(trait)
@staticmethod
def _get_version_from_id(trait_id: str) -> Union[int, None]:
# sourcery skip: use-named-expression
"""Check if the trait has version specified.
Args:
trait_id (str): Trait ID.
Returns:
int: Trait version.
None: If the trait id does not have a version.
"""
version_regex = r"v(\d+)$"
match = re.search(version_regex, trait_id)
return int(match[1]) if match else None
def __eq__(self, other: Representation) -> bool: # noqa: PLR0911
"""Check if the representation is equal to another.
Args:
other (Representation): Representation to compare.
Returns:
bool: True if the representations are equal, False otherwise.
"""
if self.representation_id != other.representation_id:
return False
if not isinstance(other, Representation):
return False
if self.name != other.name:
return False
# number of traits
if len(self) != len(other):
return False
for trait_id, trait in self._data.items():
if trait_id not in other._data:
return False
if trait != other._data[trait_id]:
return False
for key, value in trait.dict().items():
if value != other._data[trait_id].dict().get(key):
return False
return True
@classmethod
@lru_cache(maxsize=64)
def _get_possible_trait_classes_from_modules(
cls,
trait_id: str) -> set[type[TraitBase]]:
"""Get possible trait classes from modules.
Args:
trait_id (str): Trait ID.
Returns:
set[type[TraitBase]]: Set of trait classes.
"""
modules = sys.modules.copy()
filtered_modules = modules.copy()
for module_name in modules:
for bl_module in cls._module_blacklist:
if module_name.startswith(bl_module):
filtered_modules.pop(module_name)
trait_candidates = set()
for module in filtered_modules.values():
if not module:
continue
for _, klass in inspect.getmembers(module, inspect.isclass):
if inspect.isclass(klass) \
and issubclass(klass, TraitBase) \
and str(klass.id).startswith(trait_id):
trait_candidates.add(klass)
return trait_candidates
@classmethod
@lru_cache(maxsize=64)
def _get_trait_class(
cls, trait_id: str) -> Union[Type[TraitBase], None]:
"""Get the trait class with corresponding to given ID.
This method will search for the trait class in all the modules except
the blacklisted modules. There is some issue in Pydantic where
``issubclass`` is not working properly so we are excluding explicitly
modules with offending classes. This list can be updated as needed to
speed up the search.
Args:
trait_id (str): Trait ID.
Returns:
Type[TraitBase]: Trait class.
Raises:
LooseMatchingTraitError: If the trait is found with a loose
matching criteria. This exception will include the trait
class that was found and the expected trait ID. Additional
downstream logic must decide how to handle this error.
"""
version = cls._get_version_from_id(trait_id)
trait_candidates = cls._get_possible_trait_classes_from_modules(
trait_id
)
for trait_class in trait_candidates:
if trait_class.id == trait_id:
# we found direct match
return trait_class
# if we didn't find direct match, we will search for the highest
# version of the trait.
if not version:
# sourcery skip: use-named-expression
trait_versions = [
trait_class for trait_class in trait_candidates
if re.match(
rf"{trait_id}.v(\d+)$", str(trait_class.id))
]
if trait_versions:
def _get_version_by_id(trait_klass: Type[TraitBase]) -> int:
match = re.search(r"v(\d+)$", str(trait_klass.id))
return int(match[1]) if match else 0
error = LooseMatchingTraitError(
"Found trait that might match.")
error.found_trait = max(
trait_versions, key=_get_version_by_id)
error.expected_id = trait_id
raise error
return None
@classmethod
def get_trait_class_by_trait_id(cls, trait_id: str) -> type[TraitBase]:
"""Get the trait class for the given trait ID.
Args:
trait_id (str): Trait ID.
Returns:
type[TraitBase]: Trait class.
Raises:
IncompatibleTraitVersionError: If the trait version is incompatible
with the current version of the trait.
UpgradableTraitError: If the trait can upgrade existing data
meant for older versions of the trait.
ValueError: If the trait model with the given ID is not found.
"""
trait_class = None
try:
trait_class = cls._get_trait_class(trait_id=trait_id)
except LooseMatchingTraitError as e:
requested_version = _get_version_from_id(trait_id)
found_version = _get_version_from_id(e.found_trait.id)
if not requested_version:
trait_class = e.found_trait
else:
if requested_version > found_version:
error_msg = (
f"Requested trait version {requested_version} is "
f"higher than the found trait version {found_version}."
)
raise IncompatibleTraitVersionError(error_msg) from e
if requested_version < found_version and hasattr(
e.found_trait, "upgrade"):
error_msg = (
"Requested trait version "
f"{requested_version} is lower "
f"than the found trait version {found_version}."
)
error = UpgradableTraitError(error_msg)
error.trait = e.found_trait
raise error from e
return trait_class
@classmethod
def from_dict(
cls,
name: str,
representation_id: Optional[str]=None,
trait_data: Optional[dict] = None) -> Representation:
"""Create a representation from a dictionary.
Args:
name (str): Representation name.
representation_id (str, optional): Representation ID.
trait_data (dict): Representation data. Dictionary with keys
as trait ids and values as trait data. Example::
{
"ayon.2d.PixelBased.v1": {
"display_window_width": 1920,
"display_window_height": 1080
},
"ayon.2d.Planar.v1": {
"channels": 3
}
}
Returns:
Representation: Representation instance.
"""
traits = []
for trait_id, value in trait_data.items():
if not isinstance(value, dict):
msg = (
f"Invalid trait data for trait ID {trait_id}. "
"Trait data must be a dictionary."
)
raise TypeError(msg)
try:
trait_class = cls.get_trait_class_by_trait_id(trait_id)
except UpgradableTraitError as e:
# we found newer version of trait, we will upgrade the data
if hasattr(e.trait, "upgrade"):
traits.append(e.trait.upgrade(value))
else:
msg = (
f"Newer version of trait {e.trait.id} found "
f"for requested {trait_id} but without "
"upgrade method."
)
raise IncompatibleTraitVersionError(msg) from e
else:
if not trait_class:
error_msg = f"Trait model with ID {trait_id} not found."
raise ValueError(error_msg)
traits.append(trait_class(**value))
return cls(
name=name, representation_id=representation_id, traits=traits)
def validate(self) -> bool:
"""Validate the representation.
This method will validate all the traits in the representation.
Returns:
bool: True if the representation is valid, False otherwise.
"""
return all(trait.validate(self) for trait in self._data.values())
class IncompatibleTraitVersionError(Exception):

View file

@ -5,13 +5,13 @@ from typing import TYPE_CHECKING
from clique import assemble
from ayon_core.pipeline.traits import Sequence
from ayon_core.pipeline.traits.time import FrameRanged
if TYPE_CHECKING:
from pathlib import Path
def get_sequence_from_files(paths: list[Path]) -> Sequence:
def get_sequence_from_files(paths: list[Path]) -> FrameRanged:
"""Get original frame range from files.
Note that this cannot guess frame rate, so it's set to 25.
@ -20,7 +20,7 @@ def get_sequence_from_files(paths: list[Path]) -> Sequence:
paths (list[Path]): List of file paths.
Returns:
Sequence: Sequence trait.
FrameRanged: FrameRanged trait.
"""
col = assemble([path.as_posix() for path in paths])[0][0]
@ -30,9 +30,9 @@ def get_sequence_from_files(paths: list[Path]) -> Sequence:
# Get last frame for padding
last_frame = sorted_frames[-1]
# Use padding from collection of length of last frame as string
padding = max(col.padding, len(str(last_frame)))
# padding = max(col.padding, len(str(last_frame)))
return Sequence(
frame_start=first_frame, frame_end=last_frame, frame_padding=padding,
frames_per_second=25
return FrameRanged(
frame_start=first_frame, frame_end=last_frame,
frames_per_second="25.0"
)

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Tests."""

View file

@ -8,15 +8,17 @@ from ayon_core.pipeline.traits import (
Bundle,
FileLocation,
FileLocations,
FrameRanged,
Image,
MimeType,
Overscan,
PixelBased,
Planar,
Representation,
Sequence,
TraitBase,
)
from pipeline.traits import Overscan
from ayon_core.pipeline.traits.trait import TraitValidationError
REPRESENTATION_DATA = {
FileLocation.id: {
@ -340,7 +342,7 @@ def test_file_locations_validation() -> None:
file_size=1024,
file_hash=None,
)
for frame in range(1001, 1050)
for frame in range(1001, 1051)
]
representation = Representation(name="test", traits=[
@ -351,39 +353,40 @@ def test_file_locations_validation() -> None:
file_paths=file_locations_list)
# this should be valid trait
assert file_locations_trait.validate(representation) is True
file_locations_trait.validate(representation)
# add valid sequence trait
sequence_trait = Sequence(
# add valid FrameRanged trait
sequence_trait = FrameRanged(
frame_start=1001,
frame_end=1050,
frame_padding=4,
frames_per_second=25
frames_per_second="25"
)
representation.add_trait(sequence_trait)
# it should still validate fine
assert file_locations_trait.validate(representation) is True
file_locations_trait.validate(representation)
# create empty file locations trait
empty_file_locations_trait = FileLocations(file_paths=[])
representation = Representation(name="test", traits=[
empty_file_locations_trait
])
assert empty_file_locations_trait.validate(
representation) is False
with pytest.raises(TraitValidationError):
empty_file_locations_trait.validate(representation)
# create valid file locations trait but with not matching sequence
# trait
representation = Representation(name="test", traits=[
FileLocations(file_paths=file_locations_list)
])
invalid_sequence_trait = Sequence(
invalid_sequence_trait = FrameRanged(
frame_start=1001,
frame_end=1051,
frame_padding=4,
frames_per_second=25
frames_per_second="25"
)
representation.add_trait(invalid_sequence_trait)
assert file_locations_trait.validate(representation) is False
with pytest.raises(TraitValidationError):
file_locations_trait.validate(representation)

View file

@ -1,3 +1,4 @@
"""conftest.py: pytest configuration file."""
import sys
from pathlib import Path
@ -5,5 +6,3 @@ client_path = Path(__file__).resolve().parent.parent / "client"
# add client path to sys.path
sys.path.append(str(client_path))
print(f"Added {client_path} to sys.path")