Merge branch 'feature/909-define-basic-trait-type-using-dataclasses' into feature/911-new-traits-based-integrator

This commit is contained in:
Ondřej Samohel 2024-11-08 17:49:41 +01:00
commit 098cc383aa
No known key found for this signature in database
GPG key ID: 02376E18990A97C6

View file

@ -7,7 +7,7 @@ import sys
import uuid
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import ClassVar, Optional, Type, Union
from typing import ClassVar, Optional, Type, TypeVar, Union
import pydantic.alias_generators
from pydantic import (
@ -105,6 +105,9 @@ class TraitBase(ABC, BaseModel):
return re.sub(r"\.v\d+$", "", str(cls.id))
T = TypeVar("T", bound=TraitBase)
class Representation:
"""Representation of products.
@ -287,7 +290,7 @@ class Representation:
self.contains_trait_by_id(trait_id) for trait_id in trait_ids
)
def get_trait(self, trait: Type[TraitBase]) -> Union[TraitBase, None]:
def get_trait(self, trait: Type[T]) -> Union[T]:
"""Get a trait from the representation.
Args:
@ -296,10 +299,17 @@ class Representation:
Returns:
TraitBase: Trait instance.
"""
return self._data[trait.id] if self._data.get(trait.id) else None
Raises:
ValueError: If the trait is not found.
def get_trait_by_id(self, trait_id: str) -> Union[TraitBase, None]:
"""
try:
return self._data[trait.id]
except KeyError as e:
msg = f"Trait with ID {trait.id} not found."
raise ValueError(msg) from e
def get_trait_by_id(self, trait_id: str) -> Union[T]:
# sourcery skip: use-named-expression
"""Get a trait from the representation by id.
@ -309,12 +319,19 @@ class Representation:
Returns:
TraitBase: Trait instance.
Raises:
ValueError: If the trait is not found.
"""
version = _get_version_from_id(trait_id)
if version:
return self._data.get(trait_id)
try:
return self._data[trait_id]
except KeyError as e:
msg = f"Trait with ID {trait_id} not found."
raise ValueError(msg) from e
return next(
result = next(
(
self._data.get(trait_id)
for trait_id in self._data
@ -322,9 +339,14 @@ class Representation:
),
None,
)
if not result:
msg = f"Trait with ID {trait_id} not found."
raise ValueError(msg)
return result
def get_traits(self,
traits: Optional[list[Type[TraitBase]]]=None) -> dict:
traits: Optional[list[Type[TraitBase]]]=None
) -> dict[str, T]:
"""Get a list of traits from the representation.
If no trait IDs or traits are provided, all traits will be returned.
@ -346,7 +368,7 @@ class Representation:
result[trait.id] = self.get_trait(trait=trait)
return result
def get_traits_by_ids(self, trait_ids: list[str]) -> dict:
def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]:
"""Get a list of traits from the representation by their id.
If no trait IDs or traits are provided, all traits will be returned.