added typehints to create model

This commit is contained in:
Jakub Trllo 2024-06-21 18:22:01 +02:00
parent 24f3fa28a8
commit 76071c4b87

View file

@ -1,28 +1,37 @@
import logging import logging
import re import re
from typing import Union, List, Dict, Tuple, Any, Optional, Iterable, Pattern
from ayon_core.lib.attribute_definitions import ( from ayon_core.lib.attribute_definitions import (
serialize_attr_defs, serialize_attr_defs,
deserialize_attr_defs, deserialize_attr_defs,
AbstractAttrDef,
) )
from ayon_core.lib.profiles_filtering import filter_profiles from ayon_core.lib.profiles_filtering import filter_profiles
from ayon_core.lib.attribute_definitions import UIDef from ayon_core.lib.attribute_definitions import UIDef
from ayon_core.pipeline.create import ( from ayon_core.pipeline.create import (
BaseCreator,
AutoCreator, AutoCreator,
HiddenCreator, HiddenCreator,
Creator, Creator,
CreateContext, CreateContext,
CreatedInstance,
) )
from ayon_core.pipeline.create.context import ( from ayon_core.pipeline.create.context import (
CreatorsOperationFailed, CreatorsOperationFailed,
ConvertorsOperationFailed, ConvertorsOperationFailed,
ConvertorItem,
) )
from ..abstract import CardMessageTypes from ayon_core.tools.publisher.abstract import (
AbstractPublisherController,
CardMessageTypes,
)
CREATE_EVENT_SOURCE = "publisher.create.model"
class CreatorType: class CreatorType:
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name: str = name
def __str__(self): def __str__(self):
return self.name return self.name
@ -42,7 +51,7 @@ class CreatorTypes:
artist = CreatorType("artist") artist = CreatorType("artist")
@classmethod @classmethod
def from_str(cls, value): def from_str(cls, value: str) -> CreatorType:
for creator_type in ( for creator_type in (
cls.base, cls.base,
cls.auto, cls.auto,
@ -62,49 +71,52 @@ class CreatorItem:
def __init__( def __init__(
self, self,
identifier, identifier: str,
creator_type, creator_type: CreatorType,
product_type, product_type: str,
label, label: str,
group_label, group_label: str,
icon, icon: Union[str, Dict[str, Any], None],
description, description: Union[str, None],
detailed_description, detailed_description: Union[str, None],
default_variant, default_variant: Union[str, None],
default_variants, default_variants: Union[List[str], None],
create_allow_context_change, create_allow_context_change: Union[bool, None],
create_allow_thumbnail, create_allow_thumbnail: Union[bool, None],
show_order, show_order: int,
pre_create_attributes_defs, pre_create_attributes_defs: List[AbstractAttrDef],
): ):
self.identifier = identifier self.identifier: str = identifier
self.creator_type = creator_type self.creator_type: CreatorType = creator_type
self.product_type = product_type self.product_type: str = product_type
self.label = label self.label: str = label
self.group_label = group_label self.group_label: str = group_label
self.icon = icon self.icon: Union[str, Dict[str, Any], None] = icon
self.description = description self.description: Union[str, None] = description
self.detailed_description = detailed_description self.detailed_description: Union[bool, None] = detailed_description
self.default_variant = default_variant self.default_variant: Union[bool, None] = default_variant
self.default_variants = default_variants self.default_variants: Union[List[str], None] = default_variants
self.create_allow_context_change = create_allow_context_change self.create_allow_context_change: Union[bool, None] = (
self.create_allow_thumbnail = create_allow_thumbnail create_allow_context_change
self.show_order = show_order )
self.pre_create_attributes_defs = pre_create_attributes_defs self.create_allow_thumbnail: Union[bool, None] = create_allow_thumbnail
self.show_order: int = show_order
self.pre_create_attributes_defs: List[AbstractAttrDef] = (
pre_create_attributes_defs
)
def get_group_label(self): def get_group_label(self) -> str:
return self.group_label return self.group_label
@classmethod @classmethod
def from_creator(cls, creator): def from_creator(cls, creator: BaseCreator):
creator_type: CreatorType = CreatorTypes.base
if isinstance(creator, AutoCreator): if isinstance(creator, AutoCreator):
creator_type = CreatorTypes.auto creator_type = CreatorTypes.auto
elif isinstance(creator, HiddenCreator): elif isinstance(creator, HiddenCreator):
creator_type = CreatorTypes.hidden creator_type = CreatorTypes.hidden
elif isinstance(creator, Creator): elif isinstance(creator, Creator):
creator_type = CreatorTypes.artist creator_type = CreatorTypes.artist
else:
creator_type = CreatorTypes.base
description = None description = None
detail_description = None detail_description = None
@ -142,7 +154,7 @@ class CreatorItem:
pre_create_attr_defs, pre_create_attr_defs,
) )
def to_data(self): def to_data(self) -> Dict[str, Any]:
pre_create_attributes_defs = None pre_create_attributes_defs = None
if self.pre_create_attributes_defs is not None: if self.pre_create_attributes_defs is not None:
pre_create_attributes_defs = serialize_attr_defs( pre_create_attributes_defs = serialize_attr_defs(
@ -167,7 +179,7 @@ class CreatorItem:
} }
@classmethod @classmethod
def from_data(cls, data): def from_data(cls, data: Dict[str, Any]) -> "CreatorItem":
pre_create_attributes_defs = data["pre_create_attributes_defs"] pre_create_attributes_defs = data["pre_create_attributes_defs"]
if pre_create_attributes_defs is not None: if pre_create_attributes_defs is not None:
data["pre_create_attributes_defs"] = deserialize_attr_defs( data["pre_create_attributes_defs"] = deserialize_attr_defs(
@ -179,7 +191,7 @@ class CreatorItem:
class CreateModel: class CreateModel:
def __init__(self, controller): def __init__(self, controller: AbstractPublisherController):
self._log = None self._log = None
self._controller = controller self._controller = controller
@ -192,18 +204,18 @@ class CreateModel:
self._creator_items = None self._creator_items = None
@property @property
def log(self): def log(self) -> logging.Logger:
if self._log is None: if self._log is None:
self._log = logging.getLogger(self.__class__.__name__) self._log = logging.getLogger(self.__class__.__name__)
return self._log return self._log
def is_host_valid(self): def is_host_valid(self) -> bool:
return self._create_context.host_is_valid return self._create_context.host_is_valid
def get_create_context(self): def get_create_context(self) -> CreateContext:
return self._create_context return self._create_context
def get_current_project_name(self): def get_current_project_name(self) -> Union[str, None]:
"""Current project context defined by host. """Current project context defined by host.
Returns: Returns:
@ -212,7 +224,7 @@ class CreateModel:
""" """
return self._create_context.get_current_project_name() return self._create_context.get_current_project_name()
def get_current_folder_path(self): def get_current_folder_path(self) -> Union[str, None]:
"""Current context folder path defined by host. """Current context folder path defined by host.
Returns: Returns:
@ -221,7 +233,7 @@ class CreateModel:
return self._create_context.get_current_folder_path() return self._create_context.get_current_folder_path()
def get_current_task_name(self): def get_current_task_name(self) -> Union[str, None]:
"""Current context task name defined by host. """Current context task name defined by host.
Returns: Returns:
@ -230,51 +242,61 @@ class CreateModel:
return self._create_context.get_current_task_name() return self._create_context.get_current_task_name()
def host_context_has_changed(self): def host_context_has_changed(self) -> bool:
return self._create_context.context_has_changed return self._create_context.context_has_changed
def reset(self): def reset(self):
self._creator_items = None
self._create_context.reset_preparation() self._create_context.reset_preparation()
# Reset current context # Reset current context
self._create_context.reset_current_context() self._create_context.reset_current_context()
self._reset_plugins() self._create_context.reset_plugins()
# Reset creator items
self._creator_items = None
self._reset_instances() self._reset_instances()
self._create_context.reset_finalization() self._create_context.reset_finalization()
def get_creator_items(self): def get_creator_items(self) -> Dict[str, CreatorItem]:
"""Creators that can be shown in create dialog.""" """Creators that can be shown in create dialog."""
if self._creator_items is None: if self._creator_items is None:
self._creator_items = self._collect_creator_items() self._creator_items = self._collect_creator_items()
return self._creator_items return self._creator_items
def get_creator_item_by_id(self, identifier): def get_creator_item_by_id(
self, identifier: str
) -> Union[CreatorItem, None]:
items = self.get_creator_items() items = self.get_creator_items()
return items.get(identifier) return items.get(identifier)
def get_creator_icon(self, identifier): def get_creator_icon(
self, identifier: str
) -> Union[str, Dict[str, Any], None]:
"""Function to receive icon for creator identifier. """Function to receive icon for creator identifier.
Args: Args:
str: Creator's identifier for which should be icon returned. identifier (str): Creator's identifier for which should
""" be icon returned.
"""
creator_item = self.get_creator_item_by_id(identifier) creator_item = self.get_creator_item_by_id(identifier)
if creator_item is not None: if creator_item is not None:
return creator_item.icon return creator_item.icon
return None return None
def get_instances(self): def get_instances(self) -> List[CreatedInstance]:
"""Current instances in create context.""" """Current instances in create context."""
return list(self._create_context.instances_by_id.values()) return list(self._create_context.instances_by_id.values())
def get_instance_by_id(self, instance_id): def get_instance_by_id(
self, instance_id: str
) -> Union[CreatedInstance, None]:
return self._create_context.instances_by_id.get(instance_id) return self._create_context.instances_by_id.get(instance_id)
def get_instances_by_id(self, instance_ids=None): def get_instances_by_id(
self, instance_ids: Optional[Iterable[str]] = None
) -> Dict[str, Union[CreatedInstance, None]]:
if instance_ids is None: if instance_ids is None:
instance_ids = self._create_context.instances_by_id.keys() instance_ids = self._create_context.instances_by_id.keys()
return { return {
@ -282,17 +304,17 @@ class CreateModel:
for instance_id in instance_ids for instance_id in instance_ids
} }
def get_convertor_items(self): def get_convertor_items(self) -> Dict[str, ConvertorItem]:
return self._create_context.convertor_items_by_id return self._create_context.convertor_items_by_id
def get_product_name( def get_product_name(
self, self,
creator_identifier, creator_identifier: str,
variant, variant: str,
task_name, task_name: Union[str, None],
folder_path, folder_path: Union[str, None],
instance_id=None instance_id: Optional[str] = None
): ) -> str:
"""Get product name based on passed data. """Get product name based on passed data.
Args: Args:
@ -323,7 +345,10 @@ class CreateModel:
project_name, folder_item.entity_id project_name, folder_item.entity_id
) )
task_item = self._controller.get_task_item_by_name( task_item = self._controller.get_task_item_by_name(
project_name, folder_item.entity_id, task_name, "controller" project_name,
folder_item.entity_id,
task_name,
CREATE_EVENT_SOURCE
) )
if task_item is not None: if task_item is not None:
@ -340,7 +365,11 @@ class CreateModel:
) )
def create( def create(
self, creator_identifier, product_name, instance_data, options self,
creator_identifier: str,
product_name: str,
instance_data: Dict[str, Any],
options: Dict[str, Any],
): ):
"""Trigger creation and refresh of instances in UI.""" """Trigger creation and refresh of instances in UI."""
@ -363,7 +392,7 @@ class CreateModel:
self._on_create_instance_change() self._on_create_instance_change()
return success return success
def trigger_convertor_items(self, convertor_identifiers): def trigger_convertor_items(self, convertor_identifiers: List[str]):
"""Trigger legacy item convertors. """Trigger legacy item convertors.
This functionality requires to save and reset CreateContext. The reset This functionality requires to save and reset CreateContext. The reset
@ -398,7 +427,7 @@ class CreateModel:
CardMessageTypes.error CardMessageTypes.error
) )
def save_changes(self, show_message=True): def save_changes(self, show_message: Optional[bool] = True):
"""Save changes happened during creation. """Save changes happened during creation.
Trigger save of changes using host api. This functionality does not Trigger save of changes using host api. This functionality does not
@ -437,7 +466,7 @@ class CreateModel:
return False return False
def remove_instances(self, instance_ids): def remove_instances(self, instance_ids: List[str]):
"""Remove instances based on instance ids. """Remove instances based on instance ids.
Args: Args:
@ -450,11 +479,13 @@ class CreateModel:
self._on_create_instance_change() self._on_create_instance_change()
def get_creator_attribute_definitions(self, instances): def get_creator_attribute_definitions(
self, instances: List[CreatedInstance]
) -> List[Tuple[AbstractAttrDef, List[CreatedInstance], List[Any]]]:
"""Collect creator attribute definitions for multuple instances. """Collect creator attribute definitions for multuple instances.
Args: Args:
instances(List[CreatedInstance]): List of created instances for instances (List[CreatedInstance]): List of created instances for
which should be attribute definitions returned. which should be attribute definitions returned.
""" """
@ -483,15 +514,21 @@ class CreateModel:
item[2].append(value) item[2].append(value)
return output return output
def get_publish_attribute_definitions(self, instances, include_context): def get_publish_attribute_definitions(
self, instances: List[CreatedInstance], include_context: bool
) -> List[Tuple[
str,
List[AbstractAttrDef],
Dict[str, List[Tuple[CreatedInstance, Any]]]
]]:
"""Collect publish attribute definitions for passed instances. """Collect publish attribute definitions for passed instances.
Args: Args:
instances(list<CreatedInstance>): List of created instances for instances (list[CreatedInstance]): List of created instances for
which should be attribute definitions returned. which should be attribute definitions returned.
include_context(bool): Add context specific attribute definitions. include_context (bool): Add context specific attribute definitions.
"""
"""
_tmp_items = [] _tmp_items = []
if include_context: if include_context:
_tmp_items.append(self._create_context) _tmp_items.append(self._create_context)
@ -510,17 +547,13 @@ class CreateModel:
if plugin_name not in all_defs_by_plugin_name: if plugin_name not in all_defs_by_plugin_name:
all_defs_by_plugin_name[plugin_name] = attr_val.attr_defs all_defs_by_plugin_name[plugin_name] = attr_val.attr_defs
if plugin_name not in all_plugin_values: plugin_values = all_plugin_values.setdefault(plugin_name, {})
all_plugin_values[plugin_name] = {}
plugin_values = all_plugin_values[plugin_name]
for attr_def in attr_defs: for attr_def in attr_defs:
if isinstance(attr_def, UIDef): if isinstance(attr_def, UIDef):
continue continue
if attr_def.key not in plugin_values:
plugin_values[attr_def.key] = [] attr_values = plugin_values.setdefault(attr_def.key, [])
attr_values = plugin_values[attr_def.key]
value = attr_val[attr_def.key] value = attr_val[attr_def.key]
attr_values.append((item, value)) attr_values.append((item, value))
@ -537,7 +570,9 @@ class CreateModel:
)) ))
return output return output
def get_thumbnail_paths_for_instances(self, instance_ids): def get_thumbnail_paths_for_instances(
self, instance_ids: List[str]
) -> Dict[str, Union[str, None]]:
thumbnail_paths_by_instance_id = ( thumbnail_paths_by_instance_id = (
self._create_context.thumbnail_paths_by_instance_id self._create_context.thumbnail_paths_by_instance_id
) )
@ -546,7 +581,9 @@ class CreateModel:
for instance_id in instance_ids for instance_id in instance_ids
} }
def set_thumbnail_paths_for_instances(self, thumbnail_path_mapping): def set_thumbnail_paths_for_instances(
self, thumbnail_path_mapping: Dict[str, str]
):
thumbnail_paths_by_instance_id = ( thumbnail_paths_by_instance_id = (
self._create_context.thumbnail_paths_by_instance_id self._create_context.thumbnail_paths_by_instance_id
) )
@ -560,10 +597,10 @@ class CreateModel:
} }
) )
def _emit_event(self, topic, data=None): def _emit_event(self, topic: str, data: Optional[Dict[str, Any]] = None):
self._controller.emit_event(topic, data) self._controller.emit_event(topic, data)
def _get_current_project_settings(self): def _get_current_project_settings(self) -> Dict[str, Any]:
"""Current project settings. """Current project settings.
Returns: Returns:
@ -573,17 +610,11 @@ class CreateModel:
return self._create_context.get_current_project_settings() return self._create_context.get_current_project_settings()
@property @property
def _creators(self): def _creators(self) -> Dict[str, BaseCreator]:
"""All creators loaded in create context.""" """All creators loaded in create context."""
return self._create_context.creators return self._create_context.creators
def _reset_plugins(self):
"""Reset to initial state."""
self._create_context.reset_plugins()
# Reset creator items
self._creator_items = None
def _reset_instances(self): def _reset_instances(self):
"""Reset create instances.""" """Reset create instances."""
@ -625,7 +656,7 @@ class CreateModel:
self._on_create_instance_change() self._on_create_instance_change()
def _remove_instances_from_context(self, instance_ids): def _remove_instances_from_context(self, instance_ids: List[str]):
instances_by_id = self._create_context.instances_by_id instances_by_id = self._create_context.instances_by_id
instances = [ instances = [
instances_by_id[instance_id] instances_by_id[instance_id]
@ -645,7 +676,7 @@ class CreateModel:
def _on_create_instance_change(self): def _on_create_instance_change(self):
self._emit_event("instances.refresh.finished") self._emit_event("instances.refresh.finished")
def _collect_creator_items(self): def _collect_creator_items(self) -> Dict[str, CreatorItem]:
# TODO add crashed initialization of create plugins to report # TODO add crashed initialization of create plugins to report
output = {} output = {}
allowed_creator_pattern = self._get_allowed_creators_pattern() allowed_creator_pattern = self._get_allowed_creators_pattern()
@ -666,7 +697,7 @@ class CreateModel:
return output return output
def _get_allowed_creators_pattern(self): def _get_allowed_creators_pattern(self) -> Union[Pattern, None]:
"""Provide regex pattern for configured creator labels in this context """Provide regex pattern for configured creator labels in this context
If no profile matches current context, it shows all creators. If no profile matches current context, it shows all creators.
@ -709,7 +740,11 @@ class CreateModel:
re.compile("|".join(allowed_creator_labels))) re.compile("|".join(allowed_creator_labels)))
return allowed_creator_pattern return allowed_creator_pattern
def _is_label_allowed(self, label, allowed_labels_regex): def _is_label_allowed(
self,
label: str,
allowed_labels_regex: Union[Pattern, None]
) -> bool:
"""Implement regex support for allowed labels. """Implement regex support for allowed labels.
Args: Args: