diff --git a/client/ayon_core/pipeline/traits/trait.py b/client/ayon_core/pipeline/traits/trait.py index 22e7fc6d64..20ad9a1316 100644 --- a/client/ayon_core/pipeline/traits/trait.py +++ b/client/ayon_core/pipeline/traits/trait.py @@ -7,7 +7,7 @@ import sys import uuid from abc import ABC, abstractmethod from functools import lru_cache -from typing import ClassVar, Optional, Type, Union +from typing import ClassVar, Optional, Type, TypeVar, Union import pydantic.alias_generators from pydantic import ( @@ -105,6 +105,9 @@ class TraitBase(ABC, BaseModel): return re.sub(r"\.v\d+$", "", str(cls.id)) +T = TypeVar("T", bound=TraitBase) + + class Representation: """Representation of products. @@ -287,7 +290,7 @@ class Representation: self.contains_trait_by_id(trait_id) for trait_id in trait_ids ) - def get_trait(self, trait: Type[TraitBase]) -> Union[TraitBase, None]: + def get_trait(self, trait: Type[T]) -> Union[T]: """Get a trait from the representation. Args: @@ -296,10 +299,17 @@ class Representation: Returns: TraitBase: Trait instance. - """ - return self._data[trait.id] if self._data.get(trait.id) else None + Raises: + ValueError: If the trait is not found. - def get_trait_by_id(self, trait_id: str) -> Union[TraitBase, None]: + """ + try: + return self._data[trait.id] + except KeyError as e: + msg = f"Trait with ID {trait.id} not found." + raise ValueError(msg) from e + + def get_trait_by_id(self, trait_id: str) -> Union[T]: # sourcery skip: use-named-expression """Get a trait from the representation by id. @@ -309,12 +319,19 @@ class Representation: Returns: TraitBase: Trait instance. + Raises: + ValueError: If the trait is not found. + """ version = _get_version_from_id(trait_id) if version: - return self._data.get(trait_id) + try: + return self._data[trait_id] + except KeyError as e: + msg = f"Trait with ID {trait_id} not found." + raise ValueError(msg) from e - return next( + result = next( ( self._data.get(trait_id) for trait_id in self._data @@ -322,9 +339,14 @@ class Representation: ), None, ) + if not result: + msg = f"Trait with ID {trait_id} not found." + raise ValueError(msg) + return result def get_traits(self, - traits: Optional[list[Type[TraitBase]]]=None) -> dict: + traits: Optional[list[Type[TraitBase]]]=None + ) -> dict[str, T]: """Get a list of traits from the representation. If no trait IDs or traits are provided, all traits will be returned. @@ -346,7 +368,7 @@ class Representation: result[trait.id] = self.get_trait(trait=trait) return result - def get_traits_by_ids(self, trait_ids: list[str]) -> dict: + def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]: """Get a list of traits from the representation by their id. If no trait IDs or traits are provided, all traits will be returned.