♻️ raise exception if trait not found instead of returning None

Raising exception is more pythonic than returning just None. Also some changes in return type annotations.
This commit is contained in:
Ondřej Samohel 2024-11-08 17:49:11 +01:00
parent db5d997ce7
commit e4377e8f07
No known key found for this signature in database
GPG key ID: 02376E18990A97C6

View file

@ -7,7 +7,7 @@ import sys
import uuid
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import ClassVar, Optional, Type, Union
from typing import ClassVar, Optional, Type, TypeVar, Union
import pydantic.alias_generators
from pydantic import (
@ -105,6 +105,9 @@ class TraitBase(ABC, BaseModel):
return re.sub(r"\.v\d+$", "", str(cls.id))
T = TypeVar("T", bound=TraitBase)
class Representation:
"""Representation of products.
@ -287,7 +290,7 @@ class Representation:
self.contains_trait_by_id(trait_id) for trait_id in trait_ids
)
def get_trait(self, trait: Type[TraitBase]) -> Union[TraitBase, None]:
def get_trait(self, trait: Type[T]) -> Union[T]:
"""Get a trait from the representation.
Args:
@ -296,10 +299,17 @@ class Representation:
Returns:
TraitBase: Trait instance.
"""
return self._data[trait.id] if self._data.get(trait.id) else None
Raises:
ValueError: If the trait is not found.
def get_trait_by_id(self, trait_id: str) -> Union[TraitBase, None]:
"""
try:
return self._data[trait.id]
except KeyError as e:
msg = f"Trait with ID {trait.id} not found."
raise ValueError(msg) from e
def get_trait_by_id(self, trait_id: str) -> Union[T]:
# sourcery skip: use-named-expression
"""Get a trait from the representation by id.
@ -309,12 +319,19 @@ class Representation:
Returns:
TraitBase: Trait instance.
Raises:
ValueError: If the trait is not found.
"""
version = _get_version_from_id(trait_id)
if version:
return self._data.get(trait_id)
try:
return self._data[trait_id]
except KeyError as e:
msg = f"Trait with ID {trait_id} not found."
raise ValueError(msg) from e
return next(
result = next(
(
self._data.get(trait_id)
for trait_id in self._data
@ -322,9 +339,14 @@ class Representation:
),
None,
)
if not result:
msg = f"Trait with ID {trait_id} not found."
raise ValueError(msg)
return result
def get_traits(self,
traits: Optional[list[Type[TraitBase]]]=None) -> dict:
traits: Optional[list[Type[TraitBase]]]=None
) -> dict[str, T]:
"""Get a list of traits from the representation.
If no trait IDs or traits are provided, all traits will be returned.
@ -346,7 +368,7 @@ class Representation:
result[trait.id] = self.get_trait(trait=trait)
return result
def get_traits_by_ids(self, trait_ids: list[str]) -> dict:
def get_traits_by_ids(self, trait_ids: list[str]) -> dict[str, T]:
"""Get a list of traits from the representation by their id.
If no trait IDs or traits are provided, all traits will be returned.