added typehints to create model

This commit is contained in:
Jakub Trllo 2024-06-21 18:22:01 +02:00
parent 24f3fa28a8
commit 76071c4b87

View file

@ -1,28 +1,37 @@
import logging
import re
from typing import Union, List, Dict, Tuple, Any, Optional, Iterable, Pattern
from ayon_core.lib.attribute_definitions import (
serialize_attr_defs,
deserialize_attr_defs,
AbstractAttrDef,
)
from ayon_core.lib.profiles_filtering import filter_profiles
from ayon_core.lib.attribute_definitions import UIDef
from ayon_core.pipeline.create import (
BaseCreator,
AutoCreator,
HiddenCreator,
Creator,
CreateContext,
CreatedInstance,
)
from ayon_core.pipeline.create.context import (
CreatorsOperationFailed,
ConvertorsOperationFailed,
ConvertorItem,
)
from ..abstract import CardMessageTypes
from ayon_core.tools.publisher.abstract import (
AbstractPublisherController,
CardMessageTypes,
)
CREATE_EVENT_SOURCE = "publisher.create.model"
class CreatorType:
def __init__(self, name):
self.name = name
def __init__(self, name: str):
self.name: str = name
def __str__(self):
return self.name
@ -42,7 +51,7 @@ class CreatorTypes:
artist = CreatorType("artist")
@classmethod
def from_str(cls, value):
def from_str(cls, value: str) -> CreatorType:
for creator_type in (
cls.base,
cls.auto,
@ -62,49 +71,52 @@ class CreatorItem:
def __init__(
self,
identifier,
creator_type,
product_type,
label,
group_label,
icon,
description,
detailed_description,
default_variant,
default_variants,
create_allow_context_change,
create_allow_thumbnail,
show_order,
pre_create_attributes_defs,
identifier: str,
creator_type: CreatorType,
product_type: str,
label: str,
group_label: str,
icon: Union[str, Dict[str, Any], None],
description: Union[str, None],
detailed_description: Union[str, None],
default_variant: Union[str, None],
default_variants: Union[List[str], None],
create_allow_context_change: Union[bool, None],
create_allow_thumbnail: Union[bool, None],
show_order: int,
pre_create_attributes_defs: List[AbstractAttrDef],
):
self.identifier = identifier
self.creator_type = creator_type
self.product_type = product_type
self.label = label
self.group_label = group_label
self.icon = icon
self.description = description
self.detailed_description = detailed_description
self.default_variant = default_variant
self.default_variants = default_variants
self.create_allow_context_change = create_allow_context_change
self.create_allow_thumbnail = create_allow_thumbnail
self.show_order = show_order
self.pre_create_attributes_defs = pre_create_attributes_defs
self.identifier: str = identifier
self.creator_type: CreatorType = creator_type
self.product_type: str = product_type
self.label: str = label
self.group_label: str = group_label
self.icon: Union[str, Dict[str, Any], None] = icon
self.description: Union[str, None] = description
self.detailed_description: Union[bool, None] = detailed_description
self.default_variant: Union[bool, None] = default_variant
self.default_variants: Union[List[str], None] = default_variants
self.create_allow_context_change: Union[bool, None] = (
create_allow_context_change
)
self.create_allow_thumbnail: Union[bool, None] = create_allow_thumbnail
self.show_order: int = show_order
self.pre_create_attributes_defs: List[AbstractAttrDef] = (
pre_create_attributes_defs
)
def get_group_label(self):
def get_group_label(self) -> str:
return self.group_label
@classmethod
def from_creator(cls, creator):
def from_creator(cls, creator: BaseCreator):
creator_type: CreatorType = CreatorTypes.base
if isinstance(creator, AutoCreator):
creator_type = CreatorTypes.auto
elif isinstance(creator, HiddenCreator):
creator_type = CreatorTypes.hidden
elif isinstance(creator, Creator):
creator_type = CreatorTypes.artist
else:
creator_type = CreatorTypes.base
description = None
detail_description = None
@ -142,7 +154,7 @@ class CreatorItem:
pre_create_attr_defs,
)
def to_data(self):
def to_data(self) -> Dict[str, Any]:
pre_create_attributes_defs = None
if self.pre_create_attributes_defs is not None:
pre_create_attributes_defs = serialize_attr_defs(
@ -167,7 +179,7 @@ class CreatorItem:
}
@classmethod
def from_data(cls, data):
def from_data(cls, data: Dict[str, Any]) -> "CreatorItem":
pre_create_attributes_defs = data["pre_create_attributes_defs"]
if pre_create_attributes_defs is not None:
data["pre_create_attributes_defs"] = deserialize_attr_defs(
@ -179,7 +191,7 @@ class CreatorItem:
class CreateModel:
def __init__(self, controller):
def __init__(self, controller: AbstractPublisherController):
self._log = None
self._controller = controller
@ -192,18 +204,18 @@ class CreateModel:
self._creator_items = None
@property
def log(self):
def log(self) -> logging.Logger:
if self._log is None:
self._log = logging.getLogger(self.__class__.__name__)
return self._log
def is_host_valid(self):
def is_host_valid(self) -> bool:
return self._create_context.host_is_valid
def get_create_context(self):
def get_create_context(self) -> CreateContext:
return self._create_context
def get_current_project_name(self):
def get_current_project_name(self) -> Union[str, None]:
"""Current project context defined by host.
Returns:
@ -212,7 +224,7 @@ class CreateModel:
"""
return self._create_context.get_current_project_name()
def get_current_folder_path(self):
def get_current_folder_path(self) -> Union[str, None]:
"""Current context folder path defined by host.
Returns:
@ -221,7 +233,7 @@ class CreateModel:
return self._create_context.get_current_folder_path()
def get_current_task_name(self):
def get_current_task_name(self) -> Union[str, None]:
"""Current context task name defined by host.
Returns:
@ -230,51 +242,61 @@ class CreateModel:
return self._create_context.get_current_task_name()
def host_context_has_changed(self):
def host_context_has_changed(self) -> bool:
return self._create_context.context_has_changed
def reset(self):
self._creator_items = None
self._create_context.reset_preparation()
# Reset current context
self._create_context.reset_current_context()
self._reset_plugins()
self._create_context.reset_plugins()
# Reset creator items
self._creator_items = None
self._reset_instances()
self._create_context.reset_finalization()
def get_creator_items(self):
def get_creator_items(self) -> Dict[str, CreatorItem]:
"""Creators that can be shown in create dialog."""
if self._creator_items is None:
self._creator_items = self._collect_creator_items()
return self._creator_items
def get_creator_item_by_id(self, identifier):
def get_creator_item_by_id(
self, identifier: str
) -> Union[CreatorItem, None]:
items = self.get_creator_items()
return items.get(identifier)
def get_creator_icon(self, identifier):
def get_creator_icon(
self, identifier: str
) -> Union[str, Dict[str, Any], None]:
"""Function to receive icon for creator identifier.
Args:
str: Creator's identifier for which should be icon returned.
"""
identifier (str): Creator's identifier for which should
be icon returned.
"""
creator_item = self.get_creator_item_by_id(identifier)
if creator_item is not None:
return creator_item.icon
return None
def get_instances(self):
def get_instances(self) -> List[CreatedInstance]:
"""Current instances in create context."""
return list(self._create_context.instances_by_id.values())
def get_instance_by_id(self, instance_id):
def get_instance_by_id(
self, instance_id: str
) -> Union[CreatedInstance, None]:
return self._create_context.instances_by_id.get(instance_id)
def get_instances_by_id(self, instance_ids=None):
def get_instances_by_id(
self, instance_ids: Optional[Iterable[str]] = None
) -> Dict[str, Union[CreatedInstance, None]]:
if instance_ids is None:
instance_ids = self._create_context.instances_by_id.keys()
return {
@ -282,17 +304,17 @@ class CreateModel:
for instance_id in instance_ids
}
def get_convertor_items(self):
def get_convertor_items(self) -> Dict[str, ConvertorItem]:
return self._create_context.convertor_items_by_id
def get_product_name(
self,
creator_identifier,
variant,
task_name,
folder_path,
instance_id=None
):
creator_identifier: str,
variant: str,
task_name: Union[str, None],
folder_path: Union[str, None],
instance_id: Optional[str] = None
) -> str:
"""Get product name based on passed data.
Args:
@ -323,7 +345,10 @@ class CreateModel:
project_name, folder_item.entity_id
)
task_item = self._controller.get_task_item_by_name(
project_name, folder_item.entity_id, task_name, "controller"
project_name,
folder_item.entity_id,
task_name,
CREATE_EVENT_SOURCE
)
if task_item is not None:
@ -340,7 +365,11 @@ class CreateModel:
)
def create(
self, creator_identifier, product_name, instance_data, options
self,
creator_identifier: str,
product_name: str,
instance_data: Dict[str, Any],
options: Dict[str, Any],
):
"""Trigger creation and refresh of instances in UI."""
@ -363,7 +392,7 @@ class CreateModel:
self._on_create_instance_change()
return success
def trigger_convertor_items(self, convertor_identifiers):
def trigger_convertor_items(self, convertor_identifiers: List[str]):
"""Trigger legacy item convertors.
This functionality requires to save and reset CreateContext. The reset
@ -398,7 +427,7 @@ class CreateModel:
CardMessageTypes.error
)
def save_changes(self, show_message=True):
def save_changes(self, show_message: Optional[bool] = True):
"""Save changes happened during creation.
Trigger save of changes using host api. This functionality does not
@ -437,7 +466,7 @@ class CreateModel:
return False
def remove_instances(self, instance_ids):
def remove_instances(self, instance_ids: List[str]):
"""Remove instances based on instance ids.
Args:
@ -450,11 +479,13 @@ class CreateModel:
self._on_create_instance_change()
def get_creator_attribute_definitions(self, instances):
def get_creator_attribute_definitions(
self, instances: List[CreatedInstance]
) -> List[Tuple[AbstractAttrDef, List[CreatedInstance], List[Any]]]:
"""Collect creator attribute definitions for multuple instances.
Args:
instances(List[CreatedInstance]): List of created instances for
instances (List[CreatedInstance]): List of created instances for
which should be attribute definitions returned.
"""
@ -483,15 +514,21 @@ class CreateModel:
item[2].append(value)
return output
def get_publish_attribute_definitions(self, instances, include_context):
def get_publish_attribute_definitions(
self, instances: List[CreatedInstance], include_context: bool
) -> List[Tuple[
str,
List[AbstractAttrDef],
Dict[str, List[Tuple[CreatedInstance, Any]]]
]]:
"""Collect publish attribute definitions for passed instances.
Args:
instances(list<CreatedInstance>): List of created instances for
instances (list[CreatedInstance]): List of created instances for
which should be attribute definitions returned.
include_context(bool): Add context specific attribute definitions.
"""
include_context (bool): Add context specific attribute definitions.
"""
_tmp_items = []
if include_context:
_tmp_items.append(self._create_context)
@ -510,17 +547,13 @@ class CreateModel:
if plugin_name not in all_defs_by_plugin_name:
all_defs_by_plugin_name[plugin_name] = attr_val.attr_defs
if plugin_name not in all_plugin_values:
all_plugin_values[plugin_name] = {}
plugin_values = all_plugin_values[plugin_name]
plugin_values = all_plugin_values.setdefault(plugin_name, {})
for attr_def in attr_defs:
if isinstance(attr_def, UIDef):
continue
if attr_def.key not in plugin_values:
plugin_values[attr_def.key] = []
attr_values = plugin_values[attr_def.key]
attr_values = plugin_values.setdefault(attr_def.key, [])
value = attr_val[attr_def.key]
attr_values.append((item, value))
@ -537,7 +570,9 @@ class CreateModel:
))
return output
def get_thumbnail_paths_for_instances(self, instance_ids):
def get_thumbnail_paths_for_instances(
self, instance_ids: List[str]
) -> Dict[str, Union[str, None]]:
thumbnail_paths_by_instance_id = (
self._create_context.thumbnail_paths_by_instance_id
)
@ -546,7 +581,9 @@ class CreateModel:
for instance_id in instance_ids
}
def set_thumbnail_paths_for_instances(self, thumbnail_path_mapping):
def set_thumbnail_paths_for_instances(
self, thumbnail_path_mapping: Dict[str, str]
):
thumbnail_paths_by_instance_id = (
self._create_context.thumbnail_paths_by_instance_id
)
@ -560,10 +597,10 @@ class CreateModel:
}
)
def _emit_event(self, topic, data=None):
def _emit_event(self, topic: str, data: Optional[Dict[str, Any]] = None):
self._controller.emit_event(topic, data)
def _get_current_project_settings(self):
def _get_current_project_settings(self) -> Dict[str, Any]:
"""Current project settings.
Returns:
@ -573,17 +610,11 @@ class CreateModel:
return self._create_context.get_current_project_settings()
@property
def _creators(self):
def _creators(self) -> Dict[str, BaseCreator]:
"""All creators loaded in create context."""
return self._create_context.creators
def _reset_plugins(self):
"""Reset to initial state."""
self._create_context.reset_plugins()
# Reset creator items
self._creator_items = None
def _reset_instances(self):
"""Reset create instances."""
@ -625,7 +656,7 @@ class CreateModel:
self._on_create_instance_change()
def _remove_instances_from_context(self, instance_ids):
def _remove_instances_from_context(self, instance_ids: List[str]):
instances_by_id = self._create_context.instances_by_id
instances = [
instances_by_id[instance_id]
@ -645,7 +676,7 @@ class CreateModel:
def _on_create_instance_change(self):
self._emit_event("instances.refresh.finished")
def _collect_creator_items(self):
def _collect_creator_items(self) -> Dict[str, CreatorItem]:
# TODO add crashed initialization of create plugins to report
output = {}
allowed_creator_pattern = self._get_allowed_creators_pattern()
@ -666,7 +697,7 @@ class CreateModel:
return output
def _get_allowed_creators_pattern(self):
def _get_allowed_creators_pattern(self) -> Union[Pattern, None]:
"""Provide regex pattern for configured creator labels in this context
If no profile matches current context, it shows all creators.
@ -709,7 +740,11 @@ class CreateModel:
re.compile("|".join(allowed_creator_labels)))
return allowed_creator_pattern
def _is_label_allowed(self, label, allowed_labels_regex):
def _is_label_allowed(
self,
label: str,
allowed_labels_regex: Union[Pattern, None]
) -> bool:
"""Implement regex support for allowed labels.
Args: