mirror of
https://github.com/ynput/ayon-core.git
synced 2025-12-25 13:24:54 +01:00
removed fusion addon
This commit is contained in:
parent
552270dd0c
commit
348104c32d
118 changed files with 0 additions and 21923 deletions
|
|
@ -1,17 +0,0 @@
|
|||
from .version import __version__
|
||||
from .addon import (
|
||||
get_fusion_version,
|
||||
FusionAddon,
|
||||
FUSION_ADDON_ROOT,
|
||||
FUSION_VERSIONS_DICT,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"__version__",
|
||||
|
||||
"get_fusion_version",
|
||||
"FusionAddon",
|
||||
"FUSION_ADDON_ROOT",
|
||||
"FUSION_VERSIONS_DICT",
|
||||
)
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
from ayon_core.addon import AYONAddon, IHostAddon
|
||||
from ayon_core.lib import Logger
|
||||
|
||||
from .version import __version__
|
||||
|
||||
FUSION_ADDON_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# FUSION_VERSIONS_DICT is used by the pre-launch hooks
|
||||
# The keys correspond to all currently supported Fusion versions
|
||||
# Each value is a list of corresponding Python home variables and a profile
|
||||
# number, which is used by the profile hook to set Fusion profile variables.
|
||||
FUSION_VERSIONS_DICT = {
|
||||
9: ("FUSION_PYTHON36_HOME", 9),
|
||||
16: ("FUSION16_PYTHON36_HOME", 16),
|
||||
17: ("FUSION16_PYTHON36_HOME", 16),
|
||||
18: ("FUSION_PYTHON3_HOME", 16),
|
||||
}
|
||||
|
||||
|
||||
def get_fusion_version(app_name):
|
||||
"""
|
||||
The function is triggered by the prelaunch hooks to get the fusion version.
|
||||
|
||||
`app_name` is obtained by prelaunch hooks from the
|
||||
`launch_context.env.get("AYON_APP_NAME")`.
|
||||
|
||||
To get a correct Fusion version, a version number should be present
|
||||
in the `applications/fusion/variants` key
|
||||
of the Blackmagic Fusion Application Settings.
|
||||
"""
|
||||
|
||||
log = Logger.get_logger(__name__)
|
||||
|
||||
if not app_name:
|
||||
return
|
||||
|
||||
app_version_candidates = re.findall(r"\d+", app_name)
|
||||
if not app_version_candidates:
|
||||
return
|
||||
for app_version in app_version_candidates:
|
||||
if int(app_version) in FUSION_VERSIONS_DICT:
|
||||
return int(app_version)
|
||||
else:
|
||||
log.info(
|
||||
"Unsupported Fusion version: {app_version}".format(
|
||||
app_version=app_version
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FusionAddon(AYONAddon, IHostAddon):
|
||||
name = "fusion"
|
||||
version = __version__
|
||||
host_name = "fusion"
|
||||
|
||||
def get_launch_hook_paths(self, app):
|
||||
if app.host_name != self.host_name:
|
||||
return []
|
||||
return [os.path.join(FUSION_ADDON_ROOT, "hooks")]
|
||||
|
||||
def add_implementation_envs(self, env, app):
|
||||
# Set default values if are not already set via settings
|
||||
|
||||
defaults = {"AYON_LOG_NO_COLORS": "1"}
|
||||
for key, value in defaults.items():
|
||||
if not env.get(key):
|
||||
env[key] = value
|
||||
|
||||
def get_workfile_extensions(self):
|
||||
return [".comp"]
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
from .pipeline import (
|
||||
FusionHost,
|
||||
ls,
|
||||
|
||||
imprint_container,
|
||||
parse_container
|
||||
)
|
||||
|
||||
from .lib import (
|
||||
maintained_selection,
|
||||
update_frame_range,
|
||||
set_current_context_framerange,
|
||||
get_current_comp,
|
||||
get_bmd_library,
|
||||
comp_lock_and_undo_chunk
|
||||
)
|
||||
|
||||
from .menu import launch_ayon_menu
|
||||
|
||||
|
||||
__all__ = [
|
||||
# pipeline
|
||||
"FusionHost",
|
||||
"ls",
|
||||
|
||||
"imprint_container",
|
||||
"parse_container",
|
||||
|
||||
# lib
|
||||
"maintained_selection",
|
||||
"update_frame_range",
|
||||
"set_current_context_framerange",
|
||||
"get_current_comp",
|
||||
"get_bmd_library",
|
||||
"comp_lock_and_undo_chunk",
|
||||
|
||||
# menu
|
||||
"launch_ayon_menu",
|
||||
]
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
|
||||
from ayon_fusion.api.lib import get_current_comp
|
||||
from ayon_core.pipeline.publish import get_errored_instances_from_context
|
||||
|
||||
|
||||
class SelectInvalidAction(pyblish.api.Action):
|
||||
"""Select invalid nodes in Fusion when plug-in failed.
|
||||
|
||||
To retrieve the invalid nodes this assumes a static `get_invalid()`
|
||||
method is available on the plugin.
|
||||
|
||||
"""
|
||||
|
||||
label = "Select invalid"
|
||||
on = "failed" # This action is only available on a failed plug-in
|
||||
icon = "search" # Icon from Awesome Icon
|
||||
|
||||
def process(self, context, plugin):
|
||||
errored_instances = get_errored_instances_from_context(
|
||||
context,
|
||||
plugin=plugin,
|
||||
)
|
||||
|
||||
# Get the invalid nodes for the plug-ins
|
||||
self.log.info("Finding invalid nodes..")
|
||||
invalid = list()
|
||||
for instance in errored_instances:
|
||||
invalid_nodes = plugin.get_invalid(instance)
|
||||
if invalid_nodes:
|
||||
if isinstance(invalid_nodes, (list, tuple)):
|
||||
invalid.extend(invalid_nodes)
|
||||
else:
|
||||
self.log.warning(
|
||||
"Plug-in returned to be invalid, "
|
||||
"but has no selectable nodes."
|
||||
)
|
||||
|
||||
if not invalid:
|
||||
# Assume relevant comp is current comp and clear selection
|
||||
self.log.info("No invalid tools found.")
|
||||
comp = get_current_comp()
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
flow.Select() # No args equals clearing selection
|
||||
return
|
||||
|
||||
# Assume a single comp
|
||||
first_tool = invalid[0]
|
||||
comp = first_tool.Comp()
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
flow.Select() # No args equals clearing selection
|
||||
names = set()
|
||||
for tool in invalid:
|
||||
flow.Select(tool, True)
|
||||
comp.SetActiveTool(tool)
|
||||
names.add(tool.Name)
|
||||
self.log.info(
|
||||
"Selecting invalid tools: %s" % ", ".join(sorted(names))
|
||||
)
|
||||
|
||||
|
||||
class SelectToolAction(pyblish.api.Action):
|
||||
"""Select invalid output tool in Fusion when plug-in failed.
|
||||
|
||||
"""
|
||||
|
||||
label = "Select saver"
|
||||
on = "failed" # This action is only available on a failed plug-in
|
||||
icon = "search" # Icon from Awesome Icon
|
||||
|
||||
def process(self, context, plugin):
|
||||
errored_instances = get_errored_instances_from_context(
|
||||
context,
|
||||
plugin=plugin,
|
||||
)
|
||||
|
||||
# Get the invalid nodes for the plug-ins
|
||||
self.log.info("Finding invalid nodes..")
|
||||
tools = []
|
||||
for instance in errored_instances:
|
||||
|
||||
tool = instance.data.get("tool")
|
||||
if tool is not None:
|
||||
tools.append(tool)
|
||||
else:
|
||||
self.log.warning(
|
||||
"Plug-in returned to be invalid, "
|
||||
f"but has no saver for instance {instance.name}."
|
||||
)
|
||||
|
||||
if not tools:
|
||||
# Assume relevant comp is current comp and clear selection
|
||||
self.log.info("No invalid tools found.")
|
||||
comp = get_current_comp()
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
flow.Select() # No args equals clearing selection
|
||||
return
|
||||
|
||||
# Assume a single comp
|
||||
first_tool = tools[0]
|
||||
comp = first_tool.Comp()
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
flow.Select() # No args equals clearing selection
|
||||
names = set()
|
||||
for tool in tools:
|
||||
flow.Select(tool, True)
|
||||
comp.SetActiveTool(tool)
|
||||
names.add(tool.Name)
|
||||
self.log.info(
|
||||
"Selecting invalid tools: %s" % ", ".join(sorted(names))
|
||||
)
|
||||
|
|
@ -1,402 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import re
|
||||
import contextlib
|
||||
|
||||
from ayon_core.lib import Logger, BoolDef, UILabelDef
|
||||
from ayon_core.style import load_stylesheet
|
||||
from ayon_core.pipeline import registered_host
|
||||
from ayon_core.pipeline.create import CreateContext
|
||||
from ayon_core.pipeline.context_tools import get_current_folder_entity
|
||||
|
||||
self = sys.modules[__name__]
|
||||
self._project = None
|
||||
|
||||
|
||||
def update_frame_range(start, end, comp=None, set_render_range=True,
|
||||
handle_start=0, handle_end=0):
|
||||
"""Set Fusion comp's start and end frame range
|
||||
|
||||
Args:
|
||||
start (float, int): start frame
|
||||
end (float, int): end frame
|
||||
comp (object, Optional): comp object from fusion
|
||||
set_render_range (bool, Optional): When True this will also set the
|
||||
composition's render start and end frame.
|
||||
handle_start (float, int, Optional): frame handles before start frame
|
||||
handle_end (float, int, Optional): frame handles after end frame
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
if not comp:
|
||||
comp = get_current_comp()
|
||||
|
||||
# Convert any potential none type to zero
|
||||
handle_start = handle_start or 0
|
||||
handle_end = handle_end or 0
|
||||
|
||||
attrs = {
|
||||
"COMPN_GlobalStart": start - handle_start,
|
||||
"COMPN_GlobalEnd": end + handle_end
|
||||
}
|
||||
|
||||
# set frame range
|
||||
if set_render_range:
|
||||
attrs.update({
|
||||
"COMPN_RenderStart": start,
|
||||
"COMPN_RenderEnd": end
|
||||
})
|
||||
|
||||
with comp_lock_and_undo_chunk(comp):
|
||||
comp.SetAttrs(attrs)
|
||||
|
||||
|
||||
def set_current_context_framerange(folder_entity=None):
|
||||
"""Set Comp's frame range based on current folder."""
|
||||
if folder_entity is None:
|
||||
folder_entity = get_current_folder_entity(
|
||||
fields={"attrib.frameStart",
|
||||
"attrib.frameEnd",
|
||||
"attrib.handleStart",
|
||||
"attrib.handleEnd"})
|
||||
|
||||
folder_attributes = folder_entity["attrib"]
|
||||
start = folder_attributes["frameStart"]
|
||||
end = folder_attributes["frameEnd"]
|
||||
handle_start = folder_attributes["handleStart"]
|
||||
handle_end = folder_attributes["handleEnd"]
|
||||
update_frame_range(start, end, set_render_range=True,
|
||||
handle_start=handle_start,
|
||||
handle_end=handle_end)
|
||||
|
||||
|
||||
def set_current_context_fps(folder_entity=None):
|
||||
"""Set Comp's frame rate (FPS) to based on current asset"""
|
||||
if folder_entity is None:
|
||||
folder_entity = get_current_folder_entity(fields={"attrib.fps"})
|
||||
|
||||
fps = float(folder_entity["attrib"].get("fps", 24.0))
|
||||
comp = get_current_comp()
|
||||
comp.SetPrefs({
|
||||
"Comp.FrameFormat.Rate": fps,
|
||||
})
|
||||
|
||||
|
||||
def set_current_context_resolution(folder_entity=None):
|
||||
"""Set Comp's resolution width x height default based on current folder"""
|
||||
if folder_entity is None:
|
||||
folder_entity = get_current_folder_entity(
|
||||
fields={"attrib.resolutionWidth", "attrib.resolutionHeight"})
|
||||
|
||||
folder_attributes = folder_entity["attrib"]
|
||||
width = folder_attributes["resolutionWidth"]
|
||||
height = folder_attributes["resolutionHeight"]
|
||||
comp = get_current_comp()
|
||||
|
||||
print("Setting comp frame format resolution to {}x{}".format(width,
|
||||
height))
|
||||
comp.SetPrefs({
|
||||
"Comp.FrameFormat.Width": width,
|
||||
"Comp.FrameFormat.Height": height,
|
||||
})
|
||||
|
||||
|
||||
def validate_comp_prefs(comp=None, force_repair=False):
|
||||
"""Validate current comp defaults with folder settings.
|
||||
|
||||
Validates fps, resolutionWidth, resolutionHeight, aspectRatio.
|
||||
|
||||
This does *not* validate frameStart, frameEnd, handleStart and handleEnd.
|
||||
"""
|
||||
|
||||
if comp is None:
|
||||
comp = get_current_comp()
|
||||
|
||||
log = Logger.get_logger("validate_comp_prefs")
|
||||
|
||||
fields = {
|
||||
"path",
|
||||
"attrib.fps",
|
||||
"attrib.resolutionWidth",
|
||||
"attrib.resolutionHeight",
|
||||
"attrib.pixelAspect",
|
||||
}
|
||||
folder_entity = get_current_folder_entity(fields=fields)
|
||||
folder_path = folder_entity["path"]
|
||||
folder_attributes = folder_entity["attrib"]
|
||||
|
||||
comp_frame_format_prefs = comp.GetPrefs("Comp.FrameFormat")
|
||||
|
||||
# Pixel aspect ratio in Fusion is set as AspectX and AspectY so we convert
|
||||
# the data to something that is more sensible to Fusion
|
||||
folder_attributes["pixelAspectX"] = folder_attributes.pop("pixelAspect")
|
||||
folder_attributes["pixelAspectY"] = 1.0
|
||||
|
||||
validations = [
|
||||
("fps", "Rate", "FPS"),
|
||||
("resolutionWidth", "Width", "Resolution Width"),
|
||||
("resolutionHeight", "Height", "Resolution Height"),
|
||||
("pixelAspectX", "AspectX", "Pixel Aspect Ratio X"),
|
||||
("pixelAspectY", "AspectY", "Pixel Aspect Ratio Y")
|
||||
]
|
||||
|
||||
invalid = []
|
||||
for key, comp_key, label in validations:
|
||||
folder_value = folder_attributes[key]
|
||||
comp_value = comp_frame_format_prefs.get(comp_key)
|
||||
if folder_value != comp_value:
|
||||
invalid_msg = "{} {} should be {}".format(label,
|
||||
comp_value,
|
||||
folder_value)
|
||||
invalid.append(invalid_msg)
|
||||
|
||||
if not force_repair:
|
||||
# Do not log warning if we force repair anyway
|
||||
log.warning(
|
||||
"Comp {pref} {value} does not match folder "
|
||||
"'{folder_path}' {pref} {folder_value}".format(
|
||||
pref=label,
|
||||
value=comp_value,
|
||||
folder_path=folder_path,
|
||||
folder_value=folder_value)
|
||||
)
|
||||
|
||||
if invalid:
|
||||
|
||||
def _on_repair():
|
||||
attributes = dict()
|
||||
for key, comp_key, _label in validations:
|
||||
value = folder_attributes[key]
|
||||
comp_key_full = "Comp.FrameFormat.{}".format(comp_key)
|
||||
attributes[comp_key_full] = value
|
||||
comp.SetPrefs(attributes)
|
||||
|
||||
if force_repair:
|
||||
log.info("Applying default Comp preferences..")
|
||||
_on_repair()
|
||||
return
|
||||
|
||||
from . import menu
|
||||
from ayon_core.tools.utils import SimplePopup
|
||||
dialog = SimplePopup(parent=menu.menu)
|
||||
dialog.setWindowTitle("Fusion comp has invalid configuration")
|
||||
|
||||
msg = "Comp preferences mismatches '{}'".format(folder_path)
|
||||
msg += "\n" + "\n".join(invalid)
|
||||
dialog.set_message(msg)
|
||||
dialog.set_button_text("Repair")
|
||||
dialog.on_clicked.connect(_on_repair)
|
||||
dialog.show()
|
||||
dialog.raise_()
|
||||
dialog.activateWindow()
|
||||
dialog.setStyleSheet(load_stylesheet())
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maintained_selection(comp=None):
|
||||
"""Reset comp selection from before the context after the context"""
|
||||
if comp is None:
|
||||
comp = get_current_comp()
|
||||
|
||||
previous_selection = comp.GetToolList(True).values()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
flow.Select() # No args equals clearing selection
|
||||
if previous_selection:
|
||||
for tool in previous_selection:
|
||||
flow.Select(tool, True)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maintained_comp_range(comp=None,
|
||||
global_start=True,
|
||||
global_end=True,
|
||||
render_start=True,
|
||||
render_end=True):
|
||||
"""Reset comp frame ranges from before the context after the context"""
|
||||
if comp is None:
|
||||
comp = get_current_comp()
|
||||
|
||||
comp_attrs = comp.GetAttrs()
|
||||
preserve_attrs = {}
|
||||
if global_start:
|
||||
preserve_attrs["COMPN_GlobalStart"] = comp_attrs["COMPN_GlobalStart"]
|
||||
if global_end:
|
||||
preserve_attrs["COMPN_GlobalEnd"] = comp_attrs["COMPN_GlobalEnd"]
|
||||
if render_start:
|
||||
preserve_attrs["COMPN_RenderStart"] = comp_attrs["COMPN_RenderStart"]
|
||||
if render_end:
|
||||
preserve_attrs["COMPN_RenderEnd"] = comp_attrs["COMPN_RenderEnd"]
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
comp.SetAttrs(preserve_attrs)
|
||||
|
||||
|
||||
def get_frame_path(path):
|
||||
"""Get filename for the Fusion Saver with padded number as '#'
|
||||
|
||||
>>> get_frame_path("C:/test.exr")
|
||||
('C:/test', 4, '.exr')
|
||||
|
||||
>>> get_frame_path("filename.00.tif")
|
||||
('filename.', 2, '.tif')
|
||||
|
||||
>>> get_frame_path("foobar35.tif")
|
||||
('foobar', 2, '.tif')
|
||||
|
||||
Args:
|
||||
path (str): The path to render to.
|
||||
|
||||
Returns:
|
||||
tuple: head, padding, tail (extension)
|
||||
|
||||
"""
|
||||
filename, ext = os.path.splitext(path)
|
||||
|
||||
# Find a final number group
|
||||
match = re.match('.*?([0-9]+)$', filename)
|
||||
if match:
|
||||
padding = len(match.group(1))
|
||||
# remove number from end since fusion
|
||||
# will swap it with the frame number
|
||||
filename = filename[:-padding]
|
||||
else:
|
||||
padding = 4 # default Fusion padding
|
||||
|
||||
return filename, padding, ext
|
||||
|
||||
|
||||
def get_fusion_module():
|
||||
"""Get current Fusion instance"""
|
||||
fusion = getattr(sys.modules["__main__"], "fusion", None)
|
||||
return fusion
|
||||
|
||||
|
||||
def get_bmd_library():
|
||||
"""Get bmd library"""
|
||||
bmd = getattr(sys.modules["__main__"], "bmd", None)
|
||||
return bmd
|
||||
|
||||
|
||||
def get_current_comp():
|
||||
"""Get current comp in this session"""
|
||||
fusion = get_fusion_module()
|
||||
if fusion is not None:
|
||||
comp = fusion.CurrentComp
|
||||
return comp
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def comp_lock_and_undo_chunk(
|
||||
comp,
|
||||
undo_queue_name="Script CMD",
|
||||
keep_undo=True,
|
||||
):
|
||||
"""Lock comp and open an undo chunk during the context"""
|
||||
try:
|
||||
comp.Lock()
|
||||
comp.StartUndo(undo_queue_name)
|
||||
yield
|
||||
finally:
|
||||
comp.Unlock()
|
||||
comp.EndUndo(keep_undo)
|
||||
|
||||
|
||||
def update_content_on_context_change():
|
||||
"""Update all Creator instances to current asset"""
|
||||
host = registered_host()
|
||||
context = host.get_current_context()
|
||||
|
||||
folder_path = context["folder_path"]
|
||||
task = context["task_name"]
|
||||
|
||||
create_context = CreateContext(host, reset=True)
|
||||
|
||||
for instance in create_context.instances:
|
||||
instance_folder_path = instance.get("folderPath")
|
||||
if instance_folder_path and instance_folder_path != folder_path:
|
||||
instance["folderPath"] = folder_path
|
||||
instance_task = instance.get("task")
|
||||
if instance_task and instance_task != task:
|
||||
instance["task"] = task
|
||||
|
||||
create_context.save_changes()
|
||||
|
||||
|
||||
def prompt_reset_context():
|
||||
"""Prompt the user what context settings to reset.
|
||||
This prompt is used on saving to a different task to allow the scene to
|
||||
get matched to the new context.
|
||||
"""
|
||||
# TODO: Cleanup this prototyped mess of imports and odd dialog
|
||||
from ayon_core.tools.attribute_defs.dialog import (
|
||||
AttributeDefinitionsDialog
|
||||
)
|
||||
from qtpy import QtCore
|
||||
|
||||
definitions = [
|
||||
UILabelDef(
|
||||
label=(
|
||||
"You are saving your workfile into a different folder or task."
|
||||
"\n\n"
|
||||
"Would you like to update some settings to the new context?\n"
|
||||
)
|
||||
),
|
||||
BoolDef(
|
||||
"fps",
|
||||
label="FPS",
|
||||
tooltip="Reset Comp FPS",
|
||||
default=True
|
||||
),
|
||||
BoolDef(
|
||||
"frame_range",
|
||||
label="Frame Range",
|
||||
tooltip="Reset Comp start and end frame ranges",
|
||||
default=True
|
||||
),
|
||||
BoolDef(
|
||||
"resolution",
|
||||
label="Comp Resolution",
|
||||
tooltip="Reset Comp resolution",
|
||||
default=True
|
||||
),
|
||||
BoolDef(
|
||||
"instances",
|
||||
label="Publish instances",
|
||||
tooltip="Update all publish instance's folder and task to match "
|
||||
"the new folder and task",
|
||||
default=True
|
||||
),
|
||||
]
|
||||
|
||||
dialog = AttributeDefinitionsDialog(definitions)
|
||||
dialog.setWindowFlags(
|
||||
dialog.windowFlags() | QtCore.Qt.WindowStaysOnTopHint
|
||||
)
|
||||
dialog.setWindowTitle("Saving to different context.")
|
||||
dialog.setStyleSheet(load_stylesheet())
|
||||
if not dialog.exec_():
|
||||
return None
|
||||
|
||||
options = dialog.get_values()
|
||||
folder_entity = get_current_folder_entity()
|
||||
if options["frame_range"]:
|
||||
set_current_context_framerange(folder_entity)
|
||||
|
||||
if options["fps"]:
|
||||
set_current_context_fps(folder_entity)
|
||||
|
||||
if options["resolution"]:
|
||||
set_current_context_resolution(folder_entity)
|
||||
|
||||
if options["instances"]:
|
||||
update_content_on_context_change()
|
||||
|
||||
dialog.deleteLater()
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from qtpy import QtWidgets, QtCore, QtGui
|
||||
|
||||
from ayon_core.tools.utils import host_tools
|
||||
from ayon_core.style import load_stylesheet
|
||||
from ayon_core.lib import register_event_callback
|
||||
from ayon_fusion.scripts import (
|
||||
duplicate_with_inputs,
|
||||
)
|
||||
from ayon_fusion.api.lib import (
|
||||
set_current_context_framerange,
|
||||
set_current_context_resolution,
|
||||
)
|
||||
from ayon_core.pipeline import get_current_folder_path
|
||||
from ayon_core.resources import get_ayon_icon_filepath
|
||||
from ayon_core.tools.utils import get_qt_app
|
||||
|
||||
from .pipeline import FusionEventHandler
|
||||
from .pulse import FusionPulse
|
||||
|
||||
|
||||
MENU_LABEL = os.environ["AYON_MENU_LABEL"]
|
||||
|
||||
|
||||
self = sys.modules[__name__]
|
||||
self.menu = None
|
||||
|
||||
|
||||
class AYONMenu(QtWidgets.QWidget):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AYONMenu, self).__init__(*args, **kwargs)
|
||||
|
||||
self.setObjectName(f"{MENU_LABEL}Menu")
|
||||
|
||||
icon_path = get_ayon_icon_filepath()
|
||||
icon = QtGui.QIcon(icon_path)
|
||||
self.setWindowIcon(icon)
|
||||
|
||||
self.setWindowFlags(
|
||||
QtCore.Qt.Window
|
||||
| QtCore.Qt.CustomizeWindowHint
|
||||
| QtCore.Qt.WindowTitleHint
|
||||
| QtCore.Qt.WindowMinimizeButtonHint
|
||||
| QtCore.Qt.WindowCloseButtonHint
|
||||
| QtCore.Qt.WindowStaysOnTopHint
|
||||
)
|
||||
self.render_mode_widget = None
|
||||
self.setWindowTitle(MENU_LABEL)
|
||||
|
||||
context_label = QtWidgets.QLabel("Context", self)
|
||||
context_label.setStyleSheet(
|
||||
"""QLabel {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #5f9fb8;
|
||||
}"""
|
||||
)
|
||||
context_label.setAlignment(QtCore.Qt.AlignHCenter)
|
||||
|
||||
workfiles_btn = QtWidgets.QPushButton("Workfiles...", self)
|
||||
create_btn = QtWidgets.QPushButton("Create...", self)
|
||||
load_btn = QtWidgets.QPushButton("Load...", self)
|
||||
publish_btn = QtWidgets.QPushButton("Publish...", self)
|
||||
manager_btn = QtWidgets.QPushButton("Manage...", self)
|
||||
libload_btn = QtWidgets.QPushButton("Library...", self)
|
||||
set_framerange_btn = QtWidgets.QPushButton("Set Frame Range", self)
|
||||
set_resolution_btn = QtWidgets.QPushButton("Set Resolution", self)
|
||||
duplicate_with_inputs_btn = QtWidgets.QPushButton(
|
||||
"Duplicate with input connections", self
|
||||
)
|
||||
|
||||
layout = QtWidgets.QVBoxLayout(self)
|
||||
layout.setContentsMargins(10, 20, 10, 20)
|
||||
|
||||
layout.addWidget(context_label)
|
||||
|
||||
layout.addSpacing(20)
|
||||
|
||||
layout.addWidget(workfiles_btn)
|
||||
|
||||
layout.addSpacing(20)
|
||||
|
||||
layout.addWidget(create_btn)
|
||||
layout.addWidget(load_btn)
|
||||
layout.addWidget(publish_btn)
|
||||
layout.addWidget(manager_btn)
|
||||
|
||||
layout.addSpacing(20)
|
||||
|
||||
layout.addWidget(libload_btn)
|
||||
|
||||
layout.addSpacing(20)
|
||||
|
||||
layout.addWidget(set_framerange_btn)
|
||||
layout.addWidget(set_resolution_btn)
|
||||
|
||||
layout.addSpacing(20)
|
||||
|
||||
layout.addWidget(duplicate_with_inputs_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
# Store reference so we can update the label
|
||||
self.context_label = context_label
|
||||
|
||||
workfiles_btn.clicked.connect(self.on_workfile_clicked)
|
||||
create_btn.clicked.connect(self.on_create_clicked)
|
||||
publish_btn.clicked.connect(self.on_publish_clicked)
|
||||
load_btn.clicked.connect(self.on_load_clicked)
|
||||
manager_btn.clicked.connect(self.on_manager_clicked)
|
||||
libload_btn.clicked.connect(self.on_libload_clicked)
|
||||
duplicate_with_inputs_btn.clicked.connect(
|
||||
self.on_duplicate_with_inputs_clicked
|
||||
)
|
||||
set_resolution_btn.clicked.connect(self.on_set_resolution_clicked)
|
||||
set_framerange_btn.clicked.connect(self.on_set_framerange_clicked)
|
||||
|
||||
self._callbacks = []
|
||||
self.register_callback("taskChanged", self.on_task_changed)
|
||||
self.on_task_changed()
|
||||
|
||||
# Force close current process if Fusion is closed
|
||||
self._pulse = FusionPulse(parent=self)
|
||||
self._pulse.start()
|
||||
|
||||
# Detect Fusion events as AYON events
|
||||
self._event_handler = FusionEventHandler(parent=self)
|
||||
self._event_handler.start()
|
||||
|
||||
def on_task_changed(self):
|
||||
# Update current context label
|
||||
label = get_current_folder_path()
|
||||
self.context_label.setText(label)
|
||||
|
||||
def register_callback(self, name, fn):
|
||||
# Create a wrapper callback that we only store
|
||||
# for as long as we want it to persist as callback
|
||||
def _callback(*args):
|
||||
fn()
|
||||
|
||||
self._callbacks.append(_callback)
|
||||
register_event_callback(name, _callback)
|
||||
|
||||
def deregister_all_callbacks(self):
|
||||
self._callbacks[:] = []
|
||||
|
||||
def on_workfile_clicked(self):
|
||||
host_tools.show_workfiles()
|
||||
|
||||
def on_create_clicked(self):
|
||||
host_tools.show_publisher(tab="create")
|
||||
|
||||
def on_publish_clicked(self):
|
||||
host_tools.show_publisher(tab="publish")
|
||||
|
||||
def on_load_clicked(self):
|
||||
host_tools.show_loader(use_context=True)
|
||||
|
||||
def on_manager_clicked(self):
|
||||
host_tools.show_scene_inventory()
|
||||
|
||||
def on_libload_clicked(self):
|
||||
host_tools.show_library_loader()
|
||||
|
||||
def on_duplicate_with_inputs_clicked(self):
|
||||
duplicate_with_inputs.duplicate_with_input_connections()
|
||||
|
||||
def on_set_resolution_clicked(self):
|
||||
set_current_context_resolution()
|
||||
|
||||
def on_set_framerange_clicked(self):
|
||||
set_current_context_framerange()
|
||||
|
||||
|
||||
def launch_ayon_menu():
|
||||
app = get_qt_app()
|
||||
|
||||
ayon_menu = AYONMenu()
|
||||
|
||||
stylesheet = load_stylesheet()
|
||||
ayon_menu.setStyleSheet(stylesheet)
|
||||
|
||||
ayon_menu.show()
|
||||
self.menu = ayon_menu
|
||||
|
||||
result = app.exec_()
|
||||
print("Shutting down..")
|
||||
sys.exit(result)
|
||||
|
|
@ -1,439 +0,0 @@
|
|||
"""
|
||||
Basic avalon integration
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import pyblish.api
|
||||
from qtpy import QtCore
|
||||
|
||||
from ayon_core.lib import (
|
||||
Logger,
|
||||
register_event_callback,
|
||||
emit_event
|
||||
)
|
||||
from ayon_core.pipeline import (
|
||||
register_loader_plugin_path,
|
||||
register_creator_plugin_path,
|
||||
register_inventory_action_path,
|
||||
AVALON_CONTAINER_ID,
|
||||
)
|
||||
from ayon_core.pipeline.load import any_outdated_containers
|
||||
from ayon_core.host import HostBase, IWorkfileHost, ILoadHost, IPublishHost
|
||||
from ayon_core.tools.utils import host_tools
|
||||
from ayon_fusion import FUSION_ADDON_ROOT
|
||||
|
||||
|
||||
from .lib import (
|
||||
get_current_comp,
|
||||
validate_comp_prefs,
|
||||
prompt_reset_context
|
||||
)
|
||||
|
||||
log = Logger.get_logger(__name__)
|
||||
|
||||
PLUGINS_DIR = os.path.join(FUSION_ADDON_ROOT, "plugins")
|
||||
|
||||
PUBLISH_PATH = os.path.join(PLUGINS_DIR, "publish")
|
||||
LOAD_PATH = os.path.join(PLUGINS_DIR, "load")
|
||||
CREATE_PATH = os.path.join(PLUGINS_DIR, "create")
|
||||
INVENTORY_PATH = os.path.join(PLUGINS_DIR, "inventory")
|
||||
|
||||
# Track whether the workfile tool is about to save
|
||||
_about_to_save = False
|
||||
|
||||
|
||||
class FusionLogHandler(logging.Handler):
|
||||
# Keep a reference to fusion's Print function (Remote Object)
|
||||
_print = None
|
||||
|
||||
@property
|
||||
def print(self):
|
||||
if self._print is not None:
|
||||
# Use cached
|
||||
return self._print
|
||||
|
||||
_print = getattr(sys.modules["__main__"], "fusion").Print
|
||||
if _print is None:
|
||||
# Backwards compatibility: Print method on Fusion instance was
|
||||
# added around Fusion 17.4 and wasn't available on PyRemote Object
|
||||
# before
|
||||
_print = get_current_comp().Print
|
||||
self._print = _print
|
||||
return _print
|
||||
|
||||
def emit(self, record):
|
||||
entry = self.format(record)
|
||||
self.print(entry)
|
||||
|
||||
|
||||
class FusionHost(HostBase, IWorkfileHost, ILoadHost, IPublishHost):
|
||||
name = "fusion"
|
||||
|
||||
def install(self):
|
||||
"""Install fusion-specific functionality of AYON.
|
||||
|
||||
This is where you install menus and register families, data
|
||||
and loaders into fusion.
|
||||
|
||||
It is called automatically when installing via
|
||||
`ayon_core.pipeline.install_host(ayon_fusion.api)`
|
||||
|
||||
See the Maya equivalent for inspiration on how to implement this.
|
||||
|
||||
"""
|
||||
# Remove all handlers associated with the root logger object, because
|
||||
# that one always logs as "warnings" incorrectly.
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
# Attach default logging handler that prints to active comp
|
||||
logger = logging.getLogger()
|
||||
formatter = logging.Formatter(fmt="%(message)s\n")
|
||||
handler = FusionLogHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
pyblish.api.register_host("fusion")
|
||||
pyblish.api.register_plugin_path(PUBLISH_PATH)
|
||||
log.info("Registering Fusion plug-ins..")
|
||||
|
||||
register_loader_plugin_path(LOAD_PATH)
|
||||
register_creator_plugin_path(CREATE_PATH)
|
||||
register_inventory_action_path(INVENTORY_PATH)
|
||||
|
||||
# Register events
|
||||
register_event_callback("open", on_after_open)
|
||||
register_event_callback("workfile.save.before", before_workfile_save)
|
||||
register_event_callback("save", on_save)
|
||||
register_event_callback("new", on_new)
|
||||
register_event_callback("taskChanged", on_task_changed)
|
||||
|
||||
# region workfile io api
|
||||
def has_unsaved_changes(self):
|
||||
comp = get_current_comp()
|
||||
return comp.GetAttrs()["COMPB_Modified"]
|
||||
|
||||
def get_workfile_extensions(self):
|
||||
return [".comp"]
|
||||
|
||||
def save_workfile(self, dst_path=None):
|
||||
comp = get_current_comp()
|
||||
comp.Save(dst_path)
|
||||
|
||||
def open_workfile(self, filepath):
|
||||
# Hack to get fusion, see
|
||||
# ayon_fusion.api.pipeline.get_current_comp()
|
||||
fusion = getattr(sys.modules["__main__"], "fusion", None)
|
||||
|
||||
return fusion.LoadComp(filepath)
|
||||
|
||||
def get_current_workfile(self):
|
||||
comp = get_current_comp()
|
||||
current_filepath = comp.GetAttrs()["COMPS_FileName"]
|
||||
if not current_filepath:
|
||||
return None
|
||||
|
||||
return current_filepath
|
||||
|
||||
def work_root(self, session):
|
||||
work_dir = session["AYON_WORKDIR"]
|
||||
scene_dir = session.get("AVALON_SCENEDIR")
|
||||
if scene_dir:
|
||||
return os.path.join(work_dir, scene_dir)
|
||||
else:
|
||||
return work_dir
|
||||
# endregion
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maintained_selection(self):
|
||||
from .lib import maintained_selection
|
||||
return maintained_selection()
|
||||
|
||||
def get_containers(self):
|
||||
return ls()
|
||||
|
||||
def update_context_data(self, data, changes):
|
||||
comp = get_current_comp()
|
||||
comp.SetData("openpype", data)
|
||||
|
||||
def get_context_data(self):
|
||||
comp = get_current_comp()
|
||||
return comp.GetData("openpype") or {}
|
||||
|
||||
|
||||
def on_new(event):
|
||||
comp = event["Rets"]["comp"]
|
||||
validate_comp_prefs(comp, force_repair=True)
|
||||
|
||||
|
||||
def on_save(event):
|
||||
comp = event["sender"]
|
||||
validate_comp_prefs(comp)
|
||||
|
||||
# We are now starting the actual save directly
|
||||
global _about_to_save
|
||||
_about_to_save = False
|
||||
|
||||
|
||||
def on_task_changed():
|
||||
global _about_to_save
|
||||
print(f"Task changed: {_about_to_save}")
|
||||
# TODO: Only do this if not headless
|
||||
if _about_to_save:
|
||||
# Let's prompt the user to update the context settings or not
|
||||
prompt_reset_context()
|
||||
|
||||
|
||||
def on_after_open(event):
|
||||
comp = event["sender"]
|
||||
validate_comp_prefs(comp)
|
||||
|
||||
if any_outdated_containers():
|
||||
log.warning("Scene has outdated content.")
|
||||
|
||||
# Find AYON menu to attach to
|
||||
from . import menu
|
||||
|
||||
def _on_show_scene_inventory():
|
||||
# ensure that comp is active
|
||||
frame = comp.CurrentFrame
|
||||
if not frame:
|
||||
print("Comp is closed, skipping show scene inventory")
|
||||
return
|
||||
frame.ActivateFrame() # raise comp window
|
||||
host_tools.show_scene_inventory()
|
||||
|
||||
from ayon_core.tools.utils import SimplePopup
|
||||
from ayon_core.style import load_stylesheet
|
||||
dialog = SimplePopup(parent=menu.menu)
|
||||
dialog.setWindowTitle("Fusion comp has outdated content")
|
||||
dialog.set_message("There are outdated containers in "
|
||||
"your Fusion comp.")
|
||||
dialog.on_clicked.connect(_on_show_scene_inventory)
|
||||
dialog.show()
|
||||
dialog.raise_()
|
||||
dialog.activateWindow()
|
||||
dialog.setStyleSheet(load_stylesheet())
|
||||
|
||||
|
||||
def before_workfile_save(event):
|
||||
# Due to Fusion's external python process design we can't really
|
||||
# detect whether the current Fusion environment matches the one the artists
|
||||
# expects it to be. For example, our pipeline python process might
|
||||
# have been shut down, and restarted - which will restart it to the
|
||||
# environment Fusion started with; not necessarily where the artist
|
||||
# is currently working.
|
||||
# The `_about_to_save` var is used to detect context changes when
|
||||
# saving into another asset. If we keep it False it will be ignored
|
||||
# as context change. As such, before we change tasks we will only
|
||||
# consider it the current filepath is within the currently known
|
||||
# AVALON_WORKDIR. This way we avoid false positives of thinking it's
|
||||
# saving to another context and instead sometimes just have false negatives
|
||||
# where we fail to show the "Update on task change" prompt.
|
||||
comp = get_current_comp()
|
||||
filepath = comp.GetAttrs()["COMPS_FileName"]
|
||||
workdir = os.environ.get("AYON_WORKDIR")
|
||||
if Path(workdir) in Path(filepath).parents:
|
||||
global _about_to_save
|
||||
_about_to_save = True
|
||||
|
||||
|
||||
def ls():
|
||||
"""List containers from active Fusion scene
|
||||
|
||||
This is the host-equivalent of api.ls(), but instead of listing
|
||||
assets on disk, it lists assets already loaded in Fusion; once loaded
|
||||
they are called 'containers'
|
||||
|
||||
Yields:
|
||||
dict: container
|
||||
|
||||
"""
|
||||
|
||||
comp = get_current_comp()
|
||||
tools = comp.GetToolList(False).values()
|
||||
|
||||
for tool in tools:
|
||||
container = parse_container(tool)
|
||||
if container:
|
||||
yield container
|
||||
|
||||
|
||||
def imprint_container(tool,
|
||||
name,
|
||||
namespace,
|
||||
context,
|
||||
loader=None):
|
||||
"""Imprint a Loader with metadata
|
||||
|
||||
Containerisation enables a tracking of version, author and origin
|
||||
for loaded assets.
|
||||
|
||||
Arguments:
|
||||
tool (object): The node in Fusion to imprint as container, usually a
|
||||
Loader.
|
||||
name (str): Name of resulting assembly
|
||||
namespace (str): Namespace under which to host container
|
||||
context (dict): Asset information
|
||||
loader (str, optional): Name of loader used to produce this container.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
data = [
|
||||
("schema", "openpype:container-2.0"),
|
||||
("id", AVALON_CONTAINER_ID),
|
||||
("name", str(name)),
|
||||
("namespace", str(namespace)),
|
||||
("loader", str(loader)),
|
||||
("representation", context["representation"]["id"]),
|
||||
]
|
||||
|
||||
for key, value in data:
|
||||
tool.SetData("avalon.{}".format(key), value)
|
||||
|
||||
|
||||
def parse_container(tool):
|
||||
"""Returns imprinted container data of a tool
|
||||
|
||||
This reads the imprinted data from `imprint_container`.
|
||||
|
||||
"""
|
||||
|
||||
data = tool.GetData('avalon')
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
# If not all required data return the empty container
|
||||
required = ['schema', 'id', 'name',
|
||||
'namespace', 'loader', 'representation']
|
||||
if not all(key in data for key in required):
|
||||
return
|
||||
|
||||
container = {key: data[key] for key in required}
|
||||
|
||||
# Store the tool's name
|
||||
container["objectName"] = tool.Name
|
||||
|
||||
# Store reference to the tool object
|
||||
container["_tool"] = tool
|
||||
|
||||
return container
|
||||
|
||||
|
||||
class FusionEventThread(QtCore.QThread):
|
||||
"""QThread which will periodically ping Fusion app for any events.
|
||||
The fusion.UIManager must be set up to be notified of events before they'll
|
||||
be reported by this thread, for example:
|
||||
fusion.UIManager.AddNotify("Comp_Save", None)
|
||||
|
||||
"""
|
||||
|
||||
on_event = QtCore.Signal(dict)
|
||||
|
||||
def run(self):
|
||||
|
||||
app = getattr(sys.modules["__main__"], "app", None)
|
||||
if app is None:
|
||||
# No Fusion app found
|
||||
return
|
||||
|
||||
# As optimization store the GetEvent method directly because every
|
||||
# getattr of UIManager.GetEvent tries to resolve the Remote Function
|
||||
# through the PyRemoteObject
|
||||
get_event = app.UIManager.GetEvent
|
||||
delay = int(os.environ.get("AYON_FUSION_CALLBACK_INTERVAL", 1000))
|
||||
while True:
|
||||
if self.isInterruptionRequested():
|
||||
return
|
||||
|
||||
# Process all events that have been queued up until now
|
||||
while True:
|
||||
event = get_event(False)
|
||||
if not event:
|
||||
break
|
||||
self.on_event.emit(event)
|
||||
|
||||
# Wait some time before processing events again
|
||||
# to not keep blocking the UI
|
||||
self.msleep(delay)
|
||||
|
||||
|
||||
class FusionEventHandler(QtCore.QObject):
|
||||
"""Emits AYON events based on Fusion events captured in a QThread.
|
||||
|
||||
This will emit the following AYON events based on Fusion actions:
|
||||
save: Comp_Save, Comp_SaveAs
|
||||
open: Comp_Opened
|
||||
new: Comp_New
|
||||
|
||||
To use this you can attach it to you Qt UI so it runs in the background.
|
||||
E.g.
|
||||
>>> handler = FusionEventHandler(parent=window)
|
||||
>>> handler.start()
|
||||
|
||||
"""
|
||||
ACTION_IDS = [
|
||||
"Comp_Save",
|
||||
"Comp_SaveAs",
|
||||
"Comp_New",
|
||||
"Comp_Opened"
|
||||
]
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super(FusionEventHandler, self).__init__(parent=parent)
|
||||
|
||||
# Set up Fusion event callbacks
|
||||
fusion = getattr(sys.modules["__main__"], "fusion", None)
|
||||
ui = fusion.UIManager
|
||||
|
||||
# Add notifications for the ones we want to listen to
|
||||
notifiers = []
|
||||
for action_id in self.ACTION_IDS:
|
||||
notifier = ui.AddNotify(action_id, None)
|
||||
notifiers.append(notifier)
|
||||
|
||||
# TODO: Not entirely sure whether these must be kept to avoid
|
||||
# garbage collection
|
||||
self._notifiers = notifiers
|
||||
|
||||
self._event_thread = FusionEventThread(parent=self)
|
||||
self._event_thread.on_event.connect(self._on_event)
|
||||
|
||||
def start(self):
|
||||
self._event_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._event_thread.stop()
|
||||
|
||||
def _on_event(self, event):
|
||||
"""Handle Fusion events to emit AYON events"""
|
||||
if not event:
|
||||
return
|
||||
|
||||
what = event["what"]
|
||||
|
||||
# Comp Save
|
||||
if what in {"Comp_Save", "Comp_SaveAs"}:
|
||||
if not event["Rets"].get("success"):
|
||||
# If the Save action is cancelled it will still emit an
|
||||
# event but with "success": False so we ignore those cases
|
||||
return
|
||||
# Comp was saved
|
||||
emit_event("save", data=event)
|
||||
return
|
||||
|
||||
# Comp New
|
||||
elif what in {"Comp_New"}:
|
||||
emit_event("new", data=event)
|
||||
|
||||
# Comp Opened
|
||||
elif what in {"Comp_Opened"}:
|
||||
emit_event("open", data=event)
|
||||
|
|
@ -1,278 +0,0 @@
|
|||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
from ayon_fusion.api import (
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk,
|
||||
)
|
||||
|
||||
from ayon_core.lib import (
|
||||
BoolDef,
|
||||
EnumDef,
|
||||
)
|
||||
from ayon_core.pipeline import (
|
||||
Creator,
|
||||
CreatedInstance,
|
||||
AVALON_INSTANCE_ID,
|
||||
AYON_INSTANCE_ID,
|
||||
)
|
||||
from ayon_core.pipeline.workfile import get_workdir
|
||||
from ayon_api import (
|
||||
get_project,
|
||||
get_folder_by_path,
|
||||
get_task_by_name
|
||||
)
|
||||
|
||||
|
||||
class GenericCreateSaver(Creator):
|
||||
default_variants = ["Main", "Mask"]
|
||||
description = "Fusion Saver to generate image sequence"
|
||||
icon = "fa5.eye"
|
||||
|
||||
instance_attributes = [
|
||||
"reviewable"
|
||||
]
|
||||
|
||||
settings_category = "fusion"
|
||||
|
||||
image_format = "exr"
|
||||
|
||||
# TODO: This should be renamed together with Nuke so it is aligned
|
||||
temp_rendering_path_template = (
|
||||
"{workdir}/renders/fusion/{product[name]}/"
|
||||
"{product[name]}.{frame}.{ext}"
|
||||
)
|
||||
|
||||
def create(self, product_name, instance_data, pre_create_data):
|
||||
self.pass_pre_attributes_to_instance(instance_data, pre_create_data)
|
||||
|
||||
instance = CreatedInstance(
|
||||
product_type=self.product_type,
|
||||
product_name=product_name,
|
||||
data=instance_data,
|
||||
creator=self,
|
||||
)
|
||||
data = instance.data_to_store()
|
||||
comp = get_current_comp()
|
||||
with comp_lock_and_undo_chunk(comp):
|
||||
args = (-32768, -32768) # Magical position numbers
|
||||
saver = comp.AddTool("Saver", *args)
|
||||
|
||||
self._update_tool_with_data(saver, data=data)
|
||||
|
||||
# Register the CreatedInstance
|
||||
self._imprint(saver, data)
|
||||
|
||||
# Insert the transient data
|
||||
instance.transient_data["tool"] = saver
|
||||
|
||||
self._add_instance_to_context(instance)
|
||||
|
||||
return instance
|
||||
|
||||
def collect_instances(self):
|
||||
comp = get_current_comp()
|
||||
tools = comp.GetToolList(False, "Saver").values()
|
||||
for tool in tools:
|
||||
data = self.get_managed_tool_data(tool)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# Add instance
|
||||
created_instance = CreatedInstance.from_existing(data, self)
|
||||
|
||||
# Collect transient data
|
||||
created_instance.transient_data["tool"] = tool
|
||||
|
||||
self._add_instance_to_context(created_instance)
|
||||
|
||||
def update_instances(self, update_list):
|
||||
for created_inst, _changes in update_list:
|
||||
new_data = created_inst.data_to_store()
|
||||
tool = created_inst.transient_data["tool"]
|
||||
self._update_tool_with_data(tool, new_data)
|
||||
self._imprint(tool, new_data)
|
||||
|
||||
def remove_instances(self, instances):
|
||||
for instance in instances:
|
||||
# Remove the tool from the scene
|
||||
|
||||
tool = instance.transient_data["tool"]
|
||||
if tool:
|
||||
tool.Delete()
|
||||
|
||||
# Remove the collected CreatedInstance to remove from UI directly
|
||||
self._remove_instance_from_context(instance)
|
||||
|
||||
def _imprint(self, tool, data):
|
||||
# Save all data in a "openpype.{key}" = value data
|
||||
|
||||
# Instance id is the tool's name so we don't need to imprint as data
|
||||
data.pop("instance_id", None)
|
||||
|
||||
active = data.pop("active", None)
|
||||
if active is not None:
|
||||
# Use active value to set the passthrough state
|
||||
tool.SetAttrs({"TOOLB_PassThrough": not active})
|
||||
|
||||
for key, value in data.items():
|
||||
tool.SetData(f"openpype.{key}", value)
|
||||
|
||||
def _update_tool_with_data(self, tool, data):
|
||||
"""Update tool node name and output path based on product data"""
|
||||
if "productName" not in data:
|
||||
return
|
||||
|
||||
original_product_name = tool.GetData("openpype.productName")
|
||||
original_format = tool.GetData(
|
||||
"openpype.creator_attributes.image_format"
|
||||
)
|
||||
|
||||
product_name = data["productName"]
|
||||
if (
|
||||
original_product_name != product_name
|
||||
or tool.GetData("openpype.task") != data["task"]
|
||||
or tool.GetData("openpype.folderPath") != data["folderPath"]
|
||||
or original_format != data["creator_attributes"]["image_format"]
|
||||
):
|
||||
self._configure_saver_tool(data, tool, product_name)
|
||||
|
||||
def _configure_saver_tool(self, data, tool, product_name):
|
||||
formatting_data = deepcopy(data)
|
||||
|
||||
# get frame padding from anatomy templates
|
||||
frame_padding = self.project_anatomy.templates_obj.frame_padding
|
||||
|
||||
# get output format
|
||||
ext = data["creator_attributes"]["image_format"]
|
||||
|
||||
# Product change detected
|
||||
product_type = formatting_data["productType"]
|
||||
f_product_name = formatting_data["productName"]
|
||||
|
||||
folder_path = formatting_data["folderPath"]
|
||||
folder_name = folder_path.rsplit("/", 1)[-1]
|
||||
|
||||
# If the folder path and task do not match the current context then the
|
||||
# workdir is not just the `AYON_WORKDIR`. Hence, we need to actually
|
||||
# compute the resulting workdir
|
||||
if (
|
||||
data["folderPath"] == self.create_context.get_current_folder_path()
|
||||
and data["task"] == self.create_context.get_current_task_name()
|
||||
):
|
||||
workdir = os.path.normpath(os.getenv("AYON_WORKDIR"))
|
||||
else:
|
||||
# TODO: Optimize this logic
|
||||
project_name = self.create_context.get_current_project_name()
|
||||
project_entity = get_project(project_name)
|
||||
folder_entity = get_folder_by_path(project_name,
|
||||
data["folderPath"])
|
||||
task_entity = get_task_by_name(project_name,
|
||||
folder_id=folder_entity["id"],
|
||||
task_name=data["task"])
|
||||
workdir = get_workdir(
|
||||
project_entity=project_entity,
|
||||
folder_entity=folder_entity,
|
||||
task_entity=task_entity,
|
||||
host_name=self.create_context.host_name,
|
||||
)
|
||||
|
||||
formatting_data.update({
|
||||
"workdir": workdir,
|
||||
"frame": "0" * frame_padding,
|
||||
"ext": ext,
|
||||
"product": {
|
||||
"name": f_product_name,
|
||||
"type": product_type,
|
||||
},
|
||||
# TODO add more variants for 'folder' and 'task'
|
||||
"folder": {
|
||||
"name": folder_name,
|
||||
},
|
||||
"task": {
|
||||
"name": data["task"],
|
||||
},
|
||||
# Backwards compatibility
|
||||
"asset": folder_name,
|
||||
"subset": f_product_name,
|
||||
"family": product_type,
|
||||
})
|
||||
|
||||
# build file path to render
|
||||
# TODO make sure the keys are available in 'formatting_data'
|
||||
temp_rendering_path_template = (
|
||||
self.temp_rendering_path_template
|
||||
.replace("{task}", "{task[name]}")
|
||||
)
|
||||
|
||||
filepath = temp_rendering_path_template.format(**formatting_data)
|
||||
|
||||
comp = get_current_comp()
|
||||
tool["Clip"] = comp.ReverseMapPath(os.path.normpath(filepath))
|
||||
|
||||
# Rename tool
|
||||
if tool.Name != product_name:
|
||||
print(f"Renaming {tool.Name} -> {product_name}")
|
||||
tool.SetAttrs({"TOOLS_Name": product_name})
|
||||
|
||||
def get_managed_tool_data(self, tool):
|
||||
"""Return data of the tool if it matches creator identifier"""
|
||||
data = tool.GetData("openpype")
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
if (
|
||||
data.get("creator_identifier") != self.identifier
|
||||
or data.get("id") not in {
|
||||
AYON_INSTANCE_ID, AVALON_INSTANCE_ID
|
||||
}
|
||||
):
|
||||
return
|
||||
|
||||
# Get active state from the actual tool state
|
||||
attrs = tool.GetAttrs()
|
||||
passthrough = attrs["TOOLB_PassThrough"]
|
||||
data["active"] = not passthrough
|
||||
|
||||
# Override publisher's UUID generation because tool names are
|
||||
# already unique in Fusion in a comp
|
||||
data["instance_id"] = tool.Name
|
||||
|
||||
return data
|
||||
|
||||
def get_instance_attr_defs(self):
|
||||
"""Settings for publish page"""
|
||||
return self.get_pre_create_attr_defs()
|
||||
|
||||
def pass_pre_attributes_to_instance(self, instance_data, pre_create_data):
|
||||
creator_attrs = instance_data["creator_attributes"] = {}
|
||||
for pass_key in pre_create_data.keys():
|
||||
creator_attrs[pass_key] = pre_create_data[pass_key]
|
||||
|
||||
def _get_render_target_enum(self):
|
||||
rendering_targets = {
|
||||
"local": "Local machine rendering",
|
||||
"frames": "Use existing frames",
|
||||
}
|
||||
if "farm_rendering" in self.instance_attributes:
|
||||
rendering_targets["farm"] = "Farm rendering"
|
||||
|
||||
return EnumDef(
|
||||
"render_target", items=rendering_targets, label="Render target"
|
||||
)
|
||||
|
||||
def _get_reviewable_bool(self):
|
||||
return BoolDef(
|
||||
"review",
|
||||
default=("reviewable" in self.instance_attributes),
|
||||
label="Review",
|
||||
)
|
||||
|
||||
def _get_image_format_enum(self):
|
||||
image_format_options = ["exr", "tga", "tif", "png", "jpg"]
|
||||
return EnumDef(
|
||||
"image_format",
|
||||
items=image_format_options,
|
||||
default=self.image_format,
|
||||
label="Output Image Format",
|
||||
)
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from qtpy import QtCore
|
||||
|
||||
|
||||
class PulseThread(QtCore.QThread):
|
||||
no_response = QtCore.Signal()
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super(PulseThread, self).__init__(parent=parent)
|
||||
|
||||
def run(self):
|
||||
app = getattr(sys.modules["__main__"], "app", None)
|
||||
|
||||
# Interval in milliseconds
|
||||
interval = os.environ.get("AYON_FUSION_PULSE_INTERVAL", 1000)
|
||||
|
||||
while True:
|
||||
if self.isInterruptionRequested():
|
||||
return
|
||||
|
||||
# We don't need to call Test because PyRemoteObject of the app
|
||||
# will actually fail to even resolve the Test function if it has
|
||||
# gone down. So we can actually already just check by confirming
|
||||
# the method is still getting resolved. (Optimization)
|
||||
if app.Test is None:
|
||||
self.no_response.emit()
|
||||
|
||||
self.msleep(interval)
|
||||
|
||||
|
||||
class FusionPulse(QtCore.QObject):
|
||||
"""A Timer that checks whether host app is still alive.
|
||||
|
||||
This checks whether the Fusion process is still active at a certain
|
||||
interval. This is useful due to how Fusion runs its scripts. Each script
|
||||
runs in its own environment and process (a `fusionscript` process each).
|
||||
If Fusion would go down and we have a UI process running at the same time
|
||||
then it can happen that the `fusionscript.exe` will remain running in the
|
||||
background in limbo due to e.g. a Qt interface's QApplication that keeps
|
||||
running infinitely.
|
||||
|
||||
Warning:
|
||||
When the host is not detected this will automatically exit
|
||||
the current process.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super(FusionPulse, self).__init__(parent=parent)
|
||||
self._thread = PulseThread(parent=self)
|
||||
self._thread.no_response.connect(self.on_no_response)
|
||||
|
||||
def on_no_response(self):
|
||||
print("Pulse detected no response from Fusion..")
|
||||
sys.exit(1)
|
||||
|
||||
def start(self):
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._thread.requestInterruption()
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
### AYON deploy MenuScripts
|
||||
|
||||
Note that this `MenuScripts` is not an official Fusion folder.
|
||||
AYON only uses this folder in `{fusion}/deploy/` to trigger the AYON menu actions.
|
||||
|
||||
They are used in the actions defined in `.fu` files in `{fusion}/deploy/Config`.
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# This is just a quick hack for users running Py3 locally but having no
|
||||
# Qt library installed
|
||||
import os
|
||||
import subprocess
|
||||
import importlib
|
||||
|
||||
|
||||
try:
|
||||
from qtpy import API_NAME
|
||||
|
||||
print(f"Qt binding: {API_NAME}")
|
||||
mod = importlib.import_module(API_NAME)
|
||||
print(f"Qt path: {mod.__file__}")
|
||||
print("Qt library found, nothing to do..")
|
||||
|
||||
except ImportError:
|
||||
print("Assuming no Qt library is installed..")
|
||||
print('Installing PySide2 for Python 3.6: '
|
||||
f'{os.environ["FUSION16_PYTHON36_HOME"]}')
|
||||
|
||||
# Get full path to python executable
|
||||
exe = "python.exe" if os.name == 'nt' else "python"
|
||||
python = os.path.join(os.environ["FUSION16_PYTHON36_HOME"], exe)
|
||||
assert os.path.exists(python), f"Python doesn't exist: {python}"
|
||||
|
||||
# Do python -m pip install PySide2
|
||||
args = [python, "-m", "pip", "install", "PySide2"]
|
||||
print(f"Args: {args}")
|
||||
subprocess.Popen(args)
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
# hack to handle discrepancy between distributed libraries and Python 3.6
|
||||
# mostly because wrong version of urllib3
|
||||
# TODO remove when not necessary
|
||||
from ayon_fusion import FUSION_ADDON_ROOT
|
||||
|
||||
vendor_path = os.path.join(FUSION_ADDON_ROOT, "vendor")
|
||||
if vendor_path not in sys.path:
|
||||
sys.path.insert(0, vendor_path)
|
||||
|
||||
print(f"Added vendorized libraries from {vendor_path}")
|
||||
|
||||
from ayon_core.lib import Logger
|
||||
from ayon_core.pipeline import (
|
||||
install_host,
|
||||
registered_host,
|
||||
)
|
||||
|
||||
|
||||
def main(env):
|
||||
# This script working directory starts in Fusion application folder.
|
||||
# However the contents of that folder can conflict with Qt library dlls
|
||||
# so we make sure to move out of it to avoid DLL Load Failed errors.
|
||||
os.chdir("..")
|
||||
from ayon_fusion.api import FusionHost
|
||||
from ayon_fusion.api import menu
|
||||
|
||||
# activate resolve from pype
|
||||
install_host(FusionHost())
|
||||
|
||||
log = Logger.get_logger(__name__)
|
||||
log.info(f"Registered host: {registered_host()}")
|
||||
|
||||
menu.launch_ayon_menu()
|
||||
|
||||
# Initiate a QTimer to check if Fusion is still alive every X interval
|
||||
# If Fusion is not found - kill itself
|
||||
# todo(roy): Implement timer that ensures UI doesn't remain when e.g.
|
||||
# Fusion closes down
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = main(os.environ)
|
||||
sys.exit(not bool(result))
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
{
|
||||
Action
|
||||
{
|
||||
ID = "AYON_Menu",
|
||||
Category = "AYON",
|
||||
Name = "AYON Menu",
|
||||
|
||||
Targets =
|
||||
{
|
||||
Composition =
|
||||
{
|
||||
Execute = _Lua [=[
|
||||
local scriptPath = app:MapPath("AYON:../MenuScripts/launch_menu.py")
|
||||
if bmd.fileexists(scriptPath) == false then
|
||||
print("[AYON Error] Can't run file: " .. scriptPath)
|
||||
else
|
||||
target:RunScript(scriptPath)
|
||||
end
|
||||
]=],
|
||||
},
|
||||
},
|
||||
},
|
||||
Action
|
||||
{
|
||||
ID = "AYON_Install_PySide2",
|
||||
Category = "AYON",
|
||||
Name = "Install PySide2",
|
||||
|
||||
Targets =
|
||||
{
|
||||
Composition =
|
||||
{
|
||||
Execute = _Lua [=[
|
||||
local scriptPath = app:MapPath("AYON:../MenuScripts/install_pyside2.py")
|
||||
if bmd.fileexists(scriptPath) == false then
|
||||
print("[AYON Error] Can't run file: " .. scriptPath)
|
||||
else
|
||||
target:RunScript(scriptPath)
|
||||
end
|
||||
]=],
|
||||
},
|
||||
},
|
||||
},
|
||||
Menus
|
||||
{
|
||||
Target = "ChildFrame",
|
||||
|
||||
Before "Help"
|
||||
{
|
||||
Sub "AYON"
|
||||
{
|
||||
"AYON_Menu{}",
|
||||
"_",
|
||||
Sub "Admin" {
|
||||
"AYON_Install_PySide2{}"
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
{
|
||||
Locked = true,
|
||||
Global = {
|
||||
Paths = {
|
||||
Map = {
|
||||
["AYON:"] = "$(AYON_FUSION_ROOT)/deploy/ayon",
|
||||
["Config:"] = "UserPaths:Config;AYON:Config",
|
||||
["Scripts:"] = "UserPaths:Scripts;Reactor:System/Scripts",
|
||||
},
|
||||
},
|
||||
Script = {
|
||||
PythonVersion = 3,
|
||||
Python3Forced = true
|
||||
},
|
||||
UserInterface = {
|
||||
Language = "en_US"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import os
|
||||
from ayon_applications import PreLaunchHook
|
||||
from ayon_fusion import FUSION_ADDON_ROOT
|
||||
|
||||
|
||||
class FusionLaunchMenuHook(PreLaunchHook):
|
||||
"""Launch AYON menu on start of Fusion"""
|
||||
app_groups = ["fusion"]
|
||||
order = 9
|
||||
|
||||
def execute(self):
|
||||
# Prelaunch hook is optional
|
||||
settings = self.data["project_settings"][self.host_name]
|
||||
if not settings["hooks"]["FusionLaunchMenuHook"]["enabled"]:
|
||||
return
|
||||
|
||||
variant = self.application.name
|
||||
if variant.isnumeric():
|
||||
version = int(variant)
|
||||
if version < 18:
|
||||
print("Skipping launch of OpenPype menu on Fusion start "
|
||||
"because Fusion version below 18.0 does not support "
|
||||
"/execute argument on launch. "
|
||||
f"Version detected: {version}")
|
||||
return
|
||||
else:
|
||||
print(f"Application variant is not numeric: {variant}. "
|
||||
"Validation for Fusion version 18+ for /execute "
|
||||
"prelaunch argument skipped.")
|
||||
|
||||
path = os.path.join(FUSION_ADDON_ROOT,
|
||||
"deploy",
|
||||
"MenuScripts",
|
||||
"launch_menu.py").replace("\\", "/")
|
||||
script = f"fusion:RunScript('{path}')"
|
||||
self.launch_context.launch_args.extend(["/execute", script])
|
||||
|
|
@ -1,169 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from ayon_fusion import (
|
||||
FUSION_ADDON_ROOT,
|
||||
FUSION_VERSIONS_DICT,
|
||||
get_fusion_version,
|
||||
)
|
||||
from ayon_applications import (
|
||||
PreLaunchHook,
|
||||
LaunchTypes,
|
||||
ApplicationLaunchFailed,
|
||||
)
|
||||
|
||||
|
||||
class FusionCopyPrefsPrelaunch(PreLaunchHook):
|
||||
"""
|
||||
Prepares local Fusion profile directory, copies existing Fusion profile.
|
||||
This also sets FUSION MasterPrefs variable, which is used
|
||||
to apply Master.prefs file to override some Fusion profile settings to:
|
||||
- enable the AYON menu
|
||||
- force Python 3 over Python 2
|
||||
- force English interface
|
||||
Master.prefs is defined in openpype/hosts/fusion/deploy/fusion_shared.prefs
|
||||
"""
|
||||
|
||||
app_groups = {"fusion"}
|
||||
order = 2
|
||||
launch_types = {LaunchTypes.local}
|
||||
|
||||
def get_fusion_profile_name(self, profile_version) -> str:
|
||||
# Returns 'Default', unless FUSION16_PROFILE is set
|
||||
return os.getenv(f"FUSION{profile_version}_PROFILE", "Default")
|
||||
|
||||
def get_fusion_profile_dir(self, profile_version) -> Path:
|
||||
# Get FUSION_PROFILE_DIR variable
|
||||
fusion_profile = self.get_fusion_profile_name(profile_version)
|
||||
fusion_var_prefs_dir = os.getenv(
|
||||
f"FUSION{profile_version}_PROFILE_DIR"
|
||||
)
|
||||
|
||||
# Check if FUSION_PROFILE_DIR exists
|
||||
if fusion_var_prefs_dir and Path(fusion_var_prefs_dir).is_dir():
|
||||
fu_prefs_dir = Path(fusion_var_prefs_dir, fusion_profile)
|
||||
self.log.info(f"{fusion_var_prefs_dir} is set to {fu_prefs_dir}")
|
||||
return fu_prefs_dir
|
||||
|
||||
def get_profile_source(self, profile_version) -> Path:
|
||||
"""Get Fusion preferences profile location.
|
||||
See Per-User_Preferences_and_Paths on VFXpedia for reference.
|
||||
"""
|
||||
fusion_profile = self.get_fusion_profile_name(profile_version)
|
||||
profile_source = self.get_fusion_profile_dir(profile_version)
|
||||
if profile_source:
|
||||
return profile_source
|
||||
# otherwise get default location of the profile folder
|
||||
fu_prefs_dir = f"Blackmagic Design/Fusion/Profiles/{fusion_profile}"
|
||||
if platform.system() == "Windows":
|
||||
profile_source = Path(os.getenv("AppData"), fu_prefs_dir)
|
||||
elif platform.system() == "Darwin":
|
||||
profile_source = Path(
|
||||
"~/Library/Application Support/", fu_prefs_dir
|
||||
).expanduser()
|
||||
elif platform.system() == "Linux":
|
||||
profile_source = Path("~/.fusion", fu_prefs_dir).expanduser()
|
||||
self.log.info(
|
||||
f"Locating source Fusion prefs directory: {profile_source}"
|
||||
)
|
||||
return profile_source
|
||||
|
||||
def get_copy_fusion_prefs_settings(self):
|
||||
# Get copy preferences options from the global application settings
|
||||
|
||||
copy_fusion_settings = self.data["project_settings"]["fusion"].get(
|
||||
"copy_fusion_settings", {}
|
||||
)
|
||||
if not copy_fusion_settings:
|
||||
self.log.error("Copy prefs settings not found")
|
||||
copy_status = copy_fusion_settings.get("copy_status", False)
|
||||
force_sync = copy_fusion_settings.get("force_sync", False)
|
||||
copy_path = copy_fusion_settings.get("copy_path") or None
|
||||
if copy_path:
|
||||
copy_path = Path(copy_path).expanduser()
|
||||
return copy_status, copy_path, force_sync
|
||||
|
||||
def copy_fusion_profile(
|
||||
self, copy_from: Path, copy_to: Path, force_sync: bool
|
||||
) -> None:
|
||||
"""On the first Fusion launch copy the contents of Fusion profile
|
||||
directory to the working predefined location. If the Openpype profile
|
||||
folder exists, skip copying, unless re-sync is checked.
|
||||
If the prefs were not copied on the first launch,
|
||||
clean Fusion profile will be created in fu_profile_dir.
|
||||
"""
|
||||
if copy_to.exists() and not force_sync:
|
||||
self.log.info(
|
||||
"Destination Fusion preferences folder already exists: "
|
||||
f"{copy_to} "
|
||||
)
|
||||
return
|
||||
self.log.info("Starting copying Fusion preferences")
|
||||
self.log.debug(f"force_sync option is set to {force_sync}")
|
||||
try:
|
||||
copy_to.mkdir(exist_ok=True, parents=True)
|
||||
except PermissionError:
|
||||
self.log.warning(f"Creating the folder not permitted at {copy_to}")
|
||||
return
|
||||
if not copy_from.exists():
|
||||
self.log.warning(f"Fusion preferences not found in {copy_from}")
|
||||
return
|
||||
for file in copy_from.iterdir():
|
||||
if file.suffix in (
|
||||
".prefs",
|
||||
".def",
|
||||
".blocklist",
|
||||
".fu",
|
||||
".toolbars",
|
||||
):
|
||||
# convert Path to str to be compatible with Python 3.6+
|
||||
shutil.copy(str(file), str(copy_to))
|
||||
self.log.info(
|
||||
f"Successfully copied preferences: {copy_from} to {copy_to}"
|
||||
)
|
||||
|
||||
def execute(self):
|
||||
(
|
||||
copy_status,
|
||||
fu_profile_dir,
|
||||
force_sync,
|
||||
) = self.get_copy_fusion_prefs_settings()
|
||||
|
||||
# Get launched application context and return correct app version
|
||||
app_name = self.launch_context.env.get("AYON_APP_NAME")
|
||||
app_version = get_fusion_version(app_name)
|
||||
if app_version is None:
|
||||
version_names = ", ".join(str(x) for x in FUSION_VERSIONS_DICT)
|
||||
raise ApplicationLaunchFailed(
|
||||
"Unable to detect valid Fusion version number from app "
|
||||
f"name: {app_name}.\nMake sure to include at least a digit "
|
||||
"to indicate the Fusion version like '18'.\n"
|
||||
f"Detectable Fusion versions are: {version_names}"
|
||||
)
|
||||
|
||||
_, profile_version = FUSION_VERSIONS_DICT[app_version]
|
||||
fu_profile = self.get_fusion_profile_name(profile_version)
|
||||
|
||||
# do a copy of Fusion profile if copy_status toggle is enabled
|
||||
if copy_status and fu_profile_dir is not None:
|
||||
profile_source = self.get_profile_source(profile_version)
|
||||
dest_folder = Path(fu_profile_dir, fu_profile)
|
||||
self.copy_fusion_profile(profile_source, dest_folder, force_sync)
|
||||
|
||||
# Add temporary profile directory variables to customize Fusion
|
||||
# to define where it can read custom scripts and tools from
|
||||
fu_profile_dir_variable = f"FUSION{profile_version}_PROFILE_DIR"
|
||||
self.log.info(f"Setting {fu_profile_dir_variable}: {fu_profile_dir}")
|
||||
self.launch_context.env[fu_profile_dir_variable] = str(fu_profile_dir)
|
||||
|
||||
# Add custom Fusion Master Prefs and the temporary
|
||||
# profile directory variables to customize Fusion
|
||||
# to define where it can read custom scripts and tools from
|
||||
master_prefs_variable = f"FUSION{profile_version}_MasterPrefs"
|
||||
|
||||
master_prefs = Path(
|
||||
FUSION_ADDON_ROOT, "deploy", "ayon", "fusion_shared.prefs")
|
||||
|
||||
self.log.info(f"Setting {master_prefs_variable}: {master_prefs}")
|
||||
self.launch_context.env[master_prefs_variable] = str(master_prefs)
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
import os
|
||||
from ayon_applications import (
|
||||
PreLaunchHook,
|
||||
LaunchTypes,
|
||||
ApplicationLaunchFailed,
|
||||
)
|
||||
from ayon_fusion import (
|
||||
FUSION_ADDON_ROOT,
|
||||
FUSION_VERSIONS_DICT,
|
||||
get_fusion_version,
|
||||
)
|
||||
|
||||
|
||||
class FusionPrelaunch(PreLaunchHook):
|
||||
"""
|
||||
Prepares AYON Fusion environment.
|
||||
Requires correct Python home variable to be defined in the environment
|
||||
settings for Fusion to point at a valid Python 3 build for Fusion.
|
||||
Python3 versions that are supported by Fusion:
|
||||
Fusion 9, 16, 17 : Python 3.6
|
||||
Fusion 18 : Python 3.6 - 3.10
|
||||
"""
|
||||
|
||||
app_groups = {"fusion"}
|
||||
order = 1
|
||||
launch_types = {LaunchTypes.local}
|
||||
|
||||
def execute(self):
|
||||
# making sure python 3 is installed at provided path
|
||||
# Py 3.3-3.10 for Fusion 18+ or Py 3.6 for Fu 16-17
|
||||
app_data = self.launch_context.env.get("AYON_APP_NAME")
|
||||
app_version = get_fusion_version(app_data)
|
||||
if not app_version:
|
||||
raise ApplicationLaunchFailed(
|
||||
"Fusion version information not found in System settings.\n"
|
||||
"The key field in the 'applications/fusion/variants' should "
|
||||
"consist a number, corresponding to major Fusion version."
|
||||
)
|
||||
py3_var, _ = FUSION_VERSIONS_DICT[app_version]
|
||||
fusion_python3_home = self.launch_context.env.get(py3_var, "")
|
||||
|
||||
for path in fusion_python3_home.split(os.pathsep):
|
||||
# Allow defining multiple paths, separated by os.pathsep,
|
||||
# to allow "fallback" to other path.
|
||||
# But make to set only a single path as final variable.
|
||||
py3_dir = os.path.normpath(path)
|
||||
if os.path.isdir(py3_dir):
|
||||
break
|
||||
else:
|
||||
raise ApplicationLaunchFailed(
|
||||
"Python 3 is not installed at the provided path.\n"
|
||||
"Make sure the environment in fusion settings has "
|
||||
"'FUSION_PYTHON3_HOME' set correctly and make sure "
|
||||
"Python 3 is installed in the given path."
|
||||
f"\n\nPYTHON PATH: {fusion_python3_home}"
|
||||
)
|
||||
|
||||
self.log.info(f"Setting {py3_var}: '{py3_dir}'...")
|
||||
self.launch_context.env[py3_var] = py3_dir
|
||||
|
||||
# Fusion 18+ requires FUSION_PYTHON3_HOME to also be on PATH
|
||||
if app_version >= 18:
|
||||
self.launch_context.env["PATH"] += os.pathsep + py3_dir
|
||||
|
||||
self.launch_context.env[py3_var] = py3_dir
|
||||
|
||||
# for hook installing PySide2
|
||||
self.data["fusion_python3_home"] = py3_dir
|
||||
|
||||
self.log.info(f"Setting AYON_FUSION_ROOT: {FUSION_ADDON_ROOT}")
|
||||
self.launch_context.env["AYON_FUSION_ROOT"] = FUSION_ADDON_ROOT
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
import os
|
||||
import subprocess
|
||||
import platform
|
||||
import uuid
|
||||
|
||||
from ayon_applications import PreLaunchHook, LaunchTypes
|
||||
|
||||
|
||||
class InstallPySideToFusion(PreLaunchHook):
|
||||
"""Automatically installs Qt binding to fusion's python packages.
|
||||
|
||||
Check if fusion has installed PySide2 and will try to install if not.
|
||||
|
||||
For pipeline implementation is required to have Qt binding installed in
|
||||
fusion's python packages.
|
||||
"""
|
||||
|
||||
app_groups = {"fusion"}
|
||||
order = 2
|
||||
launch_types = {LaunchTypes.local}
|
||||
|
||||
def execute(self):
|
||||
# Prelaunch hook is not crucial
|
||||
try:
|
||||
settings = self.data["project_settings"][self.host_name]
|
||||
if not settings["hooks"]["InstallPySideToFusion"]["enabled"]:
|
||||
return
|
||||
self.inner_execute()
|
||||
except Exception:
|
||||
self.log.warning(
|
||||
"Processing of {} crashed.".format(self.__class__.__name__),
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
def inner_execute(self):
|
||||
self.log.debug("Check for PySide2 installation.")
|
||||
|
||||
fusion_python3_home = self.data.get("fusion_python3_home")
|
||||
if not fusion_python3_home:
|
||||
self.log.warning("'fusion_python3_home' was not provided. "
|
||||
"Installation of PySide2 not possible")
|
||||
return
|
||||
|
||||
if platform.system().lower() == "windows":
|
||||
exe_filenames = ["python.exe"]
|
||||
else:
|
||||
exe_filenames = ["python3", "python"]
|
||||
|
||||
for exe_filename in exe_filenames:
|
||||
python_executable = os.path.join(fusion_python3_home, exe_filename)
|
||||
if os.path.exists(python_executable):
|
||||
break
|
||||
|
||||
if not os.path.exists(python_executable):
|
||||
self.log.warning(
|
||||
"Couldn't find python executable for fusion. {}".format(
|
||||
python_executable
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Check if PySide2 is installed and skip if yes
|
||||
if self._is_pyside_installed(python_executable):
|
||||
self.log.debug("Fusion has already installed PySide2.")
|
||||
return
|
||||
|
||||
self.log.debug("Installing PySide2.")
|
||||
# Install PySide2 in fusion's python
|
||||
if self._windows_require_permissions(
|
||||
os.path.dirname(python_executable)):
|
||||
result = self._install_pyside_windows(python_executable)
|
||||
else:
|
||||
result = self._install_pyside(python_executable)
|
||||
|
||||
if result:
|
||||
self.log.info("Successfully installed PySide2 module to fusion.")
|
||||
else:
|
||||
self.log.warning("Failed to install PySide2 module to fusion.")
|
||||
|
||||
def _install_pyside_windows(self, python_executable):
|
||||
"""Install PySide2 python module to fusion's python.
|
||||
|
||||
Installation requires administration rights that's why it is required
|
||||
to use "pywin32" module which can execute command's and ask for
|
||||
administration rights.
|
||||
"""
|
||||
try:
|
||||
import win32con
|
||||
import win32process
|
||||
import win32event
|
||||
import pywintypes
|
||||
from win32comext.shell.shell import ShellExecuteEx
|
||||
from win32comext.shell import shellcon
|
||||
except Exception:
|
||||
self.log.warning("Couldn't import \"pywin32\" modules")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Parameters
|
||||
# - use "-m pip" as module pip to install PySide2 and argument
|
||||
# "--ignore-installed" is to force install module to fusion's
|
||||
# site-packages and make sure it is binary compatible
|
||||
parameters = "-m pip install --ignore-installed PySide2"
|
||||
|
||||
# Execute command and ask for administrator's rights
|
||||
process_info = ShellExecuteEx(
|
||||
nShow=win32con.SW_SHOWNORMAL,
|
||||
fMask=shellcon.SEE_MASK_NOCLOSEPROCESS,
|
||||
lpVerb="runas",
|
||||
lpFile=python_executable,
|
||||
lpParameters=parameters,
|
||||
lpDirectory=os.path.dirname(python_executable)
|
||||
)
|
||||
process_handle = process_info["hProcess"]
|
||||
win32event.WaitForSingleObject(process_handle,
|
||||
win32event.INFINITE)
|
||||
returncode = win32process.GetExitCodeProcess(process_handle)
|
||||
return returncode == 0
|
||||
except pywintypes.error:
|
||||
return False
|
||||
|
||||
def _install_pyside(self, python_executable):
|
||||
"""Install PySide2 python module to fusion's python."""
|
||||
try:
|
||||
# Parameters
|
||||
# - use "-m pip" as module pip to install PySide2 and argument
|
||||
# "--ignore-installed" is to force install module to fusion's
|
||||
# site-packages and make sure it is binary compatible
|
||||
env = dict(os.environ)
|
||||
del env['PYTHONPATH']
|
||||
args = [
|
||||
python_executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"--ignore-installed",
|
||||
"PySide2",
|
||||
]
|
||||
process = subprocess.Popen(
|
||||
args, stdout=subprocess.PIPE, universal_newlines=True,
|
||||
env=env
|
||||
)
|
||||
process.communicate()
|
||||
return process.returncode == 0
|
||||
except PermissionError:
|
||||
self.log.warning(
|
||||
"Permission denied with command:"
|
||||
"\"{}\".".format(" ".join(args))
|
||||
)
|
||||
except OSError as error:
|
||||
self.log.warning(f"OS error has occurred: \"{error}\".")
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
|
||||
def _is_pyside_installed(self, python_executable):
|
||||
"""Check if PySide2 module is in fusion's pip list."""
|
||||
args = [python_executable, "-c", "from qtpy import QtWidgets"]
|
||||
process = subprocess.Popen(args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
_, stderr = process.communicate()
|
||||
stderr = stderr.decode()
|
||||
if stderr:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _windows_require_permissions(self, dirpath):
|
||||
if platform.system().lower() != "windows":
|
||||
return False
|
||||
|
||||
try:
|
||||
# Attempt to create a temporary file in the folder
|
||||
temp_file_path = os.path.join(dirpath, uuid.uuid4().hex)
|
||||
with open(temp_file_path, "w"):
|
||||
pass
|
||||
os.remove(temp_file_path) # Clean up temporary file
|
||||
return False
|
||||
|
||||
except PermissionError:
|
||||
return True
|
||||
|
||||
except BaseException as exc:
|
||||
print(("Failed to determine if root requires permissions."
|
||||
"Unexpected error: {}").format(exc))
|
||||
return False
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
from ayon_core.lib import NumberDef
|
||||
|
||||
from ayon_fusion.api.plugin import GenericCreateSaver
|
||||
|
||||
|
||||
class CreateImageSaver(GenericCreateSaver):
|
||||
"""Fusion Saver to generate single image.
|
||||
|
||||
Created to explicitly separate single ('image') or
|
||||
multi frame('render) outputs.
|
||||
|
||||
This might be temporary creator until 'alias' functionality will be
|
||||
implemented to limit creation of additional product types with similar, but
|
||||
not the same workflows.
|
||||
"""
|
||||
identifier = "io.openpype.creators.fusion.imagesaver"
|
||||
label = "Image (saver)"
|
||||
name = "image"
|
||||
product_type = "image"
|
||||
description = "Fusion Saver to generate image"
|
||||
|
||||
default_frame = 0
|
||||
|
||||
def get_detail_description(self):
|
||||
return """Fusion Saver to generate single image.
|
||||
|
||||
This creator is expected for publishing of single frame `image` product
|
||||
type.
|
||||
|
||||
Artist should provide frame number (integer) to specify which frame
|
||||
should be published. It must be inside of global timeline frame range.
|
||||
|
||||
Supports local and deadline rendering.
|
||||
|
||||
Supports selection from predefined set of output file extensions:
|
||||
- exr
|
||||
- tga
|
||||
- png
|
||||
- tif
|
||||
- jpg
|
||||
|
||||
Created to explicitly separate single frame ('image') or
|
||||
multi frame ('render') outputs.
|
||||
"""
|
||||
|
||||
def get_pre_create_attr_defs(self):
|
||||
"""Settings for create page"""
|
||||
attr_defs = [
|
||||
self._get_render_target_enum(),
|
||||
self._get_reviewable_bool(),
|
||||
self._get_frame_int(),
|
||||
self._get_image_format_enum(),
|
||||
]
|
||||
return attr_defs
|
||||
|
||||
def _get_frame_int(self):
|
||||
return NumberDef(
|
||||
"frame",
|
||||
default=self.default_frame,
|
||||
label="Frame",
|
||||
tooltip="Set frame to be rendered, must be inside of global "
|
||||
"timeline range"
|
||||
)
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
from ayon_core.lib import (
|
||||
UILabelDef,
|
||||
NumberDef,
|
||||
EnumDef
|
||||
)
|
||||
|
||||
from ayon_fusion.api.plugin import GenericCreateSaver
|
||||
from ayon_fusion.api.lib import get_current_comp
|
||||
|
||||
|
||||
class CreateSaver(GenericCreateSaver):
|
||||
"""Fusion Saver to generate image sequence of 'render' product type.
|
||||
|
||||
Original Saver creator targeted for 'render' product type. It uses
|
||||
original not to descriptive name because of values in Settings.
|
||||
"""
|
||||
identifier = "io.openpype.creators.fusion.saver"
|
||||
label = "Render (saver)"
|
||||
name = "render"
|
||||
product_type = "render"
|
||||
description = "Fusion Saver to generate image sequence"
|
||||
|
||||
default_frame_range_option = "current_folder"
|
||||
|
||||
def get_detail_description(self):
|
||||
return """Fusion Saver to generate image sequence.
|
||||
|
||||
This creator is expected for publishing of image sequences for 'render'
|
||||
product type. (But can publish even single frame 'render'.)
|
||||
|
||||
Select what should be source of render range:
|
||||
- "Current Folder context" - values set on folder on AYON server
|
||||
- "From render in/out" - from node itself
|
||||
- "From composition timeline" - from timeline
|
||||
|
||||
Supports local and farm rendering.
|
||||
|
||||
Supports selection from predefined set of output file extensions:
|
||||
- exr
|
||||
- tga
|
||||
- png
|
||||
- tif
|
||||
- jpg
|
||||
"""
|
||||
|
||||
def get_pre_create_attr_defs(self):
|
||||
"""Settings for create page"""
|
||||
attr_defs = [
|
||||
self._get_render_target_enum(),
|
||||
self._get_reviewable_bool(),
|
||||
self._get_frame_range_enum(),
|
||||
self._get_image_format_enum(),
|
||||
*self._get_custom_frame_range_attribute_defs()
|
||||
]
|
||||
return attr_defs
|
||||
|
||||
def _get_frame_range_enum(self):
|
||||
frame_range_options = {
|
||||
"current_folder": "Current Folder context",
|
||||
"render_range": "From render in/out",
|
||||
"comp_range": "From composition timeline",
|
||||
"custom_range": "Custom frame range",
|
||||
}
|
||||
|
||||
return EnumDef(
|
||||
"frame_range_source",
|
||||
items=frame_range_options,
|
||||
label="Frame range source",
|
||||
default=self.default_frame_range_option
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_custom_frame_range_attribute_defs() -> list:
|
||||
|
||||
# Define custom frame range defaults based on current comp
|
||||
# timeline settings (if a comp is currently open)
|
||||
comp = get_current_comp()
|
||||
if comp is not None:
|
||||
attrs = comp.GetAttrs()
|
||||
frame_defaults = {
|
||||
"frameStart": int(attrs["COMPN_GlobalStart"]),
|
||||
"frameEnd": int(attrs["COMPN_GlobalEnd"]),
|
||||
"handleStart": int(
|
||||
attrs["COMPN_RenderStart"] - attrs["COMPN_GlobalStart"]
|
||||
),
|
||||
"handleEnd": int(
|
||||
attrs["COMPN_GlobalEnd"] - attrs["COMPN_RenderEnd"]
|
||||
),
|
||||
}
|
||||
else:
|
||||
frame_defaults = {
|
||||
"frameStart": 1001,
|
||||
"frameEnd": 1100,
|
||||
"handleStart": 0,
|
||||
"handleEnd": 0
|
||||
}
|
||||
|
||||
return [
|
||||
UILabelDef(
|
||||
label="<br><b>Custom Frame Range</b><br>"
|
||||
"<i>only used with 'Custom frame range' source</i>"
|
||||
),
|
||||
NumberDef(
|
||||
"custom_frameStart",
|
||||
label="Frame Start",
|
||||
default=frame_defaults["frameStart"],
|
||||
minimum=0,
|
||||
decimals=0,
|
||||
tooltip=(
|
||||
"Set the start frame for the export.\n"
|
||||
"Only used if frame range source is 'Custom frame range'."
|
||||
)
|
||||
),
|
||||
NumberDef(
|
||||
"custom_frameEnd",
|
||||
label="Frame End",
|
||||
default=frame_defaults["frameEnd"],
|
||||
minimum=0,
|
||||
decimals=0,
|
||||
tooltip=(
|
||||
"Set the end frame for the export.\n"
|
||||
"Only used if frame range source is 'Custom frame range'."
|
||||
)
|
||||
),
|
||||
NumberDef(
|
||||
"custom_handleStart",
|
||||
label="Handle Start",
|
||||
default=frame_defaults["handleStart"],
|
||||
minimum=0,
|
||||
decimals=0,
|
||||
tooltip=(
|
||||
"Set the start handles for the export, this will be "
|
||||
"added before the start frame.\n"
|
||||
"Only used if frame range source is 'Custom frame range'."
|
||||
)
|
||||
),
|
||||
NumberDef(
|
||||
"custom_handleEnd",
|
||||
label="Handle End",
|
||||
default=frame_defaults["handleEnd"],
|
||||
minimum=0,
|
||||
decimals=0,
|
||||
tooltip=(
|
||||
"Set the end handles for the export, this will be added "
|
||||
"after the end frame.\n"
|
||||
"Only used if frame range source is 'Custom frame range'."
|
||||
)
|
||||
)
|
||||
]
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
import ayon_api
|
||||
|
||||
from ayon_fusion.api import (
|
||||
get_current_comp
|
||||
)
|
||||
from ayon_core.pipeline import (
|
||||
AutoCreator,
|
||||
CreatedInstance,
|
||||
)
|
||||
|
||||
|
||||
class FusionWorkfileCreator(AutoCreator):
|
||||
identifier = "workfile"
|
||||
product_type = "workfile"
|
||||
label = "Workfile"
|
||||
icon = "fa5.file"
|
||||
|
||||
default_variant = "Main"
|
||||
|
||||
create_allow_context_change = False
|
||||
|
||||
data_key = "openpype_workfile"
|
||||
|
||||
def collect_instances(self):
|
||||
|
||||
comp = get_current_comp()
|
||||
data = comp.GetData(self.data_key)
|
||||
if not data:
|
||||
return
|
||||
|
||||
product_name = data.get("productName")
|
||||
if product_name is None:
|
||||
product_name = data["subset"]
|
||||
instance = CreatedInstance(
|
||||
product_type=self.product_type,
|
||||
product_name=product_name,
|
||||
data=data,
|
||||
creator=self
|
||||
)
|
||||
instance.transient_data["comp"] = comp
|
||||
|
||||
self._add_instance_to_context(instance)
|
||||
|
||||
def update_instances(self, update_list):
|
||||
for created_inst, _changes in update_list:
|
||||
comp = created_inst.transient_data["comp"]
|
||||
if not hasattr(comp, "SetData"):
|
||||
# Comp is not alive anymore, likely closed by the user
|
||||
self.log.error("Workfile comp not found for existing instance."
|
||||
" Comp might have been closed in the meantime.")
|
||||
continue
|
||||
|
||||
# Imprint data into the comp
|
||||
data = created_inst.data_to_store()
|
||||
comp.SetData(self.data_key, data)
|
||||
|
||||
def create(self, options=None):
|
||||
comp = get_current_comp()
|
||||
if not comp:
|
||||
self.log.error("Unable to find current comp")
|
||||
return
|
||||
|
||||
existing_instance = None
|
||||
for instance in self.create_context.instances:
|
||||
if instance.product_type == self.product_type:
|
||||
existing_instance = instance
|
||||
break
|
||||
|
||||
project_name = self.create_context.get_current_project_name()
|
||||
folder_path = self.create_context.get_current_folder_path()
|
||||
task_name = self.create_context.get_current_task_name()
|
||||
host_name = self.create_context.host_name
|
||||
|
||||
existing_folder_path = None
|
||||
if existing_instance is not None:
|
||||
existing_folder_path = existing_instance["folderPath"]
|
||||
|
||||
if existing_instance is None:
|
||||
folder_entity = ayon_api.get_folder_by_path(
|
||||
project_name, folder_path
|
||||
)
|
||||
task_entity = ayon_api.get_task_by_name(
|
||||
project_name, folder_entity["id"], task_name
|
||||
)
|
||||
product_name = self.get_product_name(
|
||||
project_name,
|
||||
folder_entity,
|
||||
task_entity,
|
||||
self.default_variant,
|
||||
host_name,
|
||||
)
|
||||
data = {
|
||||
"folderPath": folder_path,
|
||||
"task": task_name,
|
||||
"variant": self.default_variant,
|
||||
}
|
||||
data.update(self.get_dynamic_data(
|
||||
project_name,
|
||||
folder_entity,
|
||||
task_entity,
|
||||
self.default_variant,
|
||||
host_name,
|
||||
None
|
||||
|
||||
))
|
||||
|
||||
new_instance = CreatedInstance(
|
||||
self.product_type, product_name, data, self
|
||||
)
|
||||
new_instance.transient_data["comp"] = comp
|
||||
self._add_instance_to_context(new_instance)
|
||||
|
||||
elif (
|
||||
existing_folder_path != folder_path
|
||||
or existing_instance["task"] != task_name
|
||||
):
|
||||
folder_entity = ayon_api.get_folder_by_path(
|
||||
project_name, folder_path
|
||||
)
|
||||
task_entity = ayon_api.get_task_by_name(
|
||||
project_name, folder_entity["id"], task_name
|
||||
)
|
||||
product_name = self.get_product_name(
|
||||
project_name,
|
||||
folder_entity,
|
||||
task_entity,
|
||||
self.default_variant,
|
||||
host_name,
|
||||
)
|
||||
existing_instance["folderPath"] = folder_path
|
||||
existing_instance["task"] = task_name
|
||||
existing_instance["productName"] = product_name
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from ayon_core.pipeline import InventoryAction
|
||||
|
||||
|
||||
class FusionSelectContainers(InventoryAction):
|
||||
|
||||
label = "Select Containers"
|
||||
icon = "mouse-pointer"
|
||||
color = "#d8d8d8"
|
||||
|
||||
def process(self, containers):
|
||||
from ayon_fusion.api import (
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk
|
||||
)
|
||||
|
||||
tools = [i["_tool"] for i in containers]
|
||||
|
||||
comp = get_current_comp()
|
||||
flow = comp.CurrentFrame.FlowView
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, self.label):
|
||||
# Clear selection
|
||||
flow.Select()
|
||||
|
||||
# Select tool
|
||||
for tool in tools:
|
||||
flow.Select(tool)
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
from qtpy import QtGui, QtWidgets
|
||||
|
||||
from ayon_core.pipeline import InventoryAction
|
||||
from ayon_core import style
|
||||
from ayon_fusion.api import (
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk
|
||||
)
|
||||
|
||||
|
||||
class FusionSetToolColor(InventoryAction):
|
||||
"""Update the color of the selected tools"""
|
||||
|
||||
label = "Set Tool Color"
|
||||
icon = "plus"
|
||||
color = "#d8d8d8"
|
||||
_fallback_color = QtGui.QColor(1.0, 1.0, 1.0)
|
||||
|
||||
def process(self, containers):
|
||||
"""Color all selected tools the selected colors"""
|
||||
|
||||
result = []
|
||||
comp = get_current_comp()
|
||||
|
||||
# Get tool color
|
||||
first = containers[0]
|
||||
tool = first["_tool"]
|
||||
color = tool.TileColor
|
||||
|
||||
if color is not None:
|
||||
qcolor = QtGui.QColor().fromRgbF(color["R"], color["G"], color["B"])
|
||||
else:
|
||||
qcolor = self._fallback_color
|
||||
|
||||
# Launch pick color
|
||||
picked_color = self.get_color_picker(qcolor)
|
||||
if not picked_color:
|
||||
return
|
||||
|
||||
with comp_lock_and_undo_chunk(comp):
|
||||
for container in containers:
|
||||
# Convert color to RGB 0-1 floats
|
||||
rgb_f = picked_color.getRgbF()
|
||||
rgb_f_table = {"R": rgb_f[0], "G": rgb_f[1], "B": rgb_f[2]}
|
||||
|
||||
# Update tool
|
||||
tool = container["_tool"]
|
||||
tool.TileColor = rgb_f_table
|
||||
|
||||
result.append(container)
|
||||
|
||||
return result
|
||||
|
||||
def get_color_picker(self, color):
|
||||
"""Launch color picker and return chosen color
|
||||
|
||||
Args:
|
||||
color(QtGui.QColor): Start color to display
|
||||
|
||||
Returns:
|
||||
QtGui.QColor
|
||||
|
||||
"""
|
||||
|
||||
color_dialog = QtWidgets.QColorDialog(color)
|
||||
color_dialog.setStyleSheet(style.load_stylesheet())
|
||||
|
||||
accepted = color_dialog.exec_()
|
||||
if not accepted:
|
||||
return
|
||||
|
||||
return color_dialog.selectedColor()
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
"""A module containing generic loader actions that will display in the Loader.
|
||||
|
||||
"""
|
||||
|
||||
from ayon_core.pipeline import load
|
||||
|
||||
|
||||
class FusionSetFrameRangeLoader(load.LoaderPlugin):
|
||||
"""Set frame range excluding pre- and post-handles"""
|
||||
|
||||
product_types = {
|
||||
"animation",
|
||||
"camera",
|
||||
"imagesequence",
|
||||
"render",
|
||||
"yeticache",
|
||||
"pointcache",
|
||||
"render",
|
||||
}
|
||||
representations = {"*"}
|
||||
extensions = {"*"}
|
||||
|
||||
label = "Set frame range"
|
||||
order = 11
|
||||
icon = "clock-o"
|
||||
color = "white"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
|
||||
from ayon_fusion.api import lib
|
||||
|
||||
version_attributes = context["version"]["attrib"]
|
||||
|
||||
start = version_attributes.get("frameStart", None)
|
||||
end = version_attributes.get("frameEnd", None)
|
||||
|
||||
if start is None or end is None:
|
||||
print("Skipping setting frame range because start or "
|
||||
"end frame data is missing..")
|
||||
return
|
||||
|
||||
lib.update_frame_range(start, end)
|
||||
|
||||
|
||||
class FusionSetFrameRangeWithHandlesLoader(load.LoaderPlugin):
|
||||
"""Set frame range including pre- and post-handles"""
|
||||
|
||||
product_types = {
|
||||
"animation",
|
||||
"camera",
|
||||
"imagesequence",
|
||||
"render",
|
||||
"yeticache",
|
||||
"pointcache",
|
||||
"render",
|
||||
}
|
||||
representations = {"*"}
|
||||
|
||||
label = "Set frame range (with handles)"
|
||||
order = 12
|
||||
icon = "clock-o"
|
||||
color = "white"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
|
||||
from ayon_fusion.api import lib
|
||||
|
||||
version_attributes = context["version"]["attrib"]
|
||||
start = version_attributes.get("frameStart", None)
|
||||
end = version_attributes.get("frameEnd", None)
|
||||
|
||||
if start is None or end is None:
|
||||
print("Skipping setting frame range because start or "
|
||||
"end frame data is missing..")
|
||||
return
|
||||
|
||||
# Include handles
|
||||
start -= version_attributes.get("handleStart", 0)
|
||||
end += version_attributes.get("handleEnd", 0)
|
||||
|
||||
lib.update_frame_range(start, end)
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
from ayon_core.pipeline import (
|
||||
load,
|
||||
get_representation_path,
|
||||
)
|
||||
from ayon_fusion.api import (
|
||||
imprint_container,
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk
|
||||
)
|
||||
|
||||
|
||||
class FusionLoadAlembicMesh(load.LoaderPlugin):
|
||||
"""Load Alembic mesh into Fusion"""
|
||||
|
||||
product_types = {"pointcache", "model"}
|
||||
representations = {"*"}
|
||||
extensions = {"abc"}
|
||||
|
||||
label = "Load alembic mesh"
|
||||
order = -10
|
||||
icon = "code-fork"
|
||||
color = "orange"
|
||||
|
||||
tool_type = "SurfaceAlembicMesh"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
# Fallback to folder name when namespace is None
|
||||
if namespace is None:
|
||||
namespace = context["folder"]["name"]
|
||||
|
||||
# Create the Loader with the filename path set
|
||||
comp = get_current_comp()
|
||||
with comp_lock_and_undo_chunk(comp, "Create tool"):
|
||||
|
||||
path = self.filepath_from_context(context)
|
||||
|
||||
args = (-32768, -32768)
|
||||
tool = comp.AddTool(self.tool_type, *args)
|
||||
tool["Filename"] = path
|
||||
|
||||
imprint_container(tool,
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
context=context,
|
||||
loader=self.__class__.__name__)
|
||||
|
||||
def switch(self, container, context):
|
||||
self.update(container, context)
|
||||
|
||||
def update(self, container, context):
|
||||
"""Update Alembic path"""
|
||||
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
repre_entity = context["representation"]
|
||||
path = get_representation_path(repre_entity)
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Update tool"):
|
||||
tool["Filename"] = path
|
||||
|
||||
# Update the imprinted representation
|
||||
tool.SetData("avalon.representation", repre_entity["id"])
|
||||
|
||||
def remove(self, container):
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Remove tool"):
|
||||
tool.Delete()
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
from ayon_core.pipeline import (
|
||||
load,
|
||||
get_representation_path,
|
||||
)
|
||||
from ayon_fusion.api import (
|
||||
imprint_container,
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk,
|
||||
)
|
||||
|
||||
|
||||
class FusionLoadFBXMesh(load.LoaderPlugin):
|
||||
"""Load FBX mesh into Fusion"""
|
||||
|
||||
product_types = {"*"}
|
||||
representations = {"*"}
|
||||
extensions = {
|
||||
"3ds",
|
||||
"amc",
|
||||
"aoa",
|
||||
"asf",
|
||||
"bvh",
|
||||
"c3d",
|
||||
"dae",
|
||||
"dxf",
|
||||
"fbx",
|
||||
"htr",
|
||||
"mcd",
|
||||
"obj",
|
||||
"trc",
|
||||
}
|
||||
|
||||
label = "Load FBX mesh"
|
||||
order = -10
|
||||
icon = "code-fork"
|
||||
color = "orange"
|
||||
|
||||
tool_type = "SurfaceFBXMesh"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
# Fallback to folder name when namespace is None
|
||||
if namespace is None:
|
||||
namespace = context["folder"]["name"]
|
||||
|
||||
# Create the Loader with the filename path set
|
||||
comp = get_current_comp()
|
||||
with comp_lock_and_undo_chunk(comp, "Create tool"):
|
||||
path = self.filepath_from_context(context)
|
||||
|
||||
args = (-32768, -32768)
|
||||
tool = comp.AddTool(self.tool_type, *args)
|
||||
tool["ImportFile"] = path
|
||||
|
||||
imprint_container(
|
||||
tool,
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
context=context,
|
||||
loader=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def switch(self, container, context):
|
||||
self.update(container, context)
|
||||
|
||||
def update(self, container, context):
|
||||
"""Update path"""
|
||||
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
repre_entity = context["representation"]
|
||||
path = get_representation_path(repre_entity)
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Update tool"):
|
||||
tool["ImportFile"] = path
|
||||
|
||||
# Update the imprinted representation
|
||||
tool.SetData("avalon.representation", repre_entity["id"])
|
||||
|
||||
def remove(self, container):
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Remove tool"):
|
||||
tool.Delete()
|
||||
|
|
@ -1,291 +0,0 @@
|
|||
import contextlib
|
||||
|
||||
import ayon_core.pipeline.load as load
|
||||
from ayon_fusion.api import (
|
||||
imprint_container,
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk,
|
||||
)
|
||||
from ayon_core.lib.transcoding import IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
|
||||
comp = get_current_comp()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def preserve_inputs(tool, inputs):
|
||||
"""Preserve the tool's inputs after context"""
|
||||
|
||||
comp = tool.Comp()
|
||||
|
||||
values = {}
|
||||
for name in inputs:
|
||||
tool_input = getattr(tool, name)
|
||||
value = tool_input[comp.TIME_UNDEFINED]
|
||||
values[name] = value
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, value in values.items():
|
||||
tool_input = getattr(tool, name)
|
||||
tool_input[comp.TIME_UNDEFINED] = value
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def preserve_trim(loader, log=None):
|
||||
"""Preserve the relative trim of the Loader tool.
|
||||
|
||||
This tries to preserve the loader's trim (trim in and trim out) after
|
||||
the context by reapplying the "amount" it trims on the clip's length at
|
||||
start and end.
|
||||
|
||||
"""
|
||||
|
||||
# Get original trim as amount of "trimming" from length
|
||||
time = loader.Comp().TIME_UNDEFINED
|
||||
length = loader.GetAttrs()["TOOLIT_Clip_Length"][1] - 1
|
||||
trim_from_start = loader["ClipTimeStart"][time]
|
||||
trim_from_end = length - loader["ClipTimeEnd"][time]
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
length = loader.GetAttrs()["TOOLIT_Clip_Length"][1] - 1
|
||||
if trim_from_start > length:
|
||||
trim_from_start = length
|
||||
if log:
|
||||
log.warning(
|
||||
"Reducing trim in to %d "
|
||||
"(because of less frames)" % trim_from_start
|
||||
)
|
||||
|
||||
remainder = length - trim_from_start
|
||||
if trim_from_end > remainder:
|
||||
trim_from_end = remainder
|
||||
if log:
|
||||
log.warning(
|
||||
"Reducing trim in to %d "
|
||||
"(because of less frames)" % trim_from_end
|
||||
)
|
||||
|
||||
loader["ClipTimeStart"][time] = trim_from_start
|
||||
loader["ClipTimeEnd"][time] = length - trim_from_end
|
||||
|
||||
|
||||
def loader_shift(loader, frame, relative=True):
|
||||
"""Shift global in time by i preserving duration
|
||||
|
||||
This moves the loader by i frames preserving global duration. When relative
|
||||
is False it will shift the global in to the start frame.
|
||||
|
||||
Args:
|
||||
loader (tool): The fusion loader tool.
|
||||
frame (int): The amount of frames to move.
|
||||
relative (bool): When True the shift is relative, else the shift will
|
||||
change the global in to frame.
|
||||
|
||||
Returns:
|
||||
int: The resulting relative frame change (how much it moved)
|
||||
|
||||
"""
|
||||
comp = loader.Comp()
|
||||
time = comp.TIME_UNDEFINED
|
||||
|
||||
old_in = loader["GlobalIn"][time]
|
||||
old_out = loader["GlobalOut"][time]
|
||||
|
||||
if relative:
|
||||
shift = frame
|
||||
else:
|
||||
shift = frame - old_in
|
||||
|
||||
if not shift:
|
||||
return 0
|
||||
|
||||
# Shifting global in will try to automatically compensate for the change
|
||||
# in the "ClipTimeStart" and "HoldFirstFrame" inputs, so we preserve those
|
||||
# input values to "just shift" the clip
|
||||
with preserve_inputs(
|
||||
loader,
|
||||
inputs=[
|
||||
"ClipTimeStart",
|
||||
"ClipTimeEnd",
|
||||
"HoldFirstFrame",
|
||||
"HoldLastFrame",
|
||||
],
|
||||
):
|
||||
# GlobalIn cannot be set past GlobalOut or vice versa
|
||||
# so we must apply them in the order of the shift.
|
||||
if shift > 0:
|
||||
loader["GlobalOut"][time] = old_out + shift
|
||||
loader["GlobalIn"][time] = old_in + shift
|
||||
else:
|
||||
loader["GlobalIn"][time] = old_in + shift
|
||||
loader["GlobalOut"][time] = old_out + shift
|
||||
|
||||
return int(shift)
|
||||
|
||||
|
||||
class FusionLoadSequence(load.LoaderPlugin):
|
||||
"""Load image sequence into Fusion"""
|
||||
|
||||
product_types = {
|
||||
"imagesequence",
|
||||
"review",
|
||||
"render",
|
||||
"plate",
|
||||
"image",
|
||||
"online",
|
||||
}
|
||||
representations = {"*"}
|
||||
extensions = set(
|
||||
ext.lstrip(".") for ext in IMAGE_EXTENSIONS.union(VIDEO_EXTENSIONS)
|
||||
)
|
||||
|
||||
label = "Load sequence"
|
||||
order = -10
|
||||
icon = "code-fork"
|
||||
color = "orange"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
# Fallback to folder name when namespace is None
|
||||
if namespace is None:
|
||||
namespace = context["folder"]["name"]
|
||||
|
||||
# Use the first file for now
|
||||
path = self.filepath_from_context(context)
|
||||
|
||||
# Create the Loader with the filename path set
|
||||
comp = get_current_comp()
|
||||
with comp_lock_and_undo_chunk(comp, "Create Loader"):
|
||||
args = (-32768, -32768)
|
||||
tool = comp.AddTool("Loader", *args)
|
||||
tool["Clip"] = comp.ReverseMapPath(path)
|
||||
|
||||
# Set global in point to start frame (if in version.data)
|
||||
start = self._get_start(context["version"], tool)
|
||||
loader_shift(tool, start, relative=False)
|
||||
|
||||
imprint_container(
|
||||
tool,
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
context=context,
|
||||
loader=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def switch(self, container, context):
|
||||
self.update(container, context)
|
||||
|
||||
def update(self, container, context):
|
||||
"""Update the Loader's path
|
||||
|
||||
Fusion automatically tries to reset some variables when changing
|
||||
the loader's path to a new file. These automatic changes are to its
|
||||
inputs:
|
||||
- ClipTimeStart: Fusion reset to 0 if duration changes
|
||||
- We keep the trim in as close as possible to the previous value.
|
||||
When there are less frames then the amount of trim we reduce
|
||||
it accordingly.
|
||||
|
||||
- ClipTimeEnd: Fusion reset to 0 if duration changes
|
||||
- We keep the trim out as close as possible to the previous value
|
||||
within new amount of frames after trim in (ClipTimeStart) has
|
||||
been set.
|
||||
|
||||
- GlobalIn: Fusion reset to comp's global in if duration changes
|
||||
- We change it to the "frameStart"
|
||||
|
||||
- GlobalEnd: Fusion resets to globalIn + length if duration changes
|
||||
- We do the same like Fusion - allow fusion to take control.
|
||||
|
||||
- HoldFirstFrame: Fusion resets this to 0
|
||||
- We preserve the value.
|
||||
|
||||
- HoldLastFrame: Fusion resets this to 0
|
||||
- We preserve the value.
|
||||
|
||||
- Reverse: Fusion resets to disabled if "Loop" is not enabled.
|
||||
- We preserve the value.
|
||||
|
||||
- Depth: Fusion resets to "Format"
|
||||
- We preserve the value.
|
||||
|
||||
- KeyCode: Fusion resets to ""
|
||||
- We preserve the value.
|
||||
|
||||
- TimeCodeOffset: Fusion resets to 0
|
||||
- We preserve the value.
|
||||
|
||||
"""
|
||||
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == "Loader", "Must be Loader"
|
||||
comp = tool.Comp()
|
||||
|
||||
repre_entity = context["representation"]
|
||||
path = self.filepath_from_context(context)
|
||||
|
||||
# Get start frame from version data
|
||||
start = self._get_start(context["version"], tool)
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Update Loader"):
|
||||
# Update the loader's path whilst preserving some values
|
||||
with preserve_trim(tool, log=self.log):
|
||||
with preserve_inputs(
|
||||
tool,
|
||||
inputs=(
|
||||
"HoldFirstFrame",
|
||||
"HoldLastFrame",
|
||||
"Reverse",
|
||||
"Depth",
|
||||
"KeyCode",
|
||||
"TimeCodeOffset",
|
||||
),
|
||||
):
|
||||
tool["Clip"] = comp.ReverseMapPath(path)
|
||||
|
||||
# Set the global in to the start frame of the sequence
|
||||
global_in_changed = loader_shift(tool, start, relative=False)
|
||||
if global_in_changed:
|
||||
# Log this change to the user
|
||||
self.log.debug(
|
||||
"Changed '%s' global in: %d" % (tool.Name, start)
|
||||
)
|
||||
|
||||
# Update the imprinted representation
|
||||
tool.SetData("avalon.representation", repre_entity["id"])
|
||||
|
||||
def remove(self, container):
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == "Loader", "Must be Loader"
|
||||
comp = tool.Comp()
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Remove Loader"):
|
||||
tool.Delete()
|
||||
|
||||
def _get_start(self, version_entity, tool):
|
||||
"""Return real start frame of published files (incl. handles)"""
|
||||
attributes = version_entity["attrib"]
|
||||
|
||||
# Get start frame directly with handle if it's in data
|
||||
start = attributes.get("frameStartHandle")
|
||||
if start is not None:
|
||||
return start
|
||||
|
||||
# Get frame start without handles
|
||||
start = attributes.get("frameStart")
|
||||
if start is None:
|
||||
self.log.warning(
|
||||
"Missing start frame for version "
|
||||
"assuming starts at frame 0 for: "
|
||||
"{}".format(tool.Name)
|
||||
)
|
||||
return 0
|
||||
|
||||
# Use `handleStart` if the data is available
|
||||
handle_start = attributes.get("handleStart")
|
||||
if handle_start:
|
||||
start -= handle_start
|
||||
|
||||
return start
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
from ayon_core.pipeline import (
|
||||
load,
|
||||
get_representation_path,
|
||||
)
|
||||
from ayon_fusion.api import (
|
||||
imprint_container,
|
||||
get_current_comp,
|
||||
comp_lock_and_undo_chunk
|
||||
)
|
||||
from ayon_fusion.api.lib import get_fusion_module
|
||||
|
||||
|
||||
class FusionLoadUSD(load.LoaderPlugin):
|
||||
"""Load USD into Fusion
|
||||
|
||||
Support for USD was added since Fusion 18.5
|
||||
"""
|
||||
|
||||
product_types = {"*"}
|
||||
representations = {"*"}
|
||||
extensions = {"usd", "usda", "usdz"}
|
||||
|
||||
label = "Load USD"
|
||||
order = -10
|
||||
icon = "code-fork"
|
||||
color = "orange"
|
||||
|
||||
tool_type = "uLoader"
|
||||
|
||||
@classmethod
|
||||
def apply_settings(cls, project_settings):
|
||||
super(FusionLoadUSD, cls).apply_settings(project_settings)
|
||||
if cls.enabled:
|
||||
# Enable only in Fusion 18.5+
|
||||
fusion = get_fusion_module()
|
||||
version = fusion.GetVersion()
|
||||
major = version[1]
|
||||
minor = version[2]
|
||||
is_usd_supported = (major, minor) >= (18, 5)
|
||||
cls.enabled = is_usd_supported
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
# Fallback to folder name when namespace is None
|
||||
if namespace is None:
|
||||
namespace = context["folder"]["name"]
|
||||
|
||||
# Create the Loader with the filename path set
|
||||
comp = get_current_comp()
|
||||
with comp_lock_and_undo_chunk(comp, "Create tool"):
|
||||
|
||||
path = self.fname
|
||||
|
||||
args = (-32768, -32768)
|
||||
tool = comp.AddTool(self.tool_type, *args)
|
||||
tool["Filename"] = path
|
||||
|
||||
imprint_container(tool,
|
||||
name=name,
|
||||
namespace=namespace,
|
||||
context=context,
|
||||
loader=self.__class__.__name__)
|
||||
|
||||
def switch(self, container, context):
|
||||
self.update(container, context)
|
||||
|
||||
def update(self, container, context):
|
||||
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
repre_entity = context["representation"]
|
||||
path = get_representation_path(repre_entity)
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Update tool"):
|
||||
tool["Filename"] = path
|
||||
|
||||
# Update the imprinted representation
|
||||
tool.SetData("avalon.representation", repre_entity["id"])
|
||||
|
||||
def remove(self, container):
|
||||
tool = container["_tool"]
|
||||
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
|
||||
comp = tool.Comp()
|
||||
|
||||
with comp_lock_and_undo_chunk(comp, "Remove tool"):
|
||||
tool.Delete()
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""Import workfiles into your current comp.
|
||||
As all imported nodes are free floating and will probably be changed there
|
||||
is no update or reload function added for this plugin
|
||||
"""
|
||||
|
||||
from ayon_core.pipeline import load
|
||||
|
||||
from ayon_fusion.api import (
|
||||
get_current_comp,
|
||||
get_bmd_library,
|
||||
)
|
||||
|
||||
|
||||
class FusionLoadWorkfile(load.LoaderPlugin):
|
||||
"""Load the content of a workfile into Fusion"""
|
||||
|
||||
product_types = {"workfile"}
|
||||
representations = {"*"}
|
||||
extensions = {"comp"}
|
||||
|
||||
label = "Load Workfile"
|
||||
order = -10
|
||||
icon = "code-fork"
|
||||
color = "orange"
|
||||
|
||||
def load(self, context, name, namespace, data):
|
||||
# Get needed elements
|
||||
bmd = get_bmd_library()
|
||||
comp = get_current_comp()
|
||||
path = self.filepath_from_context(context)
|
||||
|
||||
# Paste the content of the file into the current comp
|
||||
comp.Paste(bmd.readfile(path))
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_fusion.api import get_current_comp
|
||||
|
||||
|
||||
class CollectCurrentCompFusion(pyblish.api.ContextPlugin):
|
||||
"""Collect current comp"""
|
||||
|
||||
order = pyblish.api.CollectorOrder - 0.4
|
||||
label = "Collect Current Comp"
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, context):
|
||||
"""Collect all image sequence tools"""
|
||||
|
||||
current_comp = get_current_comp()
|
||||
assert current_comp, "Must have active Fusion composition"
|
||||
context.data["currentComp"] = current_comp
|
||||
|
||||
# Store path to current file
|
||||
filepath = current_comp.GetAttrs().get("COMPS_FileName", "")
|
||||
context.data['currentFile'] = filepath
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
|
||||
def get_comp_render_range(comp):
|
||||
"""Return comp's start-end render range and global start-end range."""
|
||||
comp_attrs = comp.GetAttrs()
|
||||
start = comp_attrs["COMPN_RenderStart"]
|
||||
end = comp_attrs["COMPN_RenderEnd"]
|
||||
global_start = comp_attrs["COMPN_GlobalStart"]
|
||||
global_end = comp_attrs["COMPN_GlobalEnd"]
|
||||
|
||||
# Whenever render ranges are undefined fall back
|
||||
# to the comp's global start and end
|
||||
if start == -1000000000:
|
||||
start = global_start
|
||||
if end == -1000000000:
|
||||
end = global_end
|
||||
|
||||
return start, end, global_start, global_end
|
||||
|
||||
|
||||
class CollectFusionCompFrameRanges(pyblish.api.ContextPlugin):
|
||||
"""Collect current comp"""
|
||||
|
||||
# We run this after CollectorOrder - 0.1 otherwise it gets
|
||||
# overridden by global plug-in `CollectContextEntities`
|
||||
order = pyblish.api.CollectorOrder - 0.05
|
||||
label = "Collect Comp Frame Ranges"
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, context):
|
||||
"""Collect all image sequence tools"""
|
||||
|
||||
comp = context.data["currentComp"]
|
||||
|
||||
# Store comp render ranges
|
||||
start, end, global_start, global_end = get_comp_render_range(comp)
|
||||
|
||||
context.data.update({
|
||||
"renderFrameStart": int(start),
|
||||
"renderFrameEnd": int(end),
|
||||
"compFrameStart": int(global_start),
|
||||
"compFrameEnd": int(global_end)
|
||||
})
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import registered_host
|
||||
|
||||
|
||||
def collect_input_containers(tools):
|
||||
"""Collect containers that contain any of the node in `nodes`.
|
||||
|
||||
This will return any loaded Avalon container that contains at least one of
|
||||
the nodes. As such, the Avalon container is an input for it. Or in short,
|
||||
there are member nodes of that container.
|
||||
|
||||
Returns:
|
||||
list: Input avalon containers
|
||||
|
||||
"""
|
||||
|
||||
# Lookup by node ids
|
||||
lookup = frozenset([tool.Name for tool in tools])
|
||||
|
||||
containers = []
|
||||
host = registered_host()
|
||||
for container in host.ls():
|
||||
|
||||
name = container["_tool"].Name
|
||||
|
||||
# We currently assume no "groups" as containers but just single tools
|
||||
# like a single "Loader" operator. As such we just check whether the
|
||||
# Loader is part of the processing queue.
|
||||
if name in lookup:
|
||||
containers.append(container)
|
||||
|
||||
return containers
|
||||
|
||||
|
||||
def iter_upstream(tool):
|
||||
"""Yields all upstream inputs for the current tool.
|
||||
|
||||
Yields:
|
||||
tool: The input tools.
|
||||
|
||||
"""
|
||||
|
||||
def get_connected_input_tools(tool):
|
||||
"""Helper function that returns connected input tools for a tool."""
|
||||
inputs = []
|
||||
|
||||
# Filter only to actual types that will have sensible upstream
|
||||
# connections. So we ignore just "Number" inputs as they can be
|
||||
# many to iterate, slowing things down quite a bit - and in practice
|
||||
# they don't have upstream connections.
|
||||
VALID_INPUT_TYPES = ['Image', 'Particles', 'Mask', 'DataType3D']
|
||||
for type_ in VALID_INPUT_TYPES:
|
||||
for input_ in tool.GetInputList(type_).values():
|
||||
output = input_.GetConnectedOutput()
|
||||
if output:
|
||||
input_tool = output.GetTool()
|
||||
inputs.append(input_tool)
|
||||
|
||||
return inputs
|
||||
|
||||
# Initialize process queue with the node's inputs itself
|
||||
queue = get_connected_input_tools(tool)
|
||||
|
||||
# We keep track of which node names we have processed so far, to ensure we
|
||||
# don't process the same hierarchy again. We are not pushing the tool
|
||||
# itself into the set as that doesn't correctly recognize the same tool.
|
||||
# Since tool names are unique in a comp in Fusion we rely on that.
|
||||
collected = set(tool.Name for tool in queue)
|
||||
|
||||
# Traverse upstream references for all nodes and yield them as we
|
||||
# process the queue.
|
||||
while queue:
|
||||
upstream_tool = queue.pop()
|
||||
yield upstream_tool
|
||||
|
||||
# Find upstream tools that are not collected yet.
|
||||
upstream_inputs = get_connected_input_tools(upstream_tool)
|
||||
upstream_inputs = [t for t in upstream_inputs if
|
||||
t.Name not in collected]
|
||||
|
||||
queue.extend(upstream_inputs)
|
||||
collected.update(tool.Name for tool in upstream_inputs)
|
||||
|
||||
|
||||
class CollectUpstreamInputs(pyblish.api.InstancePlugin):
|
||||
"""Collect source input containers used for this publish.
|
||||
|
||||
This will include `inputs` data of which loaded publishes were used in the
|
||||
generation of this publish. This leaves an upstream trace to what was used
|
||||
as input.
|
||||
|
||||
"""
|
||||
|
||||
label = "Collect Inputs"
|
||||
order = pyblish.api.CollectorOrder + 0.2
|
||||
hosts = ["fusion"]
|
||||
families = ["render", "image"]
|
||||
|
||||
def process(self, instance):
|
||||
|
||||
# Get all upstream and include itself
|
||||
if not any(instance[:]):
|
||||
self.log.debug("No tool found in instance, skipping..")
|
||||
return
|
||||
|
||||
tool = instance[0]
|
||||
nodes = list(iter_upstream(tool))
|
||||
nodes.append(tool)
|
||||
|
||||
# Collect containers for the given set of nodes
|
||||
containers = collect_input_containers(nodes)
|
||||
|
||||
inputs = [c["representation"] for c in containers]
|
||||
instance.data["inputRepresentations"] = inputs
|
||||
self.log.debug("Collected inputs: %s" % inputs)
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
|
||||
class CollectInstanceData(pyblish.api.InstancePlugin):
|
||||
"""Collect Fusion saver instances
|
||||
|
||||
This additionally stores the Comp start and end render range in the
|
||||
current context's data as "frameStart" and "frameEnd".
|
||||
|
||||
"""
|
||||
|
||||
order = pyblish.api.CollectorOrder
|
||||
label = "Collect Instances Data"
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, instance):
|
||||
"""Collect all image sequence tools"""
|
||||
|
||||
context = instance.context
|
||||
|
||||
# Include creator attributes directly as instance data
|
||||
creator_attributes = instance.data["creator_attributes"]
|
||||
instance.data.update(creator_attributes)
|
||||
|
||||
frame_range_source = creator_attributes.get("frame_range_source")
|
||||
instance.data["frame_range_source"] = frame_range_source
|
||||
|
||||
# get folder frame ranges to all instances
|
||||
# render product type instances `current_folder` render target
|
||||
start = context.data["frameStart"]
|
||||
end = context.data["frameEnd"]
|
||||
handle_start = context.data["handleStart"]
|
||||
handle_end = context.data["handleEnd"]
|
||||
start_with_handle = start - handle_start
|
||||
end_with_handle = end + handle_end
|
||||
|
||||
# conditions for render product type instances
|
||||
if frame_range_source == "render_range":
|
||||
# set comp render frame ranges
|
||||
start = context.data["renderFrameStart"]
|
||||
end = context.data["renderFrameEnd"]
|
||||
handle_start = 0
|
||||
handle_end = 0
|
||||
start_with_handle = start
|
||||
end_with_handle = end
|
||||
|
||||
if frame_range_source == "comp_range":
|
||||
comp_start = context.data["compFrameStart"]
|
||||
comp_end = context.data["compFrameEnd"]
|
||||
render_start = context.data["renderFrameStart"]
|
||||
render_end = context.data["renderFrameEnd"]
|
||||
# set comp frame ranges
|
||||
start = render_start
|
||||
end = render_end
|
||||
handle_start = render_start - comp_start
|
||||
handle_end = comp_end - render_end
|
||||
start_with_handle = comp_start
|
||||
end_with_handle = comp_end
|
||||
|
||||
if frame_range_source == "custom_range":
|
||||
start = int(instance.data["custom_frameStart"])
|
||||
end = int(instance.data["custom_frameEnd"])
|
||||
handle_start = int(instance.data["custom_handleStart"])
|
||||
handle_end = int(instance.data["custom_handleEnd"])
|
||||
start_with_handle = start - handle_start
|
||||
end_with_handle = end + handle_end
|
||||
|
||||
frame = instance.data["creator_attributes"].get("frame")
|
||||
# explicitly publishing only single frame
|
||||
if frame is not None:
|
||||
frame = int(frame)
|
||||
|
||||
start = frame
|
||||
end = frame
|
||||
handle_start = 0
|
||||
handle_end = 0
|
||||
start_with_handle = frame
|
||||
end_with_handle = frame
|
||||
|
||||
# Include start and end render frame in label
|
||||
product_name = instance.data["productName"]
|
||||
label = (
|
||||
"{product_name} ({start}-{end}) [{handle_start}-{handle_end}]"
|
||||
).format(
|
||||
product_name=product_name,
|
||||
start=int(start),
|
||||
end=int(end),
|
||||
handle_start=int(handle_start),
|
||||
handle_end=int(handle_end)
|
||||
)
|
||||
|
||||
instance.data.update({
|
||||
"label": label,
|
||||
|
||||
# todo: Allow custom frame range per instance
|
||||
"frameStart": start,
|
||||
"frameEnd": end,
|
||||
"frameStartHandle": start_with_handle,
|
||||
"frameEndHandle": end_with_handle,
|
||||
"handleStart": handle_start,
|
||||
"handleEnd": handle_end,
|
||||
"fps": context.data["fps"],
|
||||
})
|
||||
|
||||
# Add review family if the instance is marked as 'review'
|
||||
# This could be done through a 'review' Creator attribute.
|
||||
if instance.data.get("review", False):
|
||||
self.log.debug("Adding review family..")
|
||||
instance.data["families"].append("review")
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
import os
|
||||
import attr
|
||||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import publish
|
||||
from ayon_core.pipeline.publish import RenderInstance
|
||||
from ayon_fusion.api.lib import get_frame_path
|
||||
|
||||
|
||||
@attr.s
|
||||
class FusionRenderInstance(RenderInstance):
|
||||
# extend generic, composition name is needed
|
||||
fps = attr.ib(default=None)
|
||||
projectEntity = attr.ib(default=None)
|
||||
stagingDir = attr.ib(default=None)
|
||||
app_version = attr.ib(default=None)
|
||||
tool = attr.ib(default=None)
|
||||
workfileComp = attr.ib(default=None)
|
||||
publish_attributes = attr.ib(default={})
|
||||
frameStartHandle = attr.ib(default=None)
|
||||
frameEndHandle = attr.ib(default=None)
|
||||
|
||||
|
||||
class CollectFusionRender(
|
||||
publish.AbstractCollectRender,
|
||||
publish.ColormanagedPyblishPluginMixin
|
||||
):
|
||||
|
||||
order = pyblish.api.CollectorOrder + 0.09
|
||||
label = "Collect Fusion Render"
|
||||
hosts = ["fusion"]
|
||||
|
||||
def get_instances(self, context):
|
||||
|
||||
comp = context.data.get("currentComp")
|
||||
comp_frame_format_prefs = comp.GetPrefs("Comp.FrameFormat")
|
||||
aspect_x = comp_frame_format_prefs["AspectX"]
|
||||
aspect_y = comp_frame_format_prefs["AspectY"]
|
||||
|
||||
|
||||
current_file = context.data["currentFile"]
|
||||
version = context.data["version"]
|
||||
|
||||
project_entity = context.data["projectEntity"]
|
||||
|
||||
instances = []
|
||||
for inst in context:
|
||||
if not inst.data.get("active", True):
|
||||
continue
|
||||
|
||||
product_type = inst.data["productType"]
|
||||
if product_type not in ["render", "image"]:
|
||||
continue
|
||||
|
||||
task_name = inst.data["task"]
|
||||
tool = inst.data["transientData"]["tool"]
|
||||
|
||||
instance_families = inst.data.get("families", [])
|
||||
product_name = inst.data["productName"]
|
||||
instance = FusionRenderInstance(
|
||||
tool=tool,
|
||||
workfileComp=comp,
|
||||
productType=product_type,
|
||||
family=product_type,
|
||||
families=instance_families,
|
||||
version=version,
|
||||
time="",
|
||||
source=current_file,
|
||||
label=inst.data["label"],
|
||||
productName=product_name,
|
||||
folderPath=inst.data["folderPath"],
|
||||
task=task_name,
|
||||
attachTo=False,
|
||||
setMembers='',
|
||||
publish=True,
|
||||
name=product_name,
|
||||
resolutionWidth=comp_frame_format_prefs.get("Width"),
|
||||
resolutionHeight=comp_frame_format_prefs.get("Height"),
|
||||
pixelAspect=aspect_x / aspect_y,
|
||||
tileRendering=False,
|
||||
tilesX=0,
|
||||
tilesY=0,
|
||||
review="review" in instance_families,
|
||||
frameStart=inst.data["frameStart"],
|
||||
frameEnd=inst.data["frameEnd"],
|
||||
handleStart=inst.data["handleStart"],
|
||||
handleEnd=inst.data["handleEnd"],
|
||||
frameStartHandle=inst.data["frameStartHandle"],
|
||||
frameEndHandle=inst.data["frameEndHandle"],
|
||||
frameStep=1,
|
||||
fps=comp_frame_format_prefs.get("Rate"),
|
||||
app_version=comp.GetApp().Version,
|
||||
publish_attributes=inst.data.get("publish_attributes", {}),
|
||||
|
||||
# The source instance this render instance replaces
|
||||
source_instance=inst
|
||||
)
|
||||
|
||||
render_target = inst.data["creator_attributes"]["render_target"]
|
||||
|
||||
# Add render target family
|
||||
render_target_family = f"render.{render_target}"
|
||||
if render_target_family not in instance.families:
|
||||
instance.families.append(render_target_family)
|
||||
|
||||
# Add render target specific data
|
||||
if render_target in {"local", "frames"}:
|
||||
instance.projectEntity = project_entity
|
||||
|
||||
if render_target == "farm":
|
||||
fam = "render.farm"
|
||||
if fam not in instance.families:
|
||||
instance.families.append(fam)
|
||||
instance.farm = True # to skip integrate
|
||||
if "review" in instance.families:
|
||||
# to skip ExtractReview locally
|
||||
instance.families.remove("review")
|
||||
instance.deadline = inst.data.get("deadline")
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
return instances
|
||||
|
||||
def post_collecting_action(self):
|
||||
for instance in self._context:
|
||||
if "render.frames" in instance.data.get("families", []):
|
||||
# adding representation data to the instance
|
||||
self._update_for_frames(instance)
|
||||
|
||||
def get_expected_files(self, render_instance):
|
||||
"""
|
||||
Returns list of rendered files that should be created by
|
||||
Deadline. These are not published directly, they are source
|
||||
for later 'submit_publish_job'.
|
||||
|
||||
Args:
|
||||
render_instance (RenderInstance): to pull anatomy and parts used
|
||||
in url
|
||||
|
||||
Returns:
|
||||
(list) of absolute urls to rendered file
|
||||
"""
|
||||
start = render_instance.frameStart - render_instance.handleStart
|
||||
end = render_instance.frameEnd + render_instance.handleEnd
|
||||
|
||||
comp = render_instance.workfileComp
|
||||
path = comp.MapPath(
|
||||
render_instance.tool["Clip"][
|
||||
render_instance.workfileComp.TIME_UNDEFINED
|
||||
]
|
||||
)
|
||||
output_dir = os.path.dirname(path)
|
||||
render_instance.outputDir = output_dir
|
||||
|
||||
basename = os.path.basename(path)
|
||||
|
||||
head, padding, ext = get_frame_path(basename)
|
||||
|
||||
expected_files = []
|
||||
for frame in range(start, end + 1):
|
||||
expected_files.append(
|
||||
os.path.join(
|
||||
output_dir,
|
||||
f"{head}{str(frame).zfill(padding)}{ext}"
|
||||
)
|
||||
)
|
||||
|
||||
return expected_files
|
||||
|
||||
def _update_for_frames(self, instance):
|
||||
"""Updating instance for render.frames family
|
||||
|
||||
Adding representation data to the instance. Also setting
|
||||
colorspaceData to the representation based on file rules.
|
||||
"""
|
||||
|
||||
expected_files = instance.data["expectedFiles"]
|
||||
|
||||
start = instance.data["frameStart"] - instance.data["handleStart"]
|
||||
|
||||
path = expected_files[0]
|
||||
basename = os.path.basename(path)
|
||||
staging_dir = os.path.dirname(path)
|
||||
_, padding, ext = get_frame_path(basename)
|
||||
|
||||
repre = {
|
||||
"name": ext[1:],
|
||||
"ext": ext[1:],
|
||||
"frameStart": f"%0{padding}d" % start,
|
||||
"files": [os.path.basename(f) for f in expected_files],
|
||||
"stagingDir": staging_dir,
|
||||
}
|
||||
|
||||
self.set_representation_colorspace(
|
||||
representation=repre,
|
||||
context=instance.context,
|
||||
)
|
||||
|
||||
# review representation
|
||||
if instance.data.get("review", False):
|
||||
repre["tags"] = ["review"]
|
||||
|
||||
# add the repre to the instance
|
||||
if "representations" not in instance.data:
|
||||
instance.data["representations"] = []
|
||||
instance.data["representations"].append(repre)
|
||||
|
||||
return instance
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
import os
|
||||
|
||||
import pyblish.api
|
||||
|
||||
|
||||
class CollectFusionWorkfile(pyblish.api.InstancePlugin):
|
||||
"""Collect Fusion workfile representation."""
|
||||
|
||||
order = pyblish.api.CollectorOrder + 0.1
|
||||
label = "Collect Workfile"
|
||||
hosts = ["fusion"]
|
||||
families = ["workfile"]
|
||||
|
||||
def process(self, instance):
|
||||
|
||||
current_file = instance.context.data["currentFile"]
|
||||
|
||||
folder, file = os.path.split(current_file)
|
||||
filename, ext = os.path.splitext(file)
|
||||
|
||||
instance.data['representations'] = [{
|
||||
'name': ext.lstrip("."),
|
||||
'ext': ext.lstrip("."),
|
||||
'files': file,
|
||||
"stagingDir": folder,
|
||||
}]
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
import contextlib
|
||||
import collections
|
||||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import publish
|
||||
from ayon_fusion.api import comp_lock_and_undo_chunk
|
||||
from ayon_fusion.api.lib import get_frame_path, maintained_comp_range
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enabled_savers(comp, savers):
|
||||
"""Enable only the `savers` in Comp during the context.
|
||||
|
||||
Any Saver tool in the passed composition that is not in the savers list
|
||||
will be set to passthrough during the context.
|
||||
|
||||
Args:
|
||||
comp (object): Fusion composition object.
|
||||
savers (list): List of Saver tool objects.
|
||||
|
||||
"""
|
||||
passthrough_key = "TOOLB_PassThrough"
|
||||
original_states = {}
|
||||
enabled_saver_names = {saver.Name for saver in savers}
|
||||
|
||||
all_savers = comp.GetToolList(False, "Saver").values()
|
||||
savers_by_name = {saver.Name: saver for saver in all_savers}
|
||||
|
||||
try:
|
||||
for saver in all_savers:
|
||||
original_state = saver.GetAttrs()[passthrough_key]
|
||||
original_states[saver.Name] = original_state
|
||||
|
||||
# The passthrough state we want to set (passthrough != enabled)
|
||||
state = saver.Name not in enabled_saver_names
|
||||
if state != original_state:
|
||||
saver.SetAttrs({passthrough_key: state})
|
||||
yield
|
||||
finally:
|
||||
for saver_name, original_state in original_states.items():
|
||||
saver = savers_by_name[saver_name]
|
||||
saver.SetAttrs({"TOOLB_PassThrough": original_state})
|
||||
|
||||
|
||||
class FusionRenderLocal(
|
||||
pyblish.api.InstancePlugin,
|
||||
publish.ColormanagedPyblishPluginMixin
|
||||
):
|
||||
"""Render the current Fusion composition locally."""
|
||||
|
||||
order = pyblish.api.ExtractorOrder - 0.2
|
||||
label = "Render Local"
|
||||
hosts = ["fusion"]
|
||||
families = ["render.local"]
|
||||
|
||||
is_rendered_key = "_fusionrenderlocal_has_rendered"
|
||||
|
||||
def process(self, instance):
|
||||
|
||||
# Start render
|
||||
result = self.render(instance)
|
||||
if result is False:
|
||||
raise RuntimeError(f"Comp render failed for {instance}")
|
||||
|
||||
self._add_representation(instance)
|
||||
|
||||
# Log render status
|
||||
self.log.info(
|
||||
"Rendered '{}' for folder '{}' under the task '{}'".format(
|
||||
instance.data["name"],
|
||||
instance.data["folderPath"],
|
||||
instance.data["task"],
|
||||
)
|
||||
)
|
||||
|
||||
def render(self, instance):
|
||||
"""Render instance.
|
||||
|
||||
We try to render the minimal amount of times by combining the instances
|
||||
that have a matching frame range in one Fusion render. Then for the
|
||||
batch of instances we store whether the render succeeded or failed.
|
||||
|
||||
"""
|
||||
|
||||
if self.is_rendered_key in instance.data:
|
||||
# This instance was already processed in batch with another
|
||||
# instance, so we just return the render result directly
|
||||
self.log.debug(f"Instance {instance} was already rendered")
|
||||
return instance.data[self.is_rendered_key]
|
||||
|
||||
instances_by_frame_range = self.get_render_instances_by_frame_range(
|
||||
instance.context
|
||||
)
|
||||
|
||||
# Render matching batch of instances that share the same frame range
|
||||
frame_range = self.get_instance_render_frame_range(instance)
|
||||
render_instances = instances_by_frame_range[frame_range]
|
||||
|
||||
# We initialize render state false to indicate it wasn't successful
|
||||
# yet to keep track of whether Fusion succeeded. This is for cases
|
||||
# where an error below this might cause the comp render result not
|
||||
# to be stored for the instances of this batch
|
||||
for render_instance in render_instances:
|
||||
render_instance.data[self.is_rendered_key] = False
|
||||
|
||||
savers_to_render = [inst.data["tool"] for inst in render_instances]
|
||||
current_comp = instance.context.data["currentComp"]
|
||||
frame_start, frame_end = frame_range
|
||||
|
||||
self.log.info(
|
||||
f"Starting Fusion render frame range {frame_start}-{frame_end}"
|
||||
)
|
||||
saver_names = ", ".join(saver.Name for saver in savers_to_render)
|
||||
self.log.info(f"Rendering tools: {saver_names}")
|
||||
|
||||
with comp_lock_and_undo_chunk(current_comp):
|
||||
with maintained_comp_range(current_comp):
|
||||
with enabled_savers(current_comp, savers_to_render):
|
||||
result = current_comp.Render(
|
||||
{
|
||||
"Start": frame_start,
|
||||
"End": frame_end,
|
||||
"Wait": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Store the render state for all the rendered instances
|
||||
for render_instance in render_instances:
|
||||
render_instance.data[self.is_rendered_key] = bool(result)
|
||||
|
||||
return result
|
||||
|
||||
def _add_representation(self, instance):
|
||||
"""Add representation to instance"""
|
||||
|
||||
expected_files = instance.data["expectedFiles"]
|
||||
|
||||
start = instance.data["frameStart"] - instance.data["handleStart"]
|
||||
|
||||
path = expected_files[0]
|
||||
_, padding, ext = get_frame_path(path)
|
||||
|
||||
staging_dir = os.path.dirname(path)
|
||||
|
||||
files = [os.path.basename(f) for f in expected_files]
|
||||
if len(expected_files) == 1:
|
||||
files = files[0]
|
||||
|
||||
repre = {
|
||||
"name": ext[1:],
|
||||
"ext": ext[1:],
|
||||
"frameStart": f"%0{padding}d" % start,
|
||||
"files": files,
|
||||
"stagingDir": staging_dir,
|
||||
}
|
||||
|
||||
self.set_representation_colorspace(
|
||||
representation=repre,
|
||||
context=instance.context,
|
||||
)
|
||||
|
||||
# review representation
|
||||
if instance.data.get("review", False):
|
||||
repre["tags"] = ["review"]
|
||||
|
||||
# add the repre to the instance
|
||||
if "representations" not in instance.data:
|
||||
instance.data["representations"] = []
|
||||
instance.data["representations"].append(repre)
|
||||
|
||||
return instance
|
||||
|
||||
def get_render_instances_by_frame_range(self, context):
|
||||
"""Return enabled render.local instances grouped by their frame range.
|
||||
|
||||
Arguments:
|
||||
context (pyblish.Context): The pyblish context
|
||||
|
||||
Returns:
|
||||
dict: (start, end): instances mapping
|
||||
|
||||
"""
|
||||
|
||||
instances_to_render = [
|
||||
instance for instance in context if
|
||||
# Only active instances
|
||||
instance.data.get("publish", True) and
|
||||
# Only render.local instances
|
||||
"render.local" in instance.data.get("families", [])
|
||||
]
|
||||
|
||||
# Instances by frame ranges
|
||||
instances_by_frame_range = collections.defaultdict(list)
|
||||
for instance in instances_to_render:
|
||||
start, end = self.get_instance_render_frame_range(instance)
|
||||
instances_by_frame_range[(start, end)].append(instance)
|
||||
|
||||
return dict(instances_by_frame_range)
|
||||
|
||||
def get_instance_render_frame_range(self, instance):
|
||||
start = instance.data["frameStartHandle"]
|
||||
end = instance.data["frameEndHandle"]
|
||||
return start, end
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import OptionalPyblishPluginMixin
|
||||
from ayon_core.pipeline import KnownPublishError
|
||||
|
||||
|
||||
class FusionIncrementCurrentFile(
|
||||
pyblish.api.ContextPlugin, OptionalPyblishPluginMixin
|
||||
):
|
||||
"""Increment the current file.
|
||||
|
||||
Saves the current file with an increased version number.
|
||||
|
||||
"""
|
||||
|
||||
label = "Increment workfile version"
|
||||
order = pyblish.api.IntegratorOrder + 9.0
|
||||
hosts = ["fusion"]
|
||||
optional = True
|
||||
|
||||
def process(self, context):
|
||||
if not self.is_active(context.data):
|
||||
return
|
||||
|
||||
from ayon_core.lib import version_up
|
||||
from ayon_core.pipeline.publish import get_errored_plugins_from_context
|
||||
|
||||
errored_plugins = get_errored_plugins_from_context(context)
|
||||
if any(
|
||||
plugin.__name__ == "FusionSubmitDeadline"
|
||||
for plugin in errored_plugins
|
||||
):
|
||||
raise KnownPublishError(
|
||||
"Skipping incrementing current file because "
|
||||
"submission to render farm failed."
|
||||
)
|
||||
|
||||
comp = context.data.get("currentComp")
|
||||
assert comp, "Must have comp"
|
||||
|
||||
current_filepath = context.data["currentFile"]
|
||||
new_filepath = version_up(current_filepath)
|
||||
|
||||
comp.Save(new_filepath)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
|
||||
class FusionSaveComp(pyblish.api.ContextPlugin):
|
||||
"""Save current comp"""
|
||||
|
||||
label = "Save current file"
|
||||
order = pyblish.api.ExtractorOrder - 0.49
|
||||
hosts = ["fusion"]
|
||||
families = ["render", "image", "workfile"]
|
||||
|
||||
def process(self, context):
|
||||
|
||||
comp = context.data.get("currentComp")
|
||||
assert comp, "Must have comp"
|
||||
|
||||
current = comp.GetAttrs().get("COMPS_FileName", "")
|
||||
assert context.data['currentFile'] == current
|
||||
|
||||
self.log.info("Saving current file: {}".format(current))
|
||||
comp.Save()
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import (
|
||||
publish,
|
||||
OptionalPyblishPluginMixin,
|
||||
PublishValidationError,
|
||||
)
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateBackgroundDepth(
|
||||
pyblish.api.InstancePlugin, OptionalPyblishPluginMixin
|
||||
):
|
||||
"""Validate if all Background tool are set to float32 bit"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Background Depth 32 bit"
|
||||
hosts = ["fusion"]
|
||||
families = ["render", "image"]
|
||||
optional = True
|
||||
|
||||
actions = [SelectInvalidAction, publish.RepairAction]
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance):
|
||||
context = instance.context
|
||||
comp = context.data.get("currentComp")
|
||||
assert comp, "Must have Comp object"
|
||||
|
||||
backgrounds = comp.GetToolList(False, "Background").values()
|
||||
if not backgrounds:
|
||||
return []
|
||||
|
||||
return [i for i in backgrounds if i.GetInput("Depth") != 4.0]
|
||||
|
||||
def process(self, instance):
|
||||
if not self.is_active(instance.data):
|
||||
return
|
||||
|
||||
invalid = self.get_invalid(instance)
|
||||
if invalid:
|
||||
raise PublishValidationError(
|
||||
"Found {} Backgrounds tools which"
|
||||
" are not set to float32".format(len(invalid)),
|
||||
title=self.label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def repair(cls, instance):
|
||||
comp = instance.context.data.get("currentComp")
|
||||
invalid = cls.get_invalid(instance)
|
||||
for i in invalid:
|
||||
i.SetInput("Depth", 4.0, comp.TIME_UNDEFINED)
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import os
|
||||
|
||||
import pyblish.api
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
|
||||
class ValidateFusionCompSaved(pyblish.api.ContextPlugin):
|
||||
"""Ensure current comp is saved"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Comp Saved"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, context):
|
||||
|
||||
comp = context.data.get("currentComp")
|
||||
assert comp, "Must have Comp object"
|
||||
attrs = comp.GetAttrs()
|
||||
|
||||
filename = attrs["COMPS_FileName"]
|
||||
if not filename:
|
||||
raise PublishValidationError("Comp is not saved.",
|
||||
title=self.label)
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise PublishValidationError(
|
||||
"Comp file does not exist: %s" % filename, title=self.label)
|
||||
|
||||
if attrs["COMPB_Modified"]:
|
||||
self.log.warning("Comp is modified. Save your comp to ensure your "
|
||||
"changes propagate correctly.")
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline.publish import RepairAction
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateCreateFolderChecked(pyblish.api.InstancePlugin):
|
||||
"""Valid if all savers have the input attribute CreateDir checked on
|
||||
|
||||
This attribute ensures that the folders to which the saver will write
|
||||
will be created.
|
||||
"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Create Folder Checked"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
actions = [RepairAction, SelectInvalidAction]
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance):
|
||||
tool = instance.data["tool"]
|
||||
create_dir = tool.GetInput("CreateDir")
|
||||
if create_dir == 0.0:
|
||||
cls.log.error(
|
||||
"%s has Create Folder turned off" % instance[0].Name
|
||||
)
|
||||
return [tool]
|
||||
|
||||
def process(self, instance):
|
||||
invalid = self.get_invalid(instance)
|
||||
if invalid:
|
||||
raise PublishValidationError(
|
||||
"Found Saver with Create Folder During Render checked off",
|
||||
title=self.label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def repair(cls, instance):
|
||||
invalid = cls.get_invalid(instance)
|
||||
for tool in invalid:
|
||||
tool.SetInput("CreateDir", 1.0)
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
import os
|
||||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline.publish import RepairAction
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateLocalFramesExistence(pyblish.api.InstancePlugin):
|
||||
"""Checks if files for savers that's set
|
||||
to publish expected frames exists
|
||||
"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Expected Frames Exists"
|
||||
families = ["render.frames"]
|
||||
hosts = ["fusion"]
|
||||
actions = [RepairAction, SelectInvalidAction]
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance, non_existing_frames=None):
|
||||
if non_existing_frames is None:
|
||||
non_existing_frames = []
|
||||
|
||||
tool = instance.data["tool"]
|
||||
|
||||
expected_files = instance.data["expectedFiles"]
|
||||
|
||||
for file in expected_files:
|
||||
if not os.path.exists(file):
|
||||
cls.log.error(
|
||||
f"Missing file: {file}"
|
||||
)
|
||||
non_existing_frames.append(file)
|
||||
|
||||
if len(non_existing_frames) > 0:
|
||||
cls.log.error(f"Some of {tool.Name}'s files does not exist")
|
||||
return [tool]
|
||||
|
||||
def process(self, instance):
|
||||
non_existing_frames = []
|
||||
invalid = self.get_invalid(instance, non_existing_frames)
|
||||
if invalid:
|
||||
raise PublishValidationError(
|
||||
"{} is set to publish existing frames but "
|
||||
"some frames are missing. "
|
||||
"The missing file(s) are:\n\n{}".format(
|
||||
invalid[0].Name,
|
||||
"\n\n".join(non_existing_frames),
|
||||
),
|
||||
title=self.label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def repair(cls, instance):
|
||||
invalid = cls.get_invalid(instance)
|
||||
if invalid:
|
||||
tool = instance.data["tool"]
|
||||
# Change render target to local to render locally
|
||||
tool.SetData("openpype.creator_attributes.render_target", "local")
|
||||
|
||||
cls.log.info(
|
||||
f"Reload the publisher and {tool.Name} "
|
||||
"will be set to render locally"
|
||||
)
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
import os
|
||||
|
||||
import pyblish.api
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateFilenameHasExtension(pyblish.api.InstancePlugin):
|
||||
"""Ensure the Saver has an extension in the filename path
|
||||
|
||||
This disallows files written as `filename` instead of `filename.frame.ext`.
|
||||
Fusion does not always set an extension for your filename when
|
||||
changing the file format of the saver.
|
||||
|
||||
"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Filename Has Extension"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
actions = [SelectInvalidAction]
|
||||
|
||||
def process(self, instance):
|
||||
invalid = self.get_invalid(instance)
|
||||
if invalid:
|
||||
raise PublishValidationError("Found Saver without an extension",
|
||||
title=self.label)
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance):
|
||||
|
||||
path = instance.data["expectedFiles"][0]
|
||||
fname, ext = os.path.splitext(path)
|
||||
|
||||
if not ext:
|
||||
tool = instance.data["tool"]
|
||||
cls.log.error("%s has no extension specified" % tool.Name)
|
||||
return [tool]
|
||||
|
||||
return []
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
|
||||
class ValidateImageFrame(pyblish.api.InstancePlugin):
|
||||
"""Validates that `image` product type contains only single frame."""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Image Frame"
|
||||
families = ["image"]
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, instance):
|
||||
render_start = instance.data["frameStartHandle"]
|
||||
render_end = instance.data["frameEndHandle"]
|
||||
too_many_frames = (isinstance(instance.data["expectedFiles"], list)
|
||||
and len(instance.data["expectedFiles"]) > 1)
|
||||
|
||||
if render_end - render_start > 0 or too_many_frames:
|
||||
desc = ("Trying to render multiple frames. 'image' product type "
|
||||
"is meant for single frame. Please use 'render' creator.")
|
||||
raise PublishValidationError(
|
||||
title="Frame range outside of comp range",
|
||||
message=desc,
|
||||
description=desc
|
||||
)
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
import pyblish.api
|
||||
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
|
||||
class ValidateInstanceFrameRange(pyblish.api.InstancePlugin):
|
||||
"""Validate instance frame range is within comp's global render range."""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Frame Range"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
|
||||
def process(self, instance):
|
||||
|
||||
context = instance.context
|
||||
global_start = context.data["compFrameStart"]
|
||||
global_end = context.data["compFrameEnd"]
|
||||
|
||||
render_start = instance.data["frameStartHandle"]
|
||||
render_end = instance.data["frameEndHandle"]
|
||||
|
||||
if render_start < global_start or render_end > global_end:
|
||||
|
||||
message = (
|
||||
f"Instance {instance} render frame range "
|
||||
f"({render_start}-{render_end}) is outside of the comp's "
|
||||
f"global render range ({global_start}-{global_end}) and thus "
|
||||
f"can't be rendered. "
|
||||
)
|
||||
description = (
|
||||
f"{message}\n\n"
|
||||
f"Either update the comp's global range or the instance's "
|
||||
f"frame range to ensure the comp's frame range includes the "
|
||||
f"to render frame range for the instance."
|
||||
)
|
||||
raise PublishValidationError(
|
||||
title="Frame range outside of comp range",
|
||||
message=message,
|
||||
description=description
|
||||
)
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Validate if instance context is the same as publish context."""
|
||||
|
||||
import pyblish.api
|
||||
from ayon_fusion.api.action import SelectToolAction
|
||||
from ayon_core.pipeline.publish import (
|
||||
RepairAction,
|
||||
ValidateContentsOrder,
|
||||
PublishValidationError,
|
||||
OptionalPyblishPluginMixin
|
||||
)
|
||||
|
||||
|
||||
class ValidateInstanceInContextFusion(pyblish.api.InstancePlugin,
|
||||
OptionalPyblishPluginMixin):
|
||||
"""Validator to check if instance context matches context of publish.
|
||||
|
||||
When working in per-shot style you always publish data in context of
|
||||
current asset (shot). This validator checks if this is so. It is optional
|
||||
so it can be disabled when needed.
|
||||
"""
|
||||
# Similar to maya and houdini-equivalent `ValidateInstanceInContext`
|
||||
|
||||
order = ValidateContentsOrder
|
||||
label = "Instance in same Context"
|
||||
optional = True
|
||||
hosts = ["fusion"]
|
||||
actions = [SelectToolAction, RepairAction]
|
||||
|
||||
def process(self, instance):
|
||||
if not self.is_active(instance.data):
|
||||
return
|
||||
|
||||
instance_context = self.get_context(instance.data)
|
||||
context = self.get_context(instance.context.data)
|
||||
if instance_context != context:
|
||||
context_label = "{} > {}".format(*context)
|
||||
instance_label = "{} > {}".format(*instance_context)
|
||||
|
||||
raise PublishValidationError(
|
||||
message=(
|
||||
"Instance '{}' publishes to different asset than current "
|
||||
"context: {}. Current context: {}".format(
|
||||
instance.name, instance_label, context_label
|
||||
)
|
||||
),
|
||||
description=(
|
||||
"## Publishing to a different asset\n"
|
||||
"There are publish instances present which are publishing "
|
||||
"into a different asset than your current context.\n\n"
|
||||
"Usually this is not what you want but there can be cases "
|
||||
"where you might want to publish into another asset or "
|
||||
"shot. If that's the case you can disable the validation "
|
||||
"on the instance to ignore it."
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def repair(cls, instance):
|
||||
|
||||
create_context = instance.context.data["create_context"]
|
||||
instance_id = instance.data.get("instance_id")
|
||||
created_instance = create_context.get_instance_by_id(
|
||||
instance_id
|
||||
)
|
||||
if created_instance is None:
|
||||
raise RuntimeError(
|
||||
f"No CreatedInstances found with id '{instance_id} "
|
||||
f"in {create_context.instances_by_id}"
|
||||
)
|
||||
|
||||
context_asset, context_task = cls.get_context(instance.context.data)
|
||||
created_instance["folderPath"] = context_asset
|
||||
created_instance["task"] = context_task
|
||||
create_context.save_changes()
|
||||
|
||||
@staticmethod
|
||||
def get_context(data):
|
||||
"""Return asset, task from publishing context data"""
|
||||
return data["folderPath"], data["task"]
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
import pyblish.api
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateSaverHasInput(pyblish.api.InstancePlugin):
|
||||
"""Validate saver has incoming connection
|
||||
|
||||
This ensures a Saver has at least an input connection.
|
||||
|
||||
"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Saver Has Input"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
actions = [SelectInvalidAction]
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance):
|
||||
|
||||
saver = instance.data["tool"]
|
||||
if not saver.Input.GetConnectedOutput():
|
||||
return [saver]
|
||||
|
||||
return []
|
||||
|
||||
def process(self, instance):
|
||||
invalid = self.get_invalid(instance)
|
||||
if invalid:
|
||||
saver_name = invalid[0].Name
|
||||
raise PublishValidationError(
|
||||
"Saver has no incoming connection: {} ({})".format(instance,
|
||||
saver_name),
|
||||
title=self.label)
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
import pyblish.api
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateSaverPassthrough(pyblish.api.ContextPlugin):
|
||||
"""Validate saver passthrough is similar to Pyblish publish state"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Saver Passthrough"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
actions = [SelectInvalidAction]
|
||||
|
||||
def process(self, context):
|
||||
|
||||
# Workaround for ContextPlugin always running, even if no instance
|
||||
# is present with the family
|
||||
instances = pyblish.api.instances_by_plugin(instances=list(context),
|
||||
plugin=self)
|
||||
if not instances:
|
||||
self.log.debug("Ignoring plugin.. (bugfix)")
|
||||
|
||||
invalid_instances = []
|
||||
for instance in instances:
|
||||
invalid = self.is_invalid(instance)
|
||||
if invalid:
|
||||
invalid_instances.append(instance)
|
||||
|
||||
if invalid_instances:
|
||||
self.log.info("Reset pyblish to collect your current scene state, "
|
||||
"that should fix error.")
|
||||
raise PublishValidationError(
|
||||
"Invalid instances: {0}".format(invalid_instances),
|
||||
title=self.label)
|
||||
|
||||
def is_invalid(self, instance):
|
||||
|
||||
saver = instance.data["tool"]
|
||||
attr = saver.GetAttrs()
|
||||
active = not attr["TOOLB_PassThrough"]
|
||||
|
||||
if active != instance.data.get("publish", True):
|
||||
self.log.info("Saver has different passthrough state than "
|
||||
"Pyblish: {} ({})".format(instance, saver.Name))
|
||||
return [saver]
|
||||
|
||||
return []
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
import pyblish.api
|
||||
from ayon_core.pipeline import (
|
||||
PublishValidationError,
|
||||
OptionalPyblishPluginMixin,
|
||||
)
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
from ayon_fusion.api import comp_lock_and_undo_chunk
|
||||
|
||||
|
||||
class ValidateSaverResolution(
|
||||
pyblish.api.InstancePlugin, OptionalPyblishPluginMixin
|
||||
):
|
||||
"""Validate that the saver input resolution matches the folder resolution"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Folder Resolution"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
optional = True
|
||||
actions = [SelectInvalidAction]
|
||||
|
||||
def process(self, instance):
|
||||
if not self.is_active(instance.data):
|
||||
return
|
||||
|
||||
resolution = self.get_resolution(instance)
|
||||
expected_resolution = self.get_expected_resolution(instance)
|
||||
if resolution != expected_resolution:
|
||||
raise PublishValidationError(
|
||||
"The input's resolution does not match "
|
||||
"the folder's resolution {}x{}.\n\n"
|
||||
"The input's resolution is {}x{}.".format(
|
||||
expected_resolution[0], expected_resolution[1],
|
||||
resolution[0], resolution[1]
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, instance):
|
||||
saver = instance.data["tool"]
|
||||
try:
|
||||
resolution = cls.get_resolution(instance)
|
||||
except PublishValidationError:
|
||||
resolution = None
|
||||
expected_resolution = cls.get_expected_resolution(instance)
|
||||
if resolution != expected_resolution:
|
||||
return [saver]
|
||||
|
||||
@classmethod
|
||||
def get_resolution(cls, instance):
|
||||
saver = instance.data["tool"]
|
||||
first_frame = instance.data["frameStartHandle"]
|
||||
return cls.get_tool_resolution(saver, frame=first_frame)
|
||||
|
||||
@classmethod
|
||||
def get_expected_resolution(cls, instance):
|
||||
attributes = instance.data["folderEntity"]["attrib"]
|
||||
return attributes["resolutionWidth"], attributes["resolutionHeight"]
|
||||
|
||||
@classmethod
|
||||
def get_tool_resolution(cls, tool, frame):
|
||||
"""Return the 2D input resolution to a Fusion tool
|
||||
|
||||
If the current tool hasn't been rendered its input resolution
|
||||
hasn't been saved. To combat this, add an expression in
|
||||
the comments field to read the resolution
|
||||
|
||||
Args
|
||||
tool (Fusion Tool): The tool to query input resolution
|
||||
frame (int): The frame to query the resolution on.
|
||||
|
||||
Returns:
|
||||
tuple: width, height as 2-tuple of integers
|
||||
|
||||
"""
|
||||
comp = tool.Composition
|
||||
|
||||
# False undo removes the undo-stack from the undo list
|
||||
with comp_lock_and_undo_chunk(comp, "Read resolution", False):
|
||||
# Save old comment
|
||||
old_comment = ""
|
||||
has_expression = False
|
||||
|
||||
if tool["Comments"][frame] not in ["", None]:
|
||||
if tool["Comments"].GetExpression() is not None:
|
||||
has_expression = True
|
||||
old_comment = tool["Comments"].GetExpression()
|
||||
tool["Comments"].SetExpression(None)
|
||||
else:
|
||||
old_comment = tool["Comments"][frame]
|
||||
tool["Comments"][frame] = ""
|
||||
# Get input width
|
||||
tool["Comments"].SetExpression("self.Input.OriginalWidth")
|
||||
if tool["Comments"][frame] is None:
|
||||
raise PublishValidationError(
|
||||
"Cannot get resolution info for frame '{}'.\n\n "
|
||||
"Please check that saver has connected input.".format(
|
||||
frame
|
||||
)
|
||||
)
|
||||
|
||||
width = int(tool["Comments"][frame])
|
||||
|
||||
# Get input height
|
||||
tool["Comments"].SetExpression("self.Input.OriginalHeight")
|
||||
height = int(tool["Comments"][frame])
|
||||
|
||||
# Reset old comment
|
||||
tool["Comments"].SetExpression(None)
|
||||
if has_expression:
|
||||
tool["Comments"].SetExpression(old_comment)
|
||||
else:
|
||||
tool["Comments"][frame] = old_comment
|
||||
|
||||
return width, height
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import pyblish.api
|
||||
from ayon_core.pipeline import PublishValidationError
|
||||
|
||||
from ayon_fusion.api.action import SelectInvalidAction
|
||||
|
||||
|
||||
class ValidateUniqueSubsets(pyblish.api.ContextPlugin):
|
||||
"""Ensure all instances have a unique product name"""
|
||||
|
||||
order = pyblish.api.ValidatorOrder
|
||||
label = "Validate Unique Products"
|
||||
families = ["render", "image"]
|
||||
hosts = ["fusion"]
|
||||
actions = [SelectInvalidAction]
|
||||
|
||||
@classmethod
|
||||
def get_invalid(cls, context):
|
||||
|
||||
# Collect instances per product per folder
|
||||
instances_per_product_folder = defaultdict(lambda: defaultdict(list))
|
||||
for instance in context:
|
||||
folder_path = instance.data["folderPath"]
|
||||
product_name = instance.data["productName"]
|
||||
instances_per_product_folder[folder_path][product_name].append(
|
||||
instance
|
||||
)
|
||||
|
||||
# Find which folder + subset combination has more than one instance
|
||||
# Those are considered invalid because they'd integrate to the same
|
||||
# destination.
|
||||
invalid = []
|
||||
for folder_path, instances_per_product in (
|
||||
instances_per_product_folder.items()
|
||||
):
|
||||
for product_name, instances in instances_per_product.items():
|
||||
if len(instances) > 1:
|
||||
cls.log.warning(
|
||||
(
|
||||
"{folder_path} > {product_name} used by more than "
|
||||
"one instance: {instances}"
|
||||
).format(
|
||||
folder_path=folder_path,
|
||||
product_name=product_name,
|
||||
instances=instances
|
||||
)
|
||||
)
|
||||
invalid.extend(instances)
|
||||
|
||||
# Return tools for the invalid instances so they can be selected
|
||||
invalid = [instance.data["tool"] for instance in invalid]
|
||||
|
||||
return invalid
|
||||
|
||||
def process(self, context):
|
||||
invalid = self.get_invalid(context)
|
||||
if invalid:
|
||||
raise PublishValidationError(
|
||||
"Multiple instances are set to the same folder > product.",
|
||||
title=self.label
|
||||
)
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
from ayon_fusion.api import (
|
||||
comp_lock_and_undo_chunk,
|
||||
get_current_comp
|
||||
)
|
||||
|
||||
|
||||
def is_connected(input):
|
||||
"""Return whether an input has incoming connection"""
|
||||
return input.GetAttrs()["INPB_Connected"]
|
||||
|
||||
|
||||
def duplicate_with_input_connections():
|
||||
"""Duplicate selected tools with incoming connections."""
|
||||
|
||||
comp = get_current_comp()
|
||||
original_tools = comp.GetToolList(True).values()
|
||||
if not original_tools:
|
||||
return # nothing selected
|
||||
|
||||
with comp_lock_and_undo_chunk(
|
||||
comp, "Duplicate With Input Connections"):
|
||||
|
||||
# Generate duplicates
|
||||
comp.Copy()
|
||||
comp.SetActiveTool()
|
||||
comp.Paste()
|
||||
duplicate_tools = comp.GetToolList(True).values()
|
||||
|
||||
# Copy connections
|
||||
for original, new in zip(original_tools, duplicate_tools):
|
||||
|
||||
original_inputs = original.GetInputList().values()
|
||||
new_inputs = new.GetInputList().values()
|
||||
assert len(original_inputs) == len(new_inputs)
|
||||
|
||||
for original_input, new_input in zip(original_inputs, new_inputs):
|
||||
|
||||
if is_connected(original_input):
|
||||
|
||||
if is_connected(new_input):
|
||||
# Already connected if it is between the copied tools
|
||||
continue
|
||||
|
||||
new_input.ConnectTo(original_input.GetConnectedOutput())
|
||||
assert is_connected(new_input), "Must be connected now"
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
|
||||
from . import converters, exceptions, filters, setters, validators
|
||||
from ._cmp import cmp_using
|
||||
from ._config import get_run_validators, set_run_validators
|
||||
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
|
||||
from ._make import (
|
||||
NOTHING,
|
||||
Attribute,
|
||||
Factory,
|
||||
attrib,
|
||||
attrs,
|
||||
fields,
|
||||
fields_dict,
|
||||
make_class,
|
||||
validate,
|
||||
)
|
||||
from ._version_info import VersionInfo
|
||||
|
||||
|
||||
__version__ = "21.2.0"
|
||||
__version_info__ = VersionInfo._from_version_string(__version__)
|
||||
|
||||
__title__ = "attrs"
|
||||
__description__ = "Classes Without Boilerplate"
|
||||
__url__ = "https://www.attrs.org/"
|
||||
__uri__ = __url__
|
||||
__doc__ = __description__ + " <" + __uri__ + ">"
|
||||
|
||||
__author__ = "Hynek Schlawack"
|
||||
__email__ = "hs@ox.cx"
|
||||
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
|
||||
|
||||
|
||||
s = attributes = attrs
|
||||
ib = attr = attrib
|
||||
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
|
||||
|
||||
__all__ = [
|
||||
"Attribute",
|
||||
"Factory",
|
||||
"NOTHING",
|
||||
"asdict",
|
||||
"assoc",
|
||||
"astuple",
|
||||
"attr",
|
||||
"attrib",
|
||||
"attributes",
|
||||
"attrs",
|
||||
"cmp_using",
|
||||
"converters",
|
||||
"evolve",
|
||||
"exceptions",
|
||||
"fields",
|
||||
"fields_dict",
|
||||
"filters",
|
||||
"get_run_validators",
|
||||
"has",
|
||||
"ib",
|
||||
"make_class",
|
||||
"resolve_types",
|
||||
"s",
|
||||
"set_run_validators",
|
||||
"setters",
|
||||
"validate",
|
||||
"validators",
|
||||
]
|
||||
|
||||
if sys.version_info[:2] >= (3, 6):
|
||||
from ._next_gen import define, field, frozen, mutable
|
||||
|
||||
__all__.extend((define, field, frozen, mutable))
|
||||
|
|
@ -1,475 +0,0 @@
|
|||
import sys
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
# `import X as X` is required to make these public
|
||||
from . import converters as converters
|
||||
from . import exceptions as exceptions
|
||||
from . import filters as filters
|
||||
from . import setters as setters
|
||||
from . import validators as validators
|
||||
from ._version_info import VersionInfo
|
||||
|
||||
|
||||
__version__: str
|
||||
__version_info__: VersionInfo
|
||||
__title__: str
|
||||
__description__: str
|
||||
__url__: str
|
||||
__uri__: str
|
||||
__author__: str
|
||||
__email__: str
|
||||
__license__: str
|
||||
__copyright__: str
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=type)
|
||||
|
||||
_EqOrderType = Union[bool, Callable[[Any], Any]]
|
||||
_ValidatorType = Callable[[Any, Attribute[_T], _T], Any]
|
||||
_ConverterType = Callable[[Any], Any]
|
||||
_FilterType = Callable[[Attribute[_T], _T], bool]
|
||||
_ReprType = Callable[[Any], str]
|
||||
_ReprArgType = Union[bool, _ReprType]
|
||||
_OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any]
|
||||
_OnSetAttrArgType = Union[
|
||||
_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType
|
||||
]
|
||||
_FieldTransformer = Callable[[type, List[Attribute[Any]]], List[Attribute[Any]]]
|
||||
# FIXME: in reality, if multiple validators are passed they must be in a list
|
||||
# or tuple, but those are invariant and so would prevent subtypes of
|
||||
# _ValidatorType from working when passed in a list or tuple.
|
||||
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
|
||||
|
||||
# _make --
|
||||
|
||||
NOTHING: object
|
||||
|
||||
# NOTE: Factory lies about its return type to make this possible:
|
||||
# `x: List[int] # = Factory(list)`
|
||||
# Work around mypy issue #4554 in the common case by using an overload.
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
|
||||
@overload
|
||||
def Factory(factory: Callable[[], _T]) -> _T: ...
|
||||
@overload
|
||||
def Factory(
|
||||
factory: Callable[[Any], _T],
|
||||
takes_self: Literal[True],
|
||||
) -> _T: ...
|
||||
@overload
|
||||
def Factory(
|
||||
factory: Callable[[], _T],
|
||||
takes_self: Literal[False],
|
||||
) -> _T: ...
|
||||
else:
|
||||
@overload
|
||||
def Factory(factory: Callable[[], _T]) -> _T: ...
|
||||
@overload
|
||||
def Factory(
|
||||
factory: Union[Callable[[Any], _T], Callable[[], _T]],
|
||||
takes_self: bool = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# Static type inference support via __dataclass_transform__ implemented as per:
|
||||
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
|
||||
# This annotation must be applied to all overloads of "define" and "attrs"
|
||||
#
|
||||
# NOTE: This is a typing construct and does not exist at runtime. Extensions
|
||||
# wrapping attrs decorators should declare a separate __dataclass_transform__
|
||||
# signature in the extension module using the specification linked above to
|
||||
# provide pyright support.
|
||||
def __dataclass_transform__(
|
||||
*,
|
||||
eq_default: bool = True,
|
||||
order_default: bool = False,
|
||||
kw_only_default: bool = False,
|
||||
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
|
||||
) -> Callable[[_T], _T]: ...
|
||||
|
||||
class Attribute(Generic[_T]):
|
||||
name: str
|
||||
default: Optional[_T]
|
||||
validator: Optional[_ValidatorType[_T]]
|
||||
repr: _ReprArgType
|
||||
cmp: _EqOrderType
|
||||
eq: _EqOrderType
|
||||
order: _EqOrderType
|
||||
hash: Optional[bool]
|
||||
init: bool
|
||||
converter: Optional[_ConverterType]
|
||||
metadata: Dict[Any, Any]
|
||||
type: Optional[Type[_T]]
|
||||
kw_only: bool
|
||||
on_setattr: _OnSetAttrType
|
||||
|
||||
def evolve(self, **changes: Any) -> "Attribute[Any]": ...
|
||||
|
||||
# NOTE: We had several choices for the annotation to use for type arg:
|
||||
# 1) Type[_T]
|
||||
# - Pros: Handles simple cases correctly
|
||||
# - Cons: Might produce less informative errors in the case of conflicting
|
||||
# TypeVars e.g. `attr.ib(default='bad', type=int)`
|
||||
# 2) Callable[..., _T]
|
||||
# - Pros: Better error messages than #1 for conflicting TypeVars
|
||||
# - Cons: Terrible error messages for validator checks.
|
||||
# e.g. attr.ib(type=int, validator=validate_str)
|
||||
# -> error: Cannot infer function type argument
|
||||
# 3) type (and do all of the work in the mypy plugin)
|
||||
# - Pros: Simple here, and we could customize the plugin with our own errors.
|
||||
# - Cons: Would need to write mypy plugin code to handle all the cases.
|
||||
# We chose option #1.
|
||||
|
||||
# `attr` lies about its return type to make the following possible:
|
||||
# attr() -> Any
|
||||
# attr(8) -> int
|
||||
# attr(validator=<some callable>) -> Whatever the callable expects.
|
||||
# This makes this type of assignments possible:
|
||||
# x: int = attr(8)
|
||||
#
|
||||
# This form catches explicit None or no default but with no other arguments
|
||||
# returns Any.
|
||||
@overload
|
||||
def attrib(
|
||||
default: None = ...,
|
||||
validator: None = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: None = ...,
|
||||
converter: None = ...,
|
||||
factory: None = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> Any: ...
|
||||
|
||||
# This form catches an explicit None or no default and infers the type from the
|
||||
# other arguments.
|
||||
@overload
|
||||
def attrib(
|
||||
default: None = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: Optional[Type[_T]] = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form catches an explicit default argument.
|
||||
@overload
|
||||
def attrib(
|
||||
default: _T,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: Optional[Type[_T]] = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form covers type=non-Type: e.g. forward references (str), Any
|
||||
@overload
|
||||
def attrib(
|
||||
default: Optional[_T] = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
type: object = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> Any: ...
|
||||
@overload
|
||||
def field(
|
||||
*,
|
||||
default: None = ...,
|
||||
validator: None = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
converter: None = ...,
|
||||
factory: None = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[bool] = ...,
|
||||
order: Optional[bool] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> Any: ...
|
||||
|
||||
# This form catches an explicit None or no default and infers the type from the
|
||||
# other arguments.
|
||||
@overload
|
||||
def field(
|
||||
*,
|
||||
default: None = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form catches an explicit default argument.
|
||||
@overload
|
||||
def field(
|
||||
*,
|
||||
default: _T,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> _T: ...
|
||||
|
||||
# This form covers type=non-Type: e.g. forward references (str), Any
|
||||
@overload
|
||||
def field(
|
||||
*,
|
||||
default: Optional[_T] = ...,
|
||||
validator: Optional[_ValidatorArgType[_T]] = ...,
|
||||
repr: _ReprArgType = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
metadata: Optional[Mapping[Any, Any]] = ...,
|
||||
converter: Optional[_ConverterType] = ...,
|
||||
factory: Optional[Callable[[], _T]] = ...,
|
||||
kw_only: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
) -> Any: ...
|
||||
@overload
|
||||
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
|
||||
def attrs(
|
||||
maybe_cls: _C,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
auto_exc: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
auto_detect: bool = ...,
|
||||
collect_by_mro: bool = ...,
|
||||
getstate_setstate: Optional[bool] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
field_transformer: Optional[_FieldTransformer] = ...,
|
||||
) -> _C: ...
|
||||
@overload
|
||||
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
|
||||
def attrs(
|
||||
maybe_cls: None = ...,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
auto_exc: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
auto_detect: bool = ...,
|
||||
collect_by_mro: bool = ...,
|
||||
getstate_setstate: Optional[bool] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
field_transformer: Optional[_FieldTransformer] = ...,
|
||||
) -> Callable[[_C], _C]: ...
|
||||
@overload
|
||||
@__dataclass_transform__(field_descriptors=(attrib, field))
|
||||
def define(
|
||||
maybe_cls: _C,
|
||||
*,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
auto_exc: bool = ...,
|
||||
eq: Optional[bool] = ...,
|
||||
order: Optional[bool] = ...,
|
||||
auto_detect: bool = ...,
|
||||
getstate_setstate: Optional[bool] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
field_transformer: Optional[_FieldTransformer] = ...,
|
||||
) -> _C: ...
|
||||
@overload
|
||||
@__dataclass_transform__(field_descriptors=(attrib, field))
|
||||
def define(
|
||||
maybe_cls: None = ...,
|
||||
*,
|
||||
these: Optional[Dict[str, Any]] = ...,
|
||||
repr: bool = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
auto_exc: bool = ...,
|
||||
eq: Optional[bool] = ...,
|
||||
order: Optional[bool] = ...,
|
||||
auto_detect: bool = ...,
|
||||
getstate_setstate: Optional[bool] = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
field_transformer: Optional[_FieldTransformer] = ...,
|
||||
) -> Callable[[_C], _C]: ...
|
||||
|
||||
mutable = define
|
||||
frozen = define # they differ only in their defaults
|
||||
|
||||
# TODO: add support for returning NamedTuple from the mypy plugin
|
||||
class _Fields(Tuple[Attribute[Any], ...]):
|
||||
def __getattr__(self, name: str) -> Attribute[Any]: ...
|
||||
|
||||
def fields(cls: type) -> _Fields: ...
|
||||
def fields_dict(cls: type) -> Dict[str, Attribute[Any]]: ...
|
||||
def validate(inst: Any) -> None: ...
|
||||
def resolve_types(
|
||||
cls: _C,
|
||||
globalns: Optional[Dict[str, Any]] = ...,
|
||||
localns: Optional[Dict[str, Any]] = ...,
|
||||
attribs: Optional[List[Attribute[Any]]] = ...,
|
||||
) -> _C: ...
|
||||
|
||||
# TODO: add support for returning a proper attrs class from the mypy plugin
|
||||
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo',
|
||||
# [attr.ib()])` is valid
|
||||
def make_class(
|
||||
name: str,
|
||||
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
|
||||
bases: Tuple[type, ...] = ...,
|
||||
repr_ns: Optional[str] = ...,
|
||||
repr: bool = ...,
|
||||
cmp: Optional[_EqOrderType] = ...,
|
||||
hash: Optional[bool] = ...,
|
||||
init: bool = ...,
|
||||
slots: bool = ...,
|
||||
frozen: bool = ...,
|
||||
weakref_slot: bool = ...,
|
||||
str: bool = ...,
|
||||
auto_attribs: bool = ...,
|
||||
kw_only: bool = ...,
|
||||
cache_hash: bool = ...,
|
||||
auto_exc: bool = ...,
|
||||
eq: Optional[_EqOrderType] = ...,
|
||||
order: Optional[_EqOrderType] = ...,
|
||||
collect_by_mro: bool = ...,
|
||||
on_setattr: Optional[_OnSetAttrArgType] = ...,
|
||||
field_transformer: Optional[_FieldTransformer] = ...,
|
||||
) -> type: ...
|
||||
|
||||
# _funcs --
|
||||
|
||||
# TODO: add support for returning TypedDict from the mypy plugin
|
||||
# FIXME: asdict/astuple do not honor their factory args. Waiting on one of
|
||||
# these:
|
||||
# https://github.com/python/mypy/issues/4236
|
||||
# https://github.com/python/typing/issues/253
|
||||
def asdict(
|
||||
inst: Any,
|
||||
recurse: bool = ...,
|
||||
filter: Optional[_FilterType[Any]] = ...,
|
||||
dict_factory: Type[Mapping[Any, Any]] = ...,
|
||||
retain_collection_types: bool = ...,
|
||||
value_serializer: Optional[Callable[[type, Attribute[Any], Any], Any]] = ...,
|
||||
) -> Dict[str, Any]: ...
|
||||
|
||||
# TODO: add support for returning NamedTuple from the mypy plugin
|
||||
def astuple(
|
||||
inst: Any,
|
||||
recurse: bool = ...,
|
||||
filter: Optional[_FilterType[Any]] = ...,
|
||||
tuple_factory: Type[Sequence[Any]] = ...,
|
||||
retain_collection_types: bool = ...,
|
||||
) -> Tuple[Any, ...]: ...
|
||||
def has(cls: type) -> bool: ...
|
||||
def assoc(inst: _T, **changes: Any) -> _T: ...
|
||||
def evolve(inst: _T, **changes: Any) -> _T: ...
|
||||
|
||||
# _config --
|
||||
|
||||
def set_run_validators(run: bool) -> None: ...
|
||||
def get_run_validators() -> bool: ...
|
||||
|
||||
# aliases --
|
||||
|
||||
s = attributes = attrs
|
||||
ib = attr = attrib
|
||||
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
|
||||
from ._compat import new_class
|
||||
from ._make import _make_ne
|
||||
|
||||
|
||||
_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
|
||||
|
||||
|
||||
def cmp_using(
|
||||
eq=None,
|
||||
lt=None,
|
||||
le=None,
|
||||
gt=None,
|
||||
ge=None,
|
||||
require_same_type=True,
|
||||
class_name="Comparable",
|
||||
):
|
||||
"""
|
||||
Create a class that can be passed into `attr.ib`'s ``eq``, ``order``, and
|
||||
``cmp`` arguments to customize field comparison.
|
||||
|
||||
The resulting class will have a full set of ordering methods if
|
||||
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
|
||||
|
||||
:param Optional[callable] eq: `callable` used to evaluate equality
|
||||
of two objects.
|
||||
:param Optional[callable] lt: `callable` used to evaluate whether
|
||||
one object is less than another object.
|
||||
:param Optional[callable] le: `callable` used to evaluate whether
|
||||
one object is less than or equal to another object.
|
||||
:param Optional[callable] gt: `callable` used to evaluate whether
|
||||
one object is greater than another object.
|
||||
:param Optional[callable] ge: `callable` used to evaluate whether
|
||||
one object is greater than or equal to another object.
|
||||
|
||||
:param bool require_same_type: When `True`, equality and ordering methods
|
||||
will return `NotImplemented` if objects are not of the same type.
|
||||
|
||||
:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
|
||||
|
||||
See `comparison` for more details.
|
||||
|
||||
.. versionadded:: 21.1.0
|
||||
"""
|
||||
|
||||
body = {
|
||||
"__slots__": ["value"],
|
||||
"__init__": _make_init(),
|
||||
"_requirements": [],
|
||||
"_is_comparable_to": _is_comparable_to,
|
||||
}
|
||||
|
||||
# Add operations.
|
||||
num_order_functions = 0
|
||||
has_eq_function = False
|
||||
|
||||
if eq is not None:
|
||||
has_eq_function = True
|
||||
body["__eq__"] = _make_operator("eq", eq)
|
||||
body["__ne__"] = _make_ne()
|
||||
|
||||
if lt is not None:
|
||||
num_order_functions += 1
|
||||
body["__lt__"] = _make_operator("lt", lt)
|
||||
|
||||
if le is not None:
|
||||
num_order_functions += 1
|
||||
body["__le__"] = _make_operator("le", le)
|
||||
|
||||
if gt is not None:
|
||||
num_order_functions += 1
|
||||
body["__gt__"] = _make_operator("gt", gt)
|
||||
|
||||
if ge is not None:
|
||||
num_order_functions += 1
|
||||
body["__ge__"] = _make_operator("ge", ge)
|
||||
|
||||
type_ = new_class(class_name, (object,), {}, lambda ns: ns.update(body))
|
||||
|
||||
# Add same type requirement.
|
||||
if require_same_type:
|
||||
type_._requirements.append(_check_same_type)
|
||||
|
||||
# Add total ordering if at least one operation was defined.
|
||||
if 0 < num_order_functions < 4:
|
||||
if not has_eq_function:
|
||||
# functools.total_ordering requires __eq__ to be defined,
|
||||
# so raise early error here to keep a nice stack.
|
||||
raise ValueError(
|
||||
"eq must be define is order to complete ordering from "
|
||||
"lt, le, gt, ge."
|
||||
)
|
||||
type_ = functools.total_ordering(type_)
|
||||
|
||||
return type_
|
||||
|
||||
|
||||
def _make_init():
|
||||
"""
|
||||
Create __init__ method.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
"""
|
||||
Initialize object with *value*.
|
||||
"""
|
||||
self.value = value
|
||||
|
||||
return __init__
|
||||
|
||||
|
||||
def _make_operator(name, func):
|
||||
"""
|
||||
Create operator method.
|
||||
"""
|
||||
|
||||
def method(self, other):
|
||||
if not self._is_comparable_to(other):
|
||||
return NotImplemented
|
||||
|
||||
result = func(self.value, other.value)
|
||||
if result is NotImplemented:
|
||||
return NotImplemented
|
||||
|
||||
return result
|
||||
|
||||
method.__name__ = "__%s__" % (name,)
|
||||
method.__doc__ = "Return a %s b. Computed by attrs." % (
|
||||
_operation_names[name],
|
||||
)
|
||||
|
||||
return method
|
||||
|
||||
|
||||
def _is_comparable_to(self, other):
|
||||
"""
|
||||
Check whether `other` is comparable to `self`.
|
||||
"""
|
||||
for func in self._requirements:
|
||||
if not func(self, other):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_same_type(self, other):
|
||||
"""
|
||||
Return True if *self* and *other* are of the same type, False otherwise.
|
||||
"""
|
||||
return other.value.__class__ is self.value.__class__
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
from typing import Type
|
||||
|
||||
from . import _CompareWithType
|
||||
|
||||
|
||||
def cmp_using(
|
||||
eq: Optional[_CompareWithType],
|
||||
lt: Optional[_CompareWithType],
|
||||
le: Optional[_CompareWithType],
|
||||
gt: Optional[_CompareWithType],
|
||||
ge: Optional[_CompareWithType],
|
||||
require_same_type: bool,
|
||||
class_name: str,
|
||||
) -> Type: ...
|
||||
|
|
@ -1,242 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import platform
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PYPY = platform.python_implementation() == "PyPy"
|
||||
|
||||
|
||||
if PYPY or sys.version_info[:2] >= (3, 6):
|
||||
ordered_dict = dict
|
||||
else:
|
||||
from collections import OrderedDict
|
||||
|
||||
ordered_dict = OrderedDict
|
||||
|
||||
|
||||
if PY2:
|
||||
from collections import Mapping, Sequence
|
||||
|
||||
from UserDict import IterableUserDict
|
||||
|
||||
# We 'bundle' isclass instead of using inspect as importing inspect is
|
||||
# fairly expensive (order of 10-15 ms for a modern machine in 2016)
|
||||
def isclass(klass):
|
||||
return isinstance(klass, (type, types.ClassType))
|
||||
|
||||
def new_class(name, bases, kwds, exec_body):
|
||||
"""
|
||||
A minimal stub of types.new_class that we need for make_class.
|
||||
"""
|
||||
ns = {}
|
||||
exec_body(ns)
|
||||
|
||||
return type(name, bases, ns)
|
||||
|
||||
# TYPE is used in exceptions, repr(int) is different on Python 2 and 3.
|
||||
TYPE = "type"
|
||||
|
||||
def iteritems(d):
|
||||
return d.iteritems()
|
||||
|
||||
# Python 2 is bereft of a read-only dict proxy, so we make one!
|
||||
class ReadOnlyDict(IterableUserDict):
|
||||
"""
|
||||
Best-effort read-only dict wrapper.
|
||||
"""
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise TypeError(
|
||||
"'mappingproxy' object does not support item assignment"
|
||||
)
|
||||
|
||||
def update(self, _):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'update'"
|
||||
)
|
||||
|
||||
def __delitem__(self, _):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise TypeError(
|
||||
"'mappingproxy' object does not support item deletion"
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'clear'"
|
||||
)
|
||||
|
||||
def pop(self, key, default=None):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'pop'"
|
||||
)
|
||||
|
||||
def popitem(self):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'popitem'"
|
||||
)
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
# We gently pretend we're a Python 3 mappingproxy.
|
||||
raise AttributeError(
|
||||
"'mappingproxy' object has no attribute 'setdefault'"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
# Override to be identical to the Python 3 version.
|
||||
return "mappingproxy(" + repr(self.data) + ")"
|
||||
|
||||
def metadata_proxy(d):
|
||||
res = ReadOnlyDict()
|
||||
res.data.update(d) # We blocked update, so we have to do it like this.
|
||||
return res
|
||||
|
||||
def just_warn(*args, **kw): # pragma: no cover
|
||||
"""
|
||||
We only warn on Python 3 because we are not aware of any concrete
|
||||
consequences of not setting the cell on Python 2.
|
||||
"""
|
||||
|
||||
|
||||
else: # Python 3 and later.
|
||||
from collections.abc import Mapping, Sequence # noqa
|
||||
|
||||
def just_warn(*args, **kw):
|
||||
"""
|
||||
We only warn on Python 3 because we are not aware of any concrete
|
||||
consequences of not setting the cell on Python 2.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Running interpreter doesn't sufficiently support code object "
|
||||
"introspection. Some features like bare super() or accessing "
|
||||
"__class__ will not work with slotted classes.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def isclass(klass):
|
||||
return isinstance(klass, type)
|
||||
|
||||
TYPE = "class"
|
||||
|
||||
def iteritems(d):
|
||||
return d.items()
|
||||
|
||||
new_class = types.new_class
|
||||
|
||||
def metadata_proxy(d):
|
||||
return types.MappingProxyType(dict(d))
|
||||
|
||||
|
||||
def make_set_closure_cell():
|
||||
"""Return a function of two arguments (cell, value) which sets
|
||||
the value stored in the closure cell `cell` to `value`.
|
||||
"""
|
||||
# pypy makes this easy. (It also supports the logic below, but
|
||||
# why not do the easy/fast thing?)
|
||||
if PYPY:
|
||||
|
||||
def set_closure_cell(cell, value):
|
||||
cell.__setstate__((value,))
|
||||
|
||||
return set_closure_cell
|
||||
|
||||
# Otherwise gotta do it the hard way.
|
||||
|
||||
# Create a function that will set its first cellvar to `value`.
|
||||
def set_first_cellvar_to(value):
|
||||
x = value
|
||||
return
|
||||
|
||||
# This function will be eliminated as dead code, but
|
||||
# not before its reference to `x` forces `x` to be
|
||||
# represented as a closure cell rather than a local.
|
||||
def force_x_to_be_a_cell(): # pragma: no cover
|
||||
return x
|
||||
|
||||
try:
|
||||
# Extract the code object and make sure our assumptions about
|
||||
# the closure behavior are correct.
|
||||
if PY2:
|
||||
co = set_first_cellvar_to.func_code
|
||||
else:
|
||||
co = set_first_cellvar_to.__code__
|
||||
if co.co_cellvars != ("x",) or co.co_freevars != ():
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
# Convert this code object to a code object that sets the
|
||||
# function's first _freevar_ (not cellvar) to the argument.
|
||||
if sys.version_info >= (3, 8):
|
||||
# CPython 3.8+ has an incompatible CodeType signature
|
||||
# (added a posonlyargcount argument) but also added
|
||||
# CodeType.replace() to do this without counting parameters.
|
||||
set_first_freevar_code = co.replace(
|
||||
co_cellvars=co.co_freevars, co_freevars=co.co_cellvars
|
||||
)
|
||||
else:
|
||||
args = [co.co_argcount]
|
||||
if not PY2:
|
||||
args.append(co.co_kwonlyargcount)
|
||||
args.extend(
|
||||
[
|
||||
co.co_nlocals,
|
||||
co.co_stacksize,
|
||||
co.co_flags,
|
||||
co.co_code,
|
||||
co.co_consts,
|
||||
co.co_names,
|
||||
co.co_varnames,
|
||||
co.co_filename,
|
||||
co.co_name,
|
||||
co.co_firstlineno,
|
||||
co.co_lnotab,
|
||||
# These two arguments are reversed:
|
||||
co.co_cellvars,
|
||||
co.co_freevars,
|
||||
]
|
||||
)
|
||||
set_first_freevar_code = types.CodeType(*args)
|
||||
|
||||
def set_closure_cell(cell, value):
|
||||
# Create a function using the set_first_freevar_code,
|
||||
# whose first closure cell is `cell`. Calling it will
|
||||
# change the value of that cell.
|
||||
setter = types.FunctionType(
|
||||
set_first_freevar_code, {}, "setter", (), (cell,)
|
||||
)
|
||||
# And call it to set the cell.
|
||||
setter(value)
|
||||
|
||||
# Make sure it works on this interpreter:
|
||||
def make_func_with_cell():
|
||||
x = None
|
||||
|
||||
def func():
|
||||
return x # pragma: no cover
|
||||
|
||||
return func
|
||||
|
||||
if PY2:
|
||||
cell = make_func_with_cell().func_closure[0]
|
||||
else:
|
||||
cell = make_func_with_cell().__closure__[0]
|
||||
set_closure_cell(cell, 100)
|
||||
if cell.cell_contents != 100:
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
except Exception:
|
||||
return just_warn
|
||||
else:
|
||||
return set_closure_cell
|
||||
|
||||
|
||||
set_closure_cell = make_set_closure_cell()
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
__all__ = ["set_run_validators", "get_run_validators"]
|
||||
|
||||
_run_validators = True
|
||||
|
||||
|
||||
def set_run_validators(run):
|
||||
"""
|
||||
Set whether or not validators are run. By default, they are run.
|
||||
"""
|
||||
if not isinstance(run, bool):
|
||||
raise TypeError("'run' must be bool.")
|
||||
global _run_validators
|
||||
_run_validators = run
|
||||
|
||||
|
||||
def get_run_validators():
|
||||
"""
|
||||
Return whether or not validators are run.
|
||||
"""
|
||||
return _run_validators
|
||||
|
|
@ -1,395 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import copy
|
||||
|
||||
from ._compat import iteritems
|
||||
from ._make import NOTHING, _obj_setattr, fields
|
||||
from .exceptions import AttrsAttributeNotFoundError
|
||||
|
||||
|
||||
def asdict(
|
||||
inst,
|
||||
recurse=True,
|
||||
filter=None,
|
||||
dict_factory=dict,
|
||||
retain_collection_types=False,
|
||||
value_serializer=None,
|
||||
):
|
||||
"""
|
||||
Return the ``attrs`` attribute values of *inst* as a dict.
|
||||
|
||||
Optionally recurse into other ``attrs``-decorated classes.
|
||||
|
||||
:param inst: Instance of an ``attrs``-decorated class.
|
||||
:param bool recurse: Recurse into classes that are also
|
||||
``attrs``-decorated.
|
||||
:param callable filter: A callable whose return code determines whether an
|
||||
attribute or element is included (``True``) or dropped (``False``). Is
|
||||
called with the `attr.Attribute` as the first argument and the
|
||||
value as the second argument.
|
||||
:param callable dict_factory: A callable to produce dictionaries from. For
|
||||
example, to produce ordered dictionaries instead of normal Python
|
||||
dictionaries, pass in ``collections.OrderedDict``.
|
||||
:param bool retain_collection_types: Do not convert to ``list`` when
|
||||
encountering an attribute whose type is ``tuple`` or ``set``. Only
|
||||
meaningful if ``recurse`` is ``True``.
|
||||
:param Optional[callable] value_serializer: A hook that is called for every
|
||||
attribute or dict key/value. It receives the current instance, field
|
||||
and value and must return the (updated) value. The hook is run *after*
|
||||
the optional *filter* has been applied.
|
||||
|
||||
:rtype: return type of *dict_factory*
|
||||
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 16.0.0 *dict_factory*
|
||||
.. versionadded:: 16.1.0 *retain_collection_types*
|
||||
.. versionadded:: 20.3.0 *value_serializer*
|
||||
"""
|
||||
attrs = fields(inst.__class__)
|
||||
rv = dict_factory()
|
||||
for a in attrs:
|
||||
v = getattr(inst, a.name)
|
||||
if filter is not None and not filter(a, v):
|
||||
continue
|
||||
|
||||
if value_serializer is not None:
|
||||
v = value_serializer(inst, a, v)
|
||||
|
||||
if recurse is True:
|
||||
if has(v.__class__):
|
||||
rv[a.name] = asdict(
|
||||
v,
|
||||
True,
|
||||
filter,
|
||||
dict_factory,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
)
|
||||
elif isinstance(v, (tuple, list, set, frozenset)):
|
||||
cf = v.__class__ if retain_collection_types is True else list
|
||||
rv[a.name] = cf(
|
||||
[
|
||||
_asdict_anything(
|
||||
i,
|
||||
filter,
|
||||
dict_factory,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
)
|
||||
for i in v
|
||||
]
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
df = dict_factory
|
||||
rv[a.name] = df(
|
||||
(
|
||||
_asdict_anything(
|
||||
kk,
|
||||
filter,
|
||||
df,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
),
|
||||
_asdict_anything(
|
||||
vv,
|
||||
filter,
|
||||
df,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
),
|
||||
)
|
||||
for kk, vv in iteritems(v)
|
||||
)
|
||||
else:
|
||||
rv[a.name] = v
|
||||
else:
|
||||
rv[a.name] = v
|
||||
return rv
|
||||
|
||||
|
||||
def _asdict_anything(
|
||||
val,
|
||||
filter,
|
||||
dict_factory,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
):
|
||||
"""
|
||||
``asdict`` only works on attrs instances, this works on anything.
|
||||
"""
|
||||
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
|
||||
# Attrs class.
|
||||
rv = asdict(
|
||||
val,
|
||||
True,
|
||||
filter,
|
||||
dict_factory,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
)
|
||||
elif isinstance(val, (tuple, list, set, frozenset)):
|
||||
cf = val.__class__ if retain_collection_types is True else list
|
||||
rv = cf(
|
||||
[
|
||||
_asdict_anything(
|
||||
i,
|
||||
filter,
|
||||
dict_factory,
|
||||
retain_collection_types,
|
||||
value_serializer,
|
||||
)
|
||||
for i in val
|
||||
]
|
||||
)
|
||||
elif isinstance(val, dict):
|
||||
df = dict_factory
|
||||
rv = df(
|
||||
(
|
||||
_asdict_anything(
|
||||
kk, filter, df, retain_collection_types, value_serializer
|
||||
),
|
||||
_asdict_anything(
|
||||
vv, filter, df, retain_collection_types, value_serializer
|
||||
),
|
||||
)
|
||||
for kk, vv in iteritems(val)
|
||||
)
|
||||
else:
|
||||
rv = val
|
||||
if value_serializer is not None:
|
||||
rv = value_serializer(None, None, rv)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def astuple(
|
||||
inst,
|
||||
recurse=True,
|
||||
filter=None,
|
||||
tuple_factory=tuple,
|
||||
retain_collection_types=False,
|
||||
):
|
||||
"""
|
||||
Return the ``attrs`` attribute values of *inst* as a tuple.
|
||||
|
||||
Optionally recurse into other ``attrs``-decorated classes.
|
||||
|
||||
:param inst: Instance of an ``attrs``-decorated class.
|
||||
:param bool recurse: Recurse into classes that are also
|
||||
``attrs``-decorated.
|
||||
:param callable filter: A callable whose return code determines whether an
|
||||
attribute or element is included (``True``) or dropped (``False``). Is
|
||||
called with the `attr.Attribute` as the first argument and the
|
||||
value as the second argument.
|
||||
:param callable tuple_factory: A callable to produce tuples from. For
|
||||
example, to produce lists instead of tuples.
|
||||
:param bool retain_collection_types: Do not convert to ``list``
|
||||
or ``dict`` when encountering an attribute which type is
|
||||
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
|
||||
``True``.
|
||||
|
||||
:rtype: return type of *tuple_factory*
|
||||
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
attrs = fields(inst.__class__)
|
||||
rv = []
|
||||
retain = retain_collection_types # Very long. :/
|
||||
for a in attrs:
|
||||
v = getattr(inst, a.name)
|
||||
if filter is not None and not filter(a, v):
|
||||
continue
|
||||
if recurse is True:
|
||||
if has(v.__class__):
|
||||
rv.append(
|
||||
astuple(
|
||||
v,
|
||||
recurse=True,
|
||||
filter=filter,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
)
|
||||
elif isinstance(v, (tuple, list, set, frozenset)):
|
||||
cf = v.__class__ if retain is True else list
|
||||
rv.append(
|
||||
cf(
|
||||
[
|
||||
astuple(
|
||||
j,
|
||||
recurse=True,
|
||||
filter=filter,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(j.__class__)
|
||||
else j
|
||||
for j in v
|
||||
]
|
||||
)
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
df = v.__class__ if retain is True else dict
|
||||
rv.append(
|
||||
df(
|
||||
(
|
||||
astuple(
|
||||
kk,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(kk.__class__)
|
||||
else kk,
|
||||
astuple(
|
||||
vv,
|
||||
tuple_factory=tuple_factory,
|
||||
retain_collection_types=retain,
|
||||
)
|
||||
if has(vv.__class__)
|
||||
else vv,
|
||||
)
|
||||
for kk, vv in iteritems(v)
|
||||
)
|
||||
)
|
||||
else:
|
||||
rv.append(v)
|
||||
else:
|
||||
rv.append(v)
|
||||
|
||||
return rv if tuple_factory is list else tuple_factory(rv)
|
||||
|
||||
|
||||
def has(cls):
|
||||
"""
|
||||
Check whether *cls* is a class with ``attrs`` attributes.
|
||||
|
||||
:param type cls: Class to introspect.
|
||||
:raise TypeError: If *cls* is not a class.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return getattr(cls, "__attrs_attrs__", None) is not None
|
||||
|
||||
|
||||
def assoc(inst, **changes):
|
||||
"""
|
||||
Copy *inst* and apply *changes*.
|
||||
|
||||
:param inst: Instance of a class with ``attrs`` attributes.
|
||||
:param changes: Keyword changes in the new copy.
|
||||
|
||||
:return: A copy of inst with *changes* incorporated.
|
||||
|
||||
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
|
||||
be found on *cls*.
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. deprecated:: 17.1.0
|
||||
Use `evolve` instead.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"assoc is deprecated and will be removed after 2018/01.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
new = copy.copy(inst)
|
||||
attrs = fields(inst.__class__)
|
||||
for k, v in iteritems(changes):
|
||||
a = getattr(attrs, k, NOTHING)
|
||||
if a is NOTHING:
|
||||
raise AttrsAttributeNotFoundError(
|
||||
"{k} is not an attrs attribute on {cl}.".format(
|
||||
k=k, cl=new.__class__
|
||||
)
|
||||
)
|
||||
_obj_setattr(new, k, v)
|
||||
return new
|
||||
|
||||
|
||||
def evolve(inst, **changes):
|
||||
"""
|
||||
Create a new instance, based on *inst* with *changes* applied.
|
||||
|
||||
:param inst: Instance of a class with ``attrs`` attributes.
|
||||
:param changes: Keyword changes in the new copy.
|
||||
|
||||
:return: A copy of inst with *changes* incorporated.
|
||||
|
||||
:raise TypeError: If *attr_name* couldn't be found in the class
|
||||
``__init__``.
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
cls = inst.__class__
|
||||
attrs = fields(cls)
|
||||
for a in attrs:
|
||||
if not a.init:
|
||||
continue
|
||||
attr_name = a.name # To deal with private attributes.
|
||||
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
|
||||
if init_name not in changes:
|
||||
changes[init_name] = getattr(inst, attr_name)
|
||||
|
||||
return cls(**changes)
|
||||
|
||||
|
||||
def resolve_types(cls, globalns=None, localns=None, attribs=None):
|
||||
"""
|
||||
Resolve any strings and forward annotations in type annotations.
|
||||
|
||||
This is only required if you need concrete types in `Attribute`'s *type*
|
||||
field. In other words, you don't need to resolve your types if you only
|
||||
use them for static type checking.
|
||||
|
||||
With no arguments, names will be looked up in the module in which the class
|
||||
was created. If this is not what you want, e.g. if the name only exists
|
||||
inside a method, you may pass *globalns* or *localns* to specify other
|
||||
dictionaries in which to look up these names. See the docs of
|
||||
`typing.get_type_hints` for more details.
|
||||
|
||||
:param type cls: Class to resolve.
|
||||
:param Optional[dict] globalns: Dictionary containing global variables.
|
||||
:param Optional[dict] localns: Dictionary containing local variables.
|
||||
:param Optional[list] attribs: List of attribs for the given class.
|
||||
This is necessary when calling from inside a ``field_transformer``
|
||||
since *cls* is not an ``attrs`` class yet.
|
||||
|
||||
:raise TypeError: If *cls* is not a class.
|
||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
||||
class and you didn't pass any attribs.
|
||||
:raise NameError: If types cannot be resolved because of missing variables.
|
||||
|
||||
:returns: *cls* so you can use this function also as a class decorator.
|
||||
Please note that you have to apply it **after** `attr.s`. That means
|
||||
the decorator has to come in the line **before** `attr.s`.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
.. versionadded:: 21.1.0 *attribs*
|
||||
|
||||
"""
|
||||
try:
|
||||
# Since calling get_type_hints is expensive we cache whether we've
|
||||
# done it already.
|
||||
cls.__attrs_types_resolved__
|
||||
except AttributeError:
|
||||
import typing
|
||||
|
||||
hints = typing.get_type_hints(cls, globalns=globalns, localns=localns)
|
||||
for field in fields(cls) if attribs is None else attribs:
|
||||
if field.name in hints:
|
||||
# Since fields have been frozen we must work around it.
|
||||
_obj_setattr(field, "type", hints[field.name])
|
||||
cls.__attrs_types_resolved__ = True
|
||||
|
||||
# Return the class so you can use it as a decorator too.
|
||||
return cls
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,158 +0,0 @@
|
|||
"""
|
||||
These are Python 3.6+-only and keyword-only APIs that call `attr.s` and
|
||||
`attr.ib` with different default values.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from attr.exceptions import UnannotatedAttributeError
|
||||
|
||||
from . import setters
|
||||
from ._make import NOTHING, _frozen_setattrs, attrib, attrs
|
||||
|
||||
|
||||
def define(
|
||||
maybe_cls=None,
|
||||
*,
|
||||
these=None,
|
||||
repr=None,
|
||||
hash=None,
|
||||
init=None,
|
||||
slots=True,
|
||||
frozen=False,
|
||||
weakref_slot=True,
|
||||
str=False,
|
||||
auto_attribs=None,
|
||||
kw_only=False,
|
||||
cache_hash=False,
|
||||
auto_exc=True,
|
||||
eq=None,
|
||||
order=False,
|
||||
auto_detect=True,
|
||||
getstate_setstate=None,
|
||||
on_setattr=None,
|
||||
field_transformer=None,
|
||||
):
|
||||
r"""
|
||||
The only behavioral differences are the handling of the *auto_attribs*
|
||||
option:
|
||||
|
||||
:param Optional[bool] auto_attribs: If set to `True` or `False`, it behaves
|
||||
exactly like `attr.s`. If left `None`, `attr.s` will try to guess:
|
||||
|
||||
1. If any attributes are annotated and no unannotated `attr.ib`\ s
|
||||
are found, it assumes *auto_attribs=True*.
|
||||
2. Otherwise it assumes *auto_attribs=False* and tries to collect
|
||||
`attr.ib`\ s.
|
||||
|
||||
and that mutable classes (``frozen=False``) validate on ``__setattr__``.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
|
||||
def do_it(cls, auto_attribs):
|
||||
return attrs(
|
||||
maybe_cls=cls,
|
||||
these=these,
|
||||
repr=repr,
|
||||
hash=hash,
|
||||
init=init,
|
||||
slots=slots,
|
||||
frozen=frozen,
|
||||
weakref_slot=weakref_slot,
|
||||
str=str,
|
||||
auto_attribs=auto_attribs,
|
||||
kw_only=kw_only,
|
||||
cache_hash=cache_hash,
|
||||
auto_exc=auto_exc,
|
||||
eq=eq,
|
||||
order=order,
|
||||
auto_detect=auto_detect,
|
||||
collect_by_mro=True,
|
||||
getstate_setstate=getstate_setstate,
|
||||
on_setattr=on_setattr,
|
||||
field_transformer=field_transformer,
|
||||
)
|
||||
|
||||
def wrap(cls):
|
||||
"""
|
||||
Making this a wrapper ensures this code runs during class creation.
|
||||
|
||||
We also ensure that frozen-ness of classes is inherited.
|
||||
"""
|
||||
nonlocal frozen, on_setattr
|
||||
|
||||
had_on_setattr = on_setattr not in (None, setters.NO_OP)
|
||||
|
||||
# By default, mutable classes validate on setattr.
|
||||
if frozen is False and on_setattr is None:
|
||||
on_setattr = setters.validate
|
||||
|
||||
# However, if we subclass a frozen class, we inherit the immutability
|
||||
# and disable on_setattr.
|
||||
for base_cls in cls.__bases__:
|
||||
if base_cls.__setattr__ is _frozen_setattrs:
|
||||
if had_on_setattr:
|
||||
raise ValueError(
|
||||
"Frozen classes can't use on_setattr "
|
||||
"(frozen-ness was inherited)."
|
||||
)
|
||||
|
||||
on_setattr = setters.NO_OP
|
||||
break
|
||||
|
||||
if auto_attribs is not None:
|
||||
return do_it(cls, auto_attribs)
|
||||
|
||||
try:
|
||||
return do_it(cls, True)
|
||||
except UnannotatedAttributeError:
|
||||
return do_it(cls, False)
|
||||
|
||||
# maybe_cls's type depends on the usage of the decorator. It's a class
|
||||
# if it's used as `@attrs` but ``None`` if used as `@attrs()`.
|
||||
if maybe_cls is None:
|
||||
return wrap
|
||||
else:
|
||||
return wrap(maybe_cls)
|
||||
|
||||
|
||||
mutable = define
|
||||
frozen = partial(define, frozen=True, on_setattr=None)
|
||||
|
||||
|
||||
def field(
|
||||
*,
|
||||
default=NOTHING,
|
||||
validator=None,
|
||||
repr=True,
|
||||
hash=None,
|
||||
init=True,
|
||||
metadata=None,
|
||||
converter=None,
|
||||
factory=None,
|
||||
kw_only=False,
|
||||
eq=None,
|
||||
order=None,
|
||||
on_setattr=None,
|
||||
):
|
||||
"""
|
||||
Identical to `attr.ib`, except keyword-only and with some arguments
|
||||
removed.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
return attrib(
|
||||
default=default,
|
||||
validator=validator,
|
||||
repr=repr,
|
||||
hash=hash,
|
||||
init=init,
|
||||
metadata=metadata,
|
||||
converter=converter,
|
||||
factory=factory,
|
||||
kw_only=kw_only,
|
||||
eq=eq,
|
||||
order=order,
|
||||
on_setattr=on_setattr,
|
||||
)
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from functools import total_ordering
|
||||
|
||||
from ._funcs import astuple
|
||||
from ._make import attrib, attrs
|
||||
|
||||
|
||||
@total_ordering
|
||||
@attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class VersionInfo(object):
|
||||
"""
|
||||
A version object that can be compared to tuple of length 1--4:
|
||||
|
||||
>>> attr.VersionInfo(19, 1, 0, "final") <= (19, 2)
|
||||
True
|
||||
>>> attr.VersionInfo(19, 1, 0, "final") < (19, 1, 1)
|
||||
True
|
||||
>>> vi = attr.VersionInfo(19, 2, 0, "final")
|
||||
>>> vi < (19, 1, 1)
|
||||
False
|
||||
>>> vi < (19,)
|
||||
False
|
||||
>>> vi == (19, 2,)
|
||||
True
|
||||
>>> vi == (19, 2, 1)
|
||||
False
|
||||
|
||||
.. versionadded:: 19.2
|
||||
"""
|
||||
|
||||
year = attrib(type=int)
|
||||
minor = attrib(type=int)
|
||||
micro = attrib(type=int)
|
||||
releaselevel = attrib(type=str)
|
||||
|
||||
@classmethod
|
||||
def _from_version_string(cls, s):
|
||||
"""
|
||||
Parse *s* and return a _VersionInfo.
|
||||
"""
|
||||
v = s.split(".")
|
||||
if len(v) == 3:
|
||||
v.append("final")
|
||||
|
||||
return cls(
|
||||
year=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]
|
||||
)
|
||||
|
||||
def _ensure_tuple(self, other):
|
||||
"""
|
||||
Ensure *other* is a tuple of a valid length.
|
||||
|
||||
Returns a possibly transformed *other* and ourselves as a tuple of
|
||||
the same length as *other*.
|
||||
"""
|
||||
|
||||
if self.__class__ is other.__class__:
|
||||
other = astuple(other)
|
||||
|
||||
if not isinstance(other, tuple):
|
||||
raise NotImplementedError
|
||||
|
||||
if not (1 <= len(other) <= 4):
|
||||
raise NotImplementedError
|
||||
|
||||
return astuple(self)[: len(other)], other
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError:
|
||||
return NotImplemented
|
||||
|
||||
return us == them
|
||||
|
||||
def __lt__(self, other):
|
||||
try:
|
||||
us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError:
|
||||
return NotImplemented
|
||||
|
||||
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
|
||||
# have to do anything special with releaselevel for now.
|
||||
return us < them
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
class VersionInfo:
|
||||
@property
|
||||
def year(self) -> int: ...
|
||||
@property
|
||||
def minor(self) -> int: ...
|
||||
@property
|
||||
def micro(self) -> int: ...
|
||||
@property
|
||||
def releaselevel(self) -> str: ...
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
"""
|
||||
Commonly useful converters.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ._compat import PY2
|
||||
from ._make import NOTHING, Factory, pipe
|
||||
|
||||
|
||||
if not PY2:
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
|
||||
__all__ = [
|
||||
"pipe",
|
||||
"optional",
|
||||
"default_if_none",
|
||||
]
|
||||
|
||||
|
||||
def optional(converter):
|
||||
"""
|
||||
A converter that allows an attribute to be optional. An optional attribute
|
||||
is one which can be set to ``None``.
|
||||
|
||||
Type annotations will be inferred from the wrapped converter's, if it
|
||||
has any.
|
||||
|
||||
:param callable converter: the converter that is used for non-``None``
|
||||
values.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
|
||||
def optional_converter(val):
|
||||
if val is None:
|
||||
return None
|
||||
return converter(val)
|
||||
|
||||
if not PY2:
|
||||
sig = None
|
||||
try:
|
||||
sig = inspect.signature(converter)
|
||||
except (ValueError, TypeError): # inspect failed
|
||||
pass
|
||||
if sig:
|
||||
params = list(sig.parameters.values())
|
||||
if params and params[0].annotation is not inspect.Parameter.empty:
|
||||
optional_converter.__annotations__["val"] = typing.Optional[
|
||||
params[0].annotation
|
||||
]
|
||||
if sig.return_annotation is not inspect.Signature.empty:
|
||||
optional_converter.__annotations__["return"] = typing.Optional[
|
||||
sig.return_annotation
|
||||
]
|
||||
|
||||
return optional_converter
|
||||
|
||||
|
||||
def default_if_none(default=NOTHING, factory=None):
|
||||
"""
|
||||
A converter that allows to replace ``None`` values by *default* or the
|
||||
result of *factory*.
|
||||
|
||||
:param default: Value to be used if ``None`` is passed. Passing an instance
|
||||
of `attr.Factory` is supported, however the ``takes_self`` option
|
||||
is *not*.
|
||||
:param callable factory: A callable that takes no parameters whose result
|
||||
is used if ``None`` is passed.
|
||||
|
||||
:raises TypeError: If **neither** *default* or *factory* is passed.
|
||||
:raises TypeError: If **both** *default* and *factory* are passed.
|
||||
:raises ValueError: If an instance of `attr.Factory` is passed with
|
||||
``takes_self=True``.
|
||||
|
||||
.. versionadded:: 18.2.0
|
||||
"""
|
||||
if default is NOTHING and factory is None:
|
||||
raise TypeError("Must pass either `default` or `factory`.")
|
||||
|
||||
if default is not NOTHING and factory is not None:
|
||||
raise TypeError(
|
||||
"Must pass either `default` or `factory` but not both."
|
||||
)
|
||||
|
||||
if factory is not None:
|
||||
default = Factory(factory)
|
||||
|
||||
if isinstance(default, Factory):
|
||||
if default.takes_self:
|
||||
raise ValueError(
|
||||
"`takes_self` is not supported by default_if_none."
|
||||
)
|
||||
|
||||
def default_if_none_converter(val):
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return default.factory()
|
||||
|
||||
else:
|
||||
|
||||
def default_if_none_converter(val):
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return default
|
||||
|
||||
return default_if_none_converter
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
from typing import Callable, Optional, TypeVar, overload
|
||||
|
||||
from . import _ConverterType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def pipe(*validators: _ConverterType) -> _ConverterType: ...
|
||||
def optional(converter: _ConverterType) -> _ConverterType: ...
|
||||
@overload
|
||||
def default_if_none(default: _T) -> _ConverterType: ...
|
||||
@overload
|
||||
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType: ...
|
||||
|
|
@ -1,92 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
class FrozenError(AttributeError):
|
||||
"""
|
||||
A frozen/immutable instance or attribute have been attempted to be
|
||||
modified.
|
||||
|
||||
It mirrors the behavior of ``namedtuples`` by using the same error message
|
||||
and subclassing `AttributeError`.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
|
||||
msg = "can't set attribute"
|
||||
args = [msg]
|
||||
|
||||
|
||||
class FrozenInstanceError(FrozenError):
|
||||
"""
|
||||
A frozen instance has been attempted to be modified.
|
||||
|
||||
.. versionadded:: 16.1.0
|
||||
"""
|
||||
|
||||
|
||||
class FrozenAttributeError(FrozenError):
|
||||
"""
|
||||
A frozen attribute has been attempted to be modified.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
|
||||
|
||||
class AttrsAttributeNotFoundError(ValueError):
|
||||
"""
|
||||
An ``attrs`` function couldn't find an attribute that the user asked for.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
|
||||
|
||||
class NotAnAttrsClassError(ValueError):
|
||||
"""
|
||||
A non-``attrs`` class has been passed into an ``attrs`` function.
|
||||
|
||||
.. versionadded:: 16.2.0
|
||||
"""
|
||||
|
||||
|
||||
class DefaultAlreadySetError(RuntimeError):
|
||||
"""
|
||||
A default has been set using ``attr.ib()`` and is attempted to be reset
|
||||
using the decorator.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
|
||||
|
||||
class UnannotatedAttributeError(RuntimeError):
|
||||
"""
|
||||
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
|
||||
annotation.
|
||||
|
||||
.. versionadded:: 17.3.0
|
||||
"""
|
||||
|
||||
|
||||
class PythonTooOldError(RuntimeError):
|
||||
"""
|
||||
It was attempted to use an ``attrs`` feature that requires a newer Python
|
||||
version.
|
||||
|
||||
.. versionadded:: 18.2.0
|
||||
"""
|
||||
|
||||
|
||||
class NotCallableError(TypeError):
|
||||
"""
|
||||
A ``attr.ib()`` requiring a callable has been set with a value
|
||||
that is not callable.
|
||||
|
||||
.. versionadded:: 19.2.0
|
||||
"""
|
||||
|
||||
def __init__(self, msg, value):
|
||||
super(TypeError, self).__init__(msg, value)
|
||||
self.msg = msg
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
return str(self.msg)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Any
|
||||
|
||||
|
||||
class FrozenError(AttributeError):
|
||||
msg: str = ...
|
||||
|
||||
class FrozenInstanceError(FrozenError): ...
|
||||
class FrozenAttributeError(FrozenError): ...
|
||||
class AttrsAttributeNotFoundError(ValueError): ...
|
||||
class NotAnAttrsClassError(ValueError): ...
|
||||
class DefaultAlreadySetError(RuntimeError): ...
|
||||
class UnannotatedAttributeError(RuntimeError): ...
|
||||
class PythonTooOldError(RuntimeError): ...
|
||||
|
||||
class NotCallableError(TypeError):
|
||||
msg: str = ...
|
||||
value: Any = ...
|
||||
def __init__(self, msg: str, value: Any) -> None: ...
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
"""
|
||||
Commonly useful filters for `attr.asdict`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ._compat import isclass
|
||||
from ._make import Attribute
|
||||
|
||||
|
||||
def _split_what(what):
|
||||
"""
|
||||
Returns a tuple of `frozenset`s of classes and attributes.
|
||||
"""
|
||||
return (
|
||||
frozenset(cls for cls in what if isclass(cls)),
|
||||
frozenset(cls for cls in what if isinstance(cls, Attribute)),
|
||||
)
|
||||
|
||||
|
||||
def include(*what):
|
||||
"""
|
||||
Whitelist *what*.
|
||||
|
||||
:param what: What to whitelist.
|
||||
:type what: `list` of `type` or `attr.Attribute`\\ s
|
||||
|
||||
:rtype: `callable`
|
||||
"""
|
||||
cls, attrs = _split_what(what)
|
||||
|
||||
def include_(attribute, value):
|
||||
return value.__class__ in cls or attribute in attrs
|
||||
|
||||
return include_
|
||||
|
||||
|
||||
def exclude(*what):
|
||||
"""
|
||||
Blacklist *what*.
|
||||
|
||||
:param what: What to blacklist.
|
||||
:type what: `list` of classes or `attr.Attribute`\\ s.
|
||||
|
||||
:rtype: `callable`
|
||||
"""
|
||||
cls, attrs = _split_what(what)
|
||||
|
||||
def exclude_(attribute, value):
|
||||
return value.__class__ not in cls and attribute not in attrs
|
||||
|
||||
return exclude_
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from . import Attribute, _FilterType
|
||||
|
||||
|
||||
def include(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...
|
||||
def exclude(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
"""
|
||||
Commonly used hooks for on_setattr.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from . import _config
|
||||
from .exceptions import FrozenAttributeError
|
||||
|
||||
|
||||
def pipe(*setters):
|
||||
"""
|
||||
Run all *setters* and return the return value of the last one.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
|
||||
def wrapped_pipe(instance, attrib, new_value):
|
||||
rv = new_value
|
||||
|
||||
for setter in setters:
|
||||
rv = setter(instance, attrib, rv)
|
||||
|
||||
return rv
|
||||
|
||||
return wrapped_pipe
|
||||
|
||||
|
||||
def frozen(_, __, ___):
|
||||
"""
|
||||
Prevent an attribute to be modified.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
raise FrozenAttributeError()
|
||||
|
||||
|
||||
def validate(instance, attrib, new_value):
|
||||
"""
|
||||
Run *attrib*'s validator on *new_value* if it has one.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
if _config._run_validators is False:
|
||||
return new_value
|
||||
|
||||
v = attrib.validator
|
||||
if not v:
|
||||
return new_value
|
||||
|
||||
v(instance, attrib, new_value)
|
||||
|
||||
return new_value
|
||||
|
||||
|
||||
def convert(instance, attrib, new_value):
|
||||
"""
|
||||
Run *attrib*'s converter -- if it has one -- on *new_value* and return the
|
||||
result.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
c = attrib.converter
|
||||
if c:
|
||||
return c(new_value)
|
||||
|
||||
return new_value
|
||||
|
||||
|
||||
NO_OP = object()
|
||||
"""
|
||||
Sentinel for disabling class-wide *on_setattr* hooks for certain attributes.
|
||||
|
||||
Does not work in `pipe` or within lists.
|
||||
|
||||
.. versionadded:: 20.1.0
|
||||
"""
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
from typing import Any, NewType, NoReturn, TypeVar, cast
|
||||
|
||||
from . import Attribute, _OnSetAttrType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def frozen(
|
||||
instance: Any, attribute: Attribute[Any], new_value: Any
|
||||
) -> NoReturn: ...
|
||||
def pipe(*setters: _OnSetAttrType) -> _OnSetAttrType: ...
|
||||
def validate(instance: Any, attribute: Attribute[_T], new_value: _T) -> _T: ...
|
||||
|
||||
# convert is allowed to return Any, because they can be chained using pipe.
|
||||
def convert(
|
||||
instance: Any, attribute: Attribute[Any], new_value: Any
|
||||
) -> Any: ...
|
||||
|
||||
_NoOpType = NewType("_NoOpType", object)
|
||||
NO_OP: _NoOpType
|
||||
|
|
@ -1,379 +0,0 @@
|
|||
"""
|
||||
Commonly useful validators.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import re
|
||||
|
||||
from ._make import _AndValidator, and_, attrib, attrs
|
||||
from .exceptions import NotCallableError
|
||||
|
||||
|
||||
__all__ = [
|
||||
"and_",
|
||||
"deep_iterable",
|
||||
"deep_mapping",
|
||||
"in_",
|
||||
"instance_of",
|
||||
"is_callable",
|
||||
"matches_re",
|
||||
"optional",
|
||||
"provides",
|
||||
]
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _InstanceOfValidator(object):
|
||||
type = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(
|
||||
"'{name}' must be {type!r} (got {value!r} that is a "
|
||||
"{actual!r}).".format(
|
||||
name=attr.name,
|
||||
type=self.type,
|
||||
actual=value.__class__,
|
||||
value=value,
|
||||
),
|
||||
attr,
|
||||
self.type,
|
||||
value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<instance_of validator for type {type!r}>".format(
|
||||
type=self.type
|
||||
)
|
||||
|
||||
|
||||
def instance_of(type):
|
||||
"""
|
||||
A validator that raises a `TypeError` if the initializer is called
|
||||
with a wrong type for this particular attribute (checks are performed using
|
||||
`isinstance` therefore it's also valid to pass a tuple of types).
|
||||
|
||||
:param type: The type to check for.
|
||||
:type type: type or tuple of types
|
||||
|
||||
:raises TypeError: With a human readable error message, the attribute
|
||||
(of type `attr.Attribute`), the expected type, and the value it
|
||||
got.
|
||||
"""
|
||||
return _InstanceOfValidator(type)
|
||||
|
||||
|
||||
@attrs(repr=False, frozen=True, slots=True)
|
||||
class _MatchesReValidator(object):
|
||||
regex = attrib()
|
||||
flags = attrib()
|
||||
match_func = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not self.match_func(value):
|
||||
raise ValueError(
|
||||
"'{name}' must match regex {regex!r}"
|
||||
" ({value!r} doesn't)".format(
|
||||
name=attr.name, regex=self.regex.pattern, value=value
|
||||
),
|
||||
attr,
|
||||
self.regex,
|
||||
value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<matches_re validator for pattern {regex!r}>".format(
|
||||
regex=self.regex
|
||||
)
|
||||
|
||||
|
||||
def matches_re(regex, flags=0, func=None):
|
||||
r"""
|
||||
A validator that raises `ValueError` if the initializer is called
|
||||
with a string that doesn't match *regex*.
|
||||
|
||||
:param str regex: a regex string to match against
|
||||
:param int flags: flags that will be passed to the underlying re function
|
||||
(default 0)
|
||||
:param callable func: which underlying `re` function to call (options
|
||||
are `re.fullmatch`, `re.search`, `re.match`, default
|
||||
is ``None`` which means either `re.fullmatch` or an emulation of
|
||||
it on Python 2). For performance reasons, they won't be used directly
|
||||
but on a pre-`re.compile`\ ed pattern.
|
||||
|
||||
.. versionadded:: 19.2.0
|
||||
"""
|
||||
fullmatch = getattr(re, "fullmatch", None)
|
||||
valid_funcs = (fullmatch, None, re.search, re.match)
|
||||
if func not in valid_funcs:
|
||||
raise ValueError(
|
||||
"'func' must be one of %s."
|
||||
% (
|
||||
", ".join(
|
||||
sorted(
|
||||
e and e.__name__ or "None" for e in set(valid_funcs)
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
pattern = re.compile(regex, flags)
|
||||
if func is re.match:
|
||||
match_func = pattern.match
|
||||
elif func is re.search:
|
||||
match_func = pattern.search
|
||||
else:
|
||||
if fullmatch:
|
||||
match_func = pattern.fullmatch
|
||||
else:
|
||||
pattern = re.compile(r"(?:{})\Z".format(regex), flags)
|
||||
match_func = pattern.match
|
||||
|
||||
return _MatchesReValidator(pattern, flags, match_func)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _ProvidesValidator(object):
|
||||
interface = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not self.interface.providedBy(value):
|
||||
raise TypeError(
|
||||
"'{name}' must provide {interface!r} which {value!r} "
|
||||
"doesn't.".format(
|
||||
name=attr.name, interface=self.interface, value=value
|
||||
),
|
||||
attr,
|
||||
self.interface,
|
||||
value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<provides validator for interface {interface!r}>".format(
|
||||
interface=self.interface
|
||||
)
|
||||
|
||||
|
||||
def provides(interface):
|
||||
"""
|
||||
A validator that raises a `TypeError` if the initializer is called
|
||||
with an object that does not provide the requested *interface* (checks are
|
||||
performed using ``interface.providedBy(value)`` (see `zope.interface
|
||||
<https://zopeinterface.readthedocs.io/en/latest/>`_).
|
||||
|
||||
:param interface: The interface to check for.
|
||||
:type interface: ``zope.interface.Interface``
|
||||
|
||||
:raises TypeError: With a human readable error message, the attribute
|
||||
(of type `attr.Attribute`), the expected interface, and the
|
||||
value it got.
|
||||
"""
|
||||
return _ProvidesValidator(interface)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _OptionalValidator(object):
|
||||
validator = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
if value is None:
|
||||
return
|
||||
|
||||
self.validator(inst, attr, value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<optional validator for {what} or None>".format(
|
||||
what=repr(self.validator)
|
||||
)
|
||||
|
||||
|
||||
def optional(validator):
|
||||
"""
|
||||
A validator that makes an attribute optional. An optional attribute is one
|
||||
which can be set to ``None`` in addition to satisfying the requirements of
|
||||
the sub-validator.
|
||||
|
||||
:param validator: A validator (or a list of validators) that is used for
|
||||
non-``None`` values.
|
||||
:type validator: callable or `list` of callables.
|
||||
|
||||
.. versionadded:: 15.1.0
|
||||
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
|
||||
"""
|
||||
if isinstance(validator, list):
|
||||
return _OptionalValidator(_AndValidator(validator))
|
||||
return _OptionalValidator(validator)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _InValidator(object):
|
||||
options = attrib()
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
try:
|
||||
in_options = value in self.options
|
||||
except TypeError: # e.g. `1 in "abc"`
|
||||
in_options = False
|
||||
|
||||
if not in_options:
|
||||
raise ValueError(
|
||||
"'{name}' must be in {options!r} (got {value!r})".format(
|
||||
name=attr.name, options=self.options, value=value
|
||||
)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<in_ validator with options {options!r}>".format(
|
||||
options=self.options
|
||||
)
|
||||
|
||||
|
||||
def in_(options):
|
||||
"""
|
||||
A validator that raises a `ValueError` if the initializer is called
|
||||
with a value that does not belong in the options provided. The check is
|
||||
performed using ``value in options``.
|
||||
|
||||
:param options: Allowed options.
|
||||
:type options: list, tuple, `enum.Enum`, ...
|
||||
|
||||
:raises ValueError: With a human readable error message, the attribute (of
|
||||
type `attr.Attribute`), the expected options, and the value it
|
||||
got.
|
||||
|
||||
.. versionadded:: 17.1.0
|
||||
"""
|
||||
return _InValidator(options)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=False, hash=True)
|
||||
class _IsCallableValidator(object):
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if not callable(value):
|
||||
message = (
|
||||
"'{name}' must be callable "
|
||||
"(got {value!r} that is a {actual!r})."
|
||||
)
|
||||
raise NotCallableError(
|
||||
msg=message.format(
|
||||
name=attr.name, value=value, actual=value.__class__
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<is_callable validator>"
|
||||
|
||||
|
||||
def is_callable():
|
||||
"""
|
||||
A validator that raises a `attr.exceptions.NotCallableError` if the
|
||||
initializer is called with a value for this particular attribute
|
||||
that is not callable.
|
||||
|
||||
.. versionadded:: 19.1.0
|
||||
|
||||
:raises `attr.exceptions.NotCallableError`: With a human readable error
|
||||
message containing the attribute (`attr.Attribute`) name,
|
||||
and the value it got.
|
||||
"""
|
||||
return _IsCallableValidator()
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _DeepIterable(object):
|
||||
member_validator = attrib(validator=is_callable())
|
||||
iterable_validator = attrib(
|
||||
default=None, validator=optional(is_callable())
|
||||
)
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if self.iterable_validator is not None:
|
||||
self.iterable_validator(inst, attr, value)
|
||||
|
||||
for member in value:
|
||||
self.member_validator(inst, attr, member)
|
||||
|
||||
def __repr__(self):
|
||||
iterable_identifier = (
|
||||
""
|
||||
if self.iterable_validator is None
|
||||
else " {iterable!r}".format(iterable=self.iterable_validator)
|
||||
)
|
||||
return (
|
||||
"<deep_iterable validator for{iterable_identifier}"
|
||||
" iterables of {member!r}>"
|
||||
).format(
|
||||
iterable_identifier=iterable_identifier,
|
||||
member=self.member_validator,
|
||||
)
|
||||
|
||||
|
||||
def deep_iterable(member_validator, iterable_validator=None):
|
||||
"""
|
||||
A validator that performs deep validation of an iterable.
|
||||
|
||||
:param member_validator: Validator to apply to iterable members
|
||||
:param iterable_validator: Validator to apply to iterable itself
|
||||
(optional)
|
||||
|
||||
.. versionadded:: 19.1.0
|
||||
|
||||
:raises TypeError: if any sub-validators fail
|
||||
"""
|
||||
return _DeepIterable(member_validator, iterable_validator)
|
||||
|
||||
|
||||
@attrs(repr=False, slots=True, hash=True)
|
||||
class _DeepMapping(object):
|
||||
key_validator = attrib(validator=is_callable())
|
||||
value_validator = attrib(validator=is_callable())
|
||||
mapping_validator = attrib(default=None, validator=optional(is_callable()))
|
||||
|
||||
def __call__(self, inst, attr, value):
|
||||
"""
|
||||
We use a callable class to be able to change the ``__repr__``.
|
||||
"""
|
||||
if self.mapping_validator is not None:
|
||||
self.mapping_validator(inst, attr, value)
|
||||
|
||||
for key in value:
|
||||
self.key_validator(inst, attr, key)
|
||||
self.value_validator(inst, attr, value[key])
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"<deep_mapping validator for objects mapping {key!r} to {value!r}>"
|
||||
).format(key=self.key_validator, value=self.value_validator)
|
||||
|
||||
|
||||
def deep_mapping(key_validator, value_validator, mapping_validator=None):
|
||||
"""
|
||||
A validator that performs deep validation of a dictionary.
|
||||
|
||||
:param key_validator: Validator to apply to dictionary keys
|
||||
:param value_validator: Validator to apply to dictionary values
|
||||
:param mapping_validator: Validator to apply to top-level mapping
|
||||
attribute (optional)
|
||||
|
||||
.. versionadded:: 19.1.0
|
||||
|
||||
:raises TypeError: if any sub-validators fail
|
||||
"""
|
||||
return _DeepMapping(key_validator, value_validator, mapping_validator)
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
Container,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Match,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from . import _ValidatorType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_T1 = TypeVar("_T1")
|
||||
_T2 = TypeVar("_T2")
|
||||
_T3 = TypeVar("_T3")
|
||||
_I = TypeVar("_I", bound=Iterable)
|
||||
_K = TypeVar("_K")
|
||||
_V = TypeVar("_V")
|
||||
_M = TypeVar("_M", bound=Mapping)
|
||||
|
||||
# To be more precise on instance_of use some overloads.
|
||||
# If there are more than 3 items in the tuple then we fall back to Any
|
||||
@overload
|
||||
def instance_of(type: Type[_T]) -> _ValidatorType[_T]: ...
|
||||
@overload
|
||||
def instance_of(type: Tuple[Type[_T]]) -> _ValidatorType[_T]: ...
|
||||
@overload
|
||||
def instance_of(
|
||||
type: Tuple[Type[_T1], Type[_T2]]
|
||||
) -> _ValidatorType[Union[_T1, _T2]]: ...
|
||||
@overload
|
||||
def instance_of(
|
||||
type: Tuple[Type[_T1], Type[_T2], Type[_T3]]
|
||||
) -> _ValidatorType[Union[_T1, _T2, _T3]]: ...
|
||||
@overload
|
||||
def instance_of(type: Tuple[type, ...]) -> _ValidatorType[Any]: ...
|
||||
def provides(interface: Any) -> _ValidatorType[Any]: ...
|
||||
def optional(
|
||||
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
|
||||
) -> _ValidatorType[Optional[_T]]: ...
|
||||
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
|
||||
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
|
||||
def matches_re(
|
||||
regex: AnyStr,
|
||||
flags: int = ...,
|
||||
func: Optional[
|
||||
Callable[[AnyStr, AnyStr, int], Optional[Match[AnyStr]]]
|
||||
] = ...,
|
||||
) -> _ValidatorType[AnyStr]: ...
|
||||
def deep_iterable(
|
||||
member_validator: _ValidatorType[_T],
|
||||
iterable_validator: Optional[_ValidatorType[_I]] = ...,
|
||||
) -> _ValidatorType[_I]: ...
|
||||
def deep_mapping(
|
||||
key_validator: _ValidatorType[_K],
|
||||
value_validator: _ValidatorType[_V],
|
||||
mapping_validator: Optional[_ValidatorType[_M]] = ...,
|
||||
) -> _ValidatorType[_M]: ...
|
||||
def is_callable() -> _ValidatorType[_T]: ...
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
"""
|
||||
Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
# Set default logging handler to avoid "No handler found" warnings.
|
||||
import logging
|
||||
import warnings
|
||||
from logging import NullHandler
|
||||
|
||||
from . import exceptions
|
||||
from ._version import __version__
|
||||
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
|
||||
from .filepost import encode_multipart_formdata
|
||||
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
|
||||
from .response import HTTPResponse
|
||||
from .util.request import make_headers
|
||||
from .util.retry import Retry
|
||||
from .util.timeout import Timeout
|
||||
from .util.url import get_host
|
||||
|
||||
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
|
||||
__license__ = "MIT"
|
||||
__version__ = __version__
|
||||
|
||||
__all__ = (
|
||||
"HTTPConnectionPool",
|
||||
"HTTPSConnectionPool",
|
||||
"PoolManager",
|
||||
"ProxyManager",
|
||||
"HTTPResponse",
|
||||
"Retry",
|
||||
"Timeout",
|
||||
"add_stderr_logger",
|
||||
"connection_from_url",
|
||||
"disable_warnings",
|
||||
"encode_multipart_formdata",
|
||||
"get_host",
|
||||
"make_headers",
|
||||
"proxy_from_url",
|
||||
)
|
||||
|
||||
logging.getLogger(__name__).addHandler(NullHandler())
|
||||
|
||||
|
||||
def add_stderr_logger(level=logging.DEBUG):
|
||||
"""
|
||||
Helper for quickly adding a StreamHandler to the logger. Useful for
|
||||
debugging.
|
||||
|
||||
Returns the handler after adding it.
|
||||
"""
|
||||
# This method needs to be in this __init__.py to get the __name__ correct
|
||||
# even if urllib3 is vendored within another package.
|
||||
logger = logging.getLogger(__name__)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(level)
|
||||
logger.debug("Added a stderr logging handler to logger: %s", __name__)
|
||||
return handler
|
||||
|
||||
|
||||
# ... Clean up.
|
||||
del NullHandler
|
||||
|
||||
|
||||
# All warning filters *must* be appended unless you're really certain that they
|
||||
# shouldn't be: otherwise, it's very hard for users to use most Python
|
||||
# mechanisms to silence them.
|
||||
# SecurityWarning's always go off by default.
|
||||
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
|
||||
# SubjectAltNameWarning's should go off once per host
|
||||
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
|
||||
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
|
||||
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
|
||||
# SNIMissingWarnings should go off only once.
|
||||
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
|
||||
|
||||
|
||||
def disable_warnings(category=exceptions.HTTPWarning):
|
||||
"""
|
||||
Helper for quickly disabling all urllib3 warnings.
|
||||
"""
|
||||
warnings.simplefilter("ignore", category)
|
||||
|
|
@ -1,337 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
except ImportError:
|
||||
from collections import Mapping, MutableMapping
|
||||
try:
|
||||
from threading import RLock
|
||||
except ImportError: # Platform-specific: No threads available
|
||||
|
||||
class RLock:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from .exceptions import InvalidHeader
|
||||
from .packages import six
|
||||
from .packages.six import iterkeys, itervalues
|
||||
|
||||
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
|
||||
|
||||
|
||||
_Null = object()
|
||||
|
||||
|
||||
class RecentlyUsedContainer(MutableMapping):
|
||||
"""
|
||||
Provides a thread-safe dict-like container which maintains up to
|
||||
``maxsize`` keys while throwing away the least-recently-used keys beyond
|
||||
``maxsize``.
|
||||
|
||||
:param maxsize:
|
||||
Maximum number of recent elements to retain.
|
||||
|
||||
:param dispose_func:
|
||||
Every time an item is evicted from the container,
|
||||
``dispose_func(value)`` is called. Callback which will get called
|
||||
"""
|
||||
|
||||
ContainerCls = OrderedDict
|
||||
|
||||
def __init__(self, maxsize=10, dispose_func=None):
|
||||
self._maxsize = maxsize
|
||||
self.dispose_func = dispose_func
|
||||
|
||||
self._container = self.ContainerCls()
|
||||
self.lock = RLock()
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Re-insert the item, moving it to the end of the eviction line.
|
||||
with self.lock:
|
||||
item = self._container.pop(key)
|
||||
self._container[key] = item
|
||||
return item
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
evicted_value = _Null
|
||||
with self.lock:
|
||||
# Possibly evict the existing value of 'key'
|
||||
evicted_value = self._container.get(key, _Null)
|
||||
self._container[key] = value
|
||||
|
||||
# If we didn't evict an existing value, we might have to evict the
|
||||
# least recently used item from the beginning of the container.
|
||||
if len(self._container) > self._maxsize:
|
||||
_key, evicted_value = self._container.popitem(last=False)
|
||||
|
||||
if self.dispose_func and evicted_value is not _Null:
|
||||
self.dispose_func(evicted_value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self.lock:
|
||||
value = self._container.pop(key)
|
||||
|
||||
if self.dispose_func:
|
||||
self.dispose_func(value)
|
||||
|
||||
def __len__(self):
|
||||
with self.lock:
|
||||
return len(self._container)
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError(
|
||||
"Iteration over this class is unlikely to be threadsafe."
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
with self.lock:
|
||||
# Copy pointers to all values, then wipe the mapping
|
||||
values = list(itervalues(self._container))
|
||||
self._container.clear()
|
||||
|
||||
if self.dispose_func:
|
||||
for value in values:
|
||||
self.dispose_func(value)
|
||||
|
||||
def keys(self):
|
||||
with self.lock:
|
||||
return list(iterkeys(self._container))
|
||||
|
||||
|
||||
class HTTPHeaderDict(MutableMapping):
|
||||
"""
|
||||
:param headers:
|
||||
An iterable of field-value pairs. Must not contain multiple field names
|
||||
when compared case-insensitively.
|
||||
|
||||
:param kwargs:
|
||||
Additional field-value pairs to pass in to ``dict.update``.
|
||||
|
||||
A ``dict`` like container for storing HTTP Headers.
|
||||
|
||||
Field names are stored and compared case-insensitively in compliance with
|
||||
RFC 7230. Iteration provides the first case-sensitive key seen for each
|
||||
case-insensitive pair.
|
||||
|
||||
Using ``__setitem__`` syntax overwrites fields that compare equal
|
||||
case-insensitively in order to maintain ``dict``'s api. For fields that
|
||||
compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
|
||||
in a loop.
|
||||
|
||||
If multiple fields that are equal case-insensitively are passed to the
|
||||
constructor or ``.update``, the behavior is undefined and some will be
|
||||
lost.
|
||||
|
||||
>>> headers = HTTPHeaderDict()
|
||||
>>> headers.add('Set-Cookie', 'foo=bar')
|
||||
>>> headers.add('set-cookie', 'baz=quxx')
|
||||
>>> headers['content-length'] = '7'
|
||||
>>> headers['SET-cookie']
|
||||
'foo=bar, baz=quxx'
|
||||
>>> headers['Content-Length']
|
||||
'7'
|
||||
"""
|
||||
|
||||
def __init__(self, headers=None, **kwargs):
|
||||
super(HTTPHeaderDict, self).__init__()
|
||||
self._container = OrderedDict()
|
||||
if headers is not None:
|
||||
if isinstance(headers, HTTPHeaderDict):
|
||||
self._copy_from(headers)
|
||||
else:
|
||||
self.extend(headers)
|
||||
if kwargs:
|
||||
self.extend(kwargs)
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
self._container[key.lower()] = [key, val]
|
||||
return self._container[key.lower()]
|
||||
|
||||
def __getitem__(self, key):
|
||||
val = self._container[key.lower()]
|
||||
return ", ".join(val[1:])
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._container[key.lower()]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key.lower() in self._container
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
|
||||
return False
|
||||
if not isinstance(other, type(self)):
|
||||
other = type(self)(other)
|
||||
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
|
||||
(k.lower(), v) for k, v in other.itermerged()
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
if six.PY2: # Python 2
|
||||
iterkeys = MutableMapping.iterkeys
|
||||
itervalues = MutableMapping.itervalues
|
||||
|
||||
__marker = object()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._container)
|
||||
|
||||
def __iter__(self):
|
||||
# Only provide the originally cased names
|
||||
for vals in self._container.values():
|
||||
yield vals[0]
|
||||
|
||||
def pop(self, key, default=__marker):
|
||||
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
|
||||
If key is not found, d is returned if given, otherwise KeyError is raised.
|
||||
"""
|
||||
# Using the MutableMapping function directly fails due to the private marker.
|
||||
# Using ordinary dict.pop would expose the internal structures.
|
||||
# So let's reinvent the wheel.
|
||||
try:
|
||||
value = self[key]
|
||||
except KeyError:
|
||||
if default is self.__marker:
|
||||
raise
|
||||
return default
|
||||
else:
|
||||
del self[key]
|
||||
return value
|
||||
|
||||
def discard(self, key):
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def add(self, key, val):
|
||||
"""Adds a (name, value) pair, doesn't overwrite the value if it already
|
||||
exists.
|
||||
|
||||
>>> headers = HTTPHeaderDict(foo='bar')
|
||||
>>> headers.add('Foo', 'baz')
|
||||
>>> headers['foo']
|
||||
'bar, baz'
|
||||
"""
|
||||
key_lower = key.lower()
|
||||
new_vals = [key, val]
|
||||
# Keep the common case aka no item present as fast as possible
|
||||
vals = self._container.setdefault(key_lower, new_vals)
|
||||
if new_vals is not vals:
|
||||
vals.append(val)
|
||||
|
||||
def extend(self, *args, **kwargs):
|
||||
"""Generic import function for any type of header-like object.
|
||||
Adapted version of MutableMapping.update in order to insert items
|
||||
with self.add instead of self.__setitem__
|
||||
"""
|
||||
if len(args) > 1:
|
||||
raise TypeError(
|
||||
"extend() takes at most 1 positional "
|
||||
"arguments ({0} given)".format(len(args))
|
||||
)
|
||||
other = args[0] if len(args) >= 1 else ()
|
||||
|
||||
if isinstance(other, HTTPHeaderDict):
|
||||
for key, val in other.iteritems():
|
||||
self.add(key, val)
|
||||
elif isinstance(other, Mapping):
|
||||
for key in other:
|
||||
self.add(key, other[key])
|
||||
elif hasattr(other, "keys"):
|
||||
for key in other.keys():
|
||||
self.add(key, other[key])
|
||||
else:
|
||||
for key, value in other:
|
||||
self.add(key, value)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.add(key, value)
|
||||
|
||||
def getlist(self, key, default=__marker):
|
||||
"""Returns a list of all the values for the named field. Returns an
|
||||
empty list if the key doesn't exist."""
|
||||
try:
|
||||
vals = self._container[key.lower()]
|
||||
except KeyError:
|
||||
if default is self.__marker:
|
||||
return []
|
||||
return default
|
||||
else:
|
||||
return vals[1:]
|
||||
|
||||
# Backwards compatibility for httplib
|
||||
getheaders = getlist
|
||||
getallmatchingheaders = getlist
|
||||
iget = getlist
|
||||
|
||||
# Backwards compatibility for http.cookiejar
|
||||
get_all = getlist
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (type(self).__name__, dict(self.itermerged()))
|
||||
|
||||
def _copy_from(self, other):
|
||||
for key in other:
|
||||
val = other.getlist(key)
|
||||
if isinstance(val, list):
|
||||
# Don't need to convert tuples
|
||||
val = list(val)
|
||||
self._container[key.lower()] = [key] + val
|
||||
|
||||
def copy(self):
|
||||
clone = type(self)()
|
||||
clone._copy_from(self)
|
||||
return clone
|
||||
|
||||
def iteritems(self):
|
||||
"""Iterate over all header lines, including duplicate ones."""
|
||||
for key in self:
|
||||
vals = self._container[key.lower()]
|
||||
for val in vals[1:]:
|
||||
yield vals[0], val
|
||||
|
||||
def itermerged(self):
|
||||
"""Iterate over all headers, merging duplicate ones together."""
|
||||
for key in self:
|
||||
val = self._container[key.lower()]
|
||||
yield val[0], ", ".join(val[1:])
|
||||
|
||||
def items(self):
|
||||
return list(self.iteritems())
|
||||
|
||||
@classmethod
|
||||
def from_httplib(cls, message): # Python 2
|
||||
"""Read headers from a Python 2 httplib message object."""
|
||||
# python2.7 does not expose a proper API for exporting multiheaders
|
||||
# efficiently. This function re-reads raw lines from the message
|
||||
# object and extracts the multiheaders properly.
|
||||
obs_fold_continued_leaders = (" ", "\t")
|
||||
headers = []
|
||||
|
||||
for line in message.headers:
|
||||
if line.startswith(obs_fold_continued_leaders):
|
||||
if not headers:
|
||||
# We received a header line that starts with OWS as described
|
||||
# in RFC-7230 S3.2.4. This indicates a multiline header, but
|
||||
# there exists no previous header to which we can attach it.
|
||||
raise InvalidHeader(
|
||||
"Header continuation with no previous header: %s" % line
|
||||
)
|
||||
else:
|
||||
key, value = headers[-1]
|
||||
headers[-1] = (key, value + " " + line.strip())
|
||||
continue
|
||||
|
||||
key, value = line.split(":", 1)
|
||||
headers.append((key, value.strip()))
|
||||
|
||||
return cls(headers)
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
# This file is protected via CODEOWNERS
|
||||
__version__ = "1.26.6"
|
||||
|
|
@ -1,539 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import warnings
|
||||
from socket import error as SocketError
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
from .packages import six
|
||||
from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection
|
||||
from .packages.six.moves.http_client import HTTPException # noqa: F401
|
||||
from .util.proxy import create_proxy_ssl_context
|
||||
|
||||
try: # Compiled with SSL?
|
||||
import ssl
|
||||
|
||||
BaseSSLError = ssl.SSLError
|
||||
except (ImportError, AttributeError): # Platform-specific: No SSL.
|
||||
ssl = None
|
||||
|
||||
class BaseSSLError(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
# Python 3: not a no-op, we're adding this to the namespace so it can be imported.
|
||||
ConnectionError = ConnectionError
|
||||
except NameError:
|
||||
# Python 2
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
try: # Python 3:
|
||||
# Not a no-op, we're adding this to the namespace so it can be imported.
|
||||
BrokenPipeError = BrokenPipeError
|
||||
except NameError: # Python 2:
|
||||
|
||||
class BrokenPipeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
from ._collections import HTTPHeaderDict # noqa (historical, removed in v2)
|
||||
from ._version import __version__
|
||||
from .exceptions import (
|
||||
ConnectTimeoutError,
|
||||
NewConnectionError,
|
||||
SubjectAltNameWarning,
|
||||
SystemTimeWarning,
|
||||
)
|
||||
from .packages.ssl_match_hostname import CertificateError, match_hostname
|
||||
from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection
|
||||
from .util.ssl_ import (
|
||||
assert_fingerprint,
|
||||
create_urllib3_context,
|
||||
resolve_cert_reqs,
|
||||
resolve_ssl_version,
|
||||
ssl_wrap_socket,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
port_by_scheme = {"http": 80, "https": 443}
|
||||
|
||||
# When it comes time to update this value as a part of regular maintenance
|
||||
# (ie test_recent_date is failing) update it to ~6 months before the current date.
|
||||
RECENT_DATE = datetime.date(2020, 7, 1)
|
||||
|
||||
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
|
||||
|
||||
|
||||
class HTTPConnection(_HTTPConnection, object):
|
||||
"""
|
||||
Based on :class:`http.client.HTTPConnection` but provides an extra constructor
|
||||
backwards-compatibility layer between older and newer Pythons.
|
||||
|
||||
Additional keyword parameters are used to configure attributes of the connection.
|
||||
Accepted parameters include:
|
||||
|
||||
- ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool`
|
||||
- ``source_address``: Set the source address for the current connection.
|
||||
- ``socket_options``: Set specific options on the underlying socket. If not specified, then
|
||||
defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling
|
||||
Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy.
|
||||
|
||||
For example, if you wish to enable TCP Keep Alive in addition to the defaults,
|
||||
you might pass:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
HTTPConnection.default_socket_options + [
|
||||
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
||||
]
|
||||
|
||||
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
|
||||
"""
|
||||
|
||||
default_port = port_by_scheme["http"]
|
||||
|
||||
#: Disable Nagle's algorithm by default.
|
||||
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
|
||||
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
|
||||
|
||||
#: Whether this connection verifies the host's certificate.
|
||||
is_verified = False
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
if not six.PY2:
|
||||
kw.pop("strict", None)
|
||||
|
||||
# Pre-set source_address.
|
||||
self.source_address = kw.get("source_address")
|
||||
|
||||
#: The socket options provided by the user. If no options are
|
||||
#: provided, we use the default options.
|
||||
self.socket_options = kw.pop("socket_options", self.default_socket_options)
|
||||
|
||||
# Proxy options provided by the user.
|
||||
self.proxy = kw.pop("proxy", None)
|
||||
self.proxy_config = kw.pop("proxy_config", None)
|
||||
|
||||
_HTTPConnection.__init__(self, *args, **kw)
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
"""
|
||||
Getter method to remove any trailing dots that indicate the hostname is an FQDN.
|
||||
|
||||
In general, SSL certificates don't include the trailing dot indicating a
|
||||
fully-qualified domain name, and thus, they don't validate properly when
|
||||
checked against a domain name that includes the dot. In addition, some
|
||||
servers may not expect to receive the trailing dot when provided.
|
||||
|
||||
However, the hostname with trailing dot is critical to DNS resolution; doing a
|
||||
lookup with the trailing dot will properly only resolve the appropriate FQDN,
|
||||
whereas a lookup without a trailing dot will search the system's search domain
|
||||
list. Thus, it's important to keep the original host around for use only in
|
||||
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
|
||||
actual TCP connection across which we're going to send HTTP requests).
|
||||
"""
|
||||
return self._dns_host.rstrip(".")
|
||||
|
||||
@host.setter
|
||||
def host(self, value):
|
||||
"""
|
||||
Setter for the `host` property.
|
||||
|
||||
We assume that only urllib3 uses the _dns_host attribute; httplib itself
|
||||
only uses `host`, and it seems reasonable that other libraries follow suit.
|
||||
"""
|
||||
self._dns_host = value
|
||||
|
||||
def _new_conn(self):
|
||||
"""Establish a socket connection and set nodelay settings on it.
|
||||
|
||||
:return: New socket connection.
|
||||
"""
|
||||
extra_kw = {}
|
||||
if self.source_address:
|
||||
extra_kw["source_address"] = self.source_address
|
||||
|
||||
if self.socket_options:
|
||||
extra_kw["socket_options"] = self.socket_options
|
||||
|
||||
try:
|
||||
conn = connection.create_connection(
|
||||
(self._dns_host, self.port), self.timeout, **extra_kw
|
||||
)
|
||||
|
||||
except SocketTimeout:
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
"Connection to %s timed out. (connect timeout=%s)"
|
||||
% (self.host, self.timeout),
|
||||
)
|
||||
|
||||
except SocketError as e:
|
||||
raise NewConnectionError(
|
||||
self, "Failed to establish a new connection: %s" % e
|
||||
)
|
||||
|
||||
return conn
|
||||
|
||||
def _is_using_tunnel(self):
|
||||
# Google App Engine's httplib does not define _tunnel_host
|
||||
return getattr(self, "_tunnel_host", None)
|
||||
|
||||
def _prepare_conn(self, conn):
|
||||
self.sock = conn
|
||||
if self._is_using_tunnel():
|
||||
# TODO: Fix tunnel so it doesn't depend on self.sock state.
|
||||
self._tunnel()
|
||||
# Mark this connection as not reusable
|
||||
self.auto_open = 0
|
||||
|
||||
def connect(self):
|
||||
conn = self._new_conn()
|
||||
self._prepare_conn(conn)
|
||||
|
||||
def putrequest(self, method, url, *args, **kwargs):
|
||||
""" """
|
||||
# Empty docstring because the indentation of CPython's implementation
|
||||
# is broken but we don't want this method in our documentation.
|
||||
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
|
||||
if match:
|
||||
raise ValueError(
|
||||
"Method cannot contain non-token characters %r (found at least %r)"
|
||||
% (method, match.group())
|
||||
)
|
||||
|
||||
return _HTTPConnection.putrequest(self, method, url, *args, **kwargs)
|
||||
|
||||
def putheader(self, header, *values):
|
||||
""" """
|
||||
if not any(isinstance(v, str) and v == SKIP_HEADER for v in values):
|
||||
_HTTPConnection.putheader(self, header, *values)
|
||||
elif six.ensure_str(header.lower()) not in SKIPPABLE_HEADERS:
|
||||
raise ValueError(
|
||||
"urllib3.util.SKIP_HEADER only supports '%s'"
|
||||
% ("', '".join(map(str.title, sorted(SKIPPABLE_HEADERS))),)
|
||||
)
|
||||
|
||||
def request(self, method, url, body=None, headers=None):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
else:
|
||||
# Avoid modifying the headers passed into .request()
|
||||
headers = headers.copy()
|
||||
if "user-agent" not in (six.ensure_str(k.lower()) for k in headers):
|
||||
headers["User-Agent"] = _get_default_user_agent()
|
||||
super(HTTPConnection, self).request(method, url, body=body, headers=headers)
|
||||
|
||||
def request_chunked(self, method, url, body=None, headers=None):
|
||||
"""
|
||||
Alternative to the common request method, which sends the
|
||||
body with chunked encoding and not as one block
|
||||
"""
|
||||
headers = headers or {}
|
||||
header_keys = set([six.ensure_str(k.lower()) for k in headers])
|
||||
skip_accept_encoding = "accept-encoding" in header_keys
|
||||
skip_host = "host" in header_keys
|
||||
self.putrequest(
|
||||
method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
|
||||
)
|
||||
if "user-agent" not in header_keys:
|
||||
self.putheader("User-Agent", _get_default_user_agent())
|
||||
for header, value in headers.items():
|
||||
self.putheader(header, value)
|
||||
if "transfer-encoding" not in header_keys:
|
||||
self.putheader("Transfer-Encoding", "chunked")
|
||||
self.endheaders()
|
||||
|
||||
if body is not None:
|
||||
stringish_types = six.string_types + (bytes,)
|
||||
if isinstance(body, stringish_types):
|
||||
body = (body,)
|
||||
for chunk in body:
|
||||
if not chunk:
|
||||
continue
|
||||
if not isinstance(chunk, bytes):
|
||||
chunk = chunk.encode("utf8")
|
||||
len_str = hex(len(chunk))[2:]
|
||||
to_send = bytearray(len_str.encode())
|
||||
to_send += b"\r\n"
|
||||
to_send += chunk
|
||||
to_send += b"\r\n"
|
||||
self.send(to_send)
|
||||
|
||||
# After the if clause, to always have a closed body
|
||||
self.send(b"0\r\n\r\n")
|
||||
|
||||
|
||||
class HTTPSConnection(HTTPConnection):
|
||||
"""
|
||||
Many of the parameters to this constructor are passed to the underlying SSL
|
||||
socket by means of :py:func:`urllib3.util.ssl_wrap_socket`.
|
||||
"""
|
||||
|
||||
default_port = port_by_scheme["https"]
|
||||
|
||||
cert_reqs = None
|
||||
ca_certs = None
|
||||
ca_cert_dir = None
|
||||
ca_cert_data = None
|
||||
ssl_version = None
|
||||
assert_fingerprint = None
|
||||
tls_in_tls_required = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port=None,
|
||||
key_file=None,
|
||||
cert_file=None,
|
||||
key_password=None,
|
||||
strict=None,
|
||||
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
||||
ssl_context=None,
|
||||
server_hostname=None,
|
||||
**kw
|
||||
):
|
||||
|
||||
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
|
||||
|
||||
self.key_file = key_file
|
||||
self.cert_file = cert_file
|
||||
self.key_password = key_password
|
||||
self.ssl_context = ssl_context
|
||||
self.server_hostname = server_hostname
|
||||
|
||||
# Required property for Google AppEngine 1.9.0 which otherwise causes
|
||||
# HTTPS requests to go out as HTTP. (See Issue #356)
|
||||
self._protocol = "https"
|
||||
|
||||
def set_cert(
|
||||
self,
|
||||
key_file=None,
|
||||
cert_file=None,
|
||||
cert_reqs=None,
|
||||
key_password=None,
|
||||
ca_certs=None,
|
||||
assert_hostname=None,
|
||||
assert_fingerprint=None,
|
||||
ca_cert_dir=None,
|
||||
ca_cert_data=None,
|
||||
):
|
||||
"""
|
||||
This method should only be called once, before the connection is used.
|
||||
"""
|
||||
# If cert_reqs is not provided we'll assume CERT_REQUIRED unless we also
|
||||
# have an SSLContext object in which case we'll use its verify_mode.
|
||||
if cert_reqs is None:
|
||||
if self.ssl_context is not None:
|
||||
cert_reqs = self.ssl_context.verify_mode
|
||||
else:
|
||||
cert_reqs = resolve_cert_reqs(None)
|
||||
|
||||
self.key_file = key_file
|
||||
self.cert_file = cert_file
|
||||
self.cert_reqs = cert_reqs
|
||||
self.key_password = key_password
|
||||
self.assert_hostname = assert_hostname
|
||||
self.assert_fingerprint = assert_fingerprint
|
||||
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
|
||||
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
|
||||
self.ca_cert_data = ca_cert_data
|
||||
|
||||
def connect(self):
|
||||
# Add certificate verification
|
||||
conn = self._new_conn()
|
||||
hostname = self.host
|
||||
tls_in_tls = False
|
||||
|
||||
if self._is_using_tunnel():
|
||||
if self.tls_in_tls_required:
|
||||
conn = self._connect_tls_proxy(hostname, conn)
|
||||
tls_in_tls = True
|
||||
|
||||
self.sock = conn
|
||||
|
||||
# Calls self._set_hostport(), so self.host is
|
||||
# self._tunnel_host below.
|
||||
self._tunnel()
|
||||
# Mark this connection as not reusable
|
||||
self.auto_open = 0
|
||||
|
||||
# Override the host with the one we're requesting data from.
|
||||
hostname = self._tunnel_host
|
||||
|
||||
server_hostname = hostname
|
||||
if self.server_hostname is not None:
|
||||
server_hostname = self.server_hostname
|
||||
|
||||
is_time_off = datetime.date.today() < RECENT_DATE
|
||||
if is_time_off:
|
||||
warnings.warn(
|
||||
(
|
||||
"System time is way off (before {0}). This will probably "
|
||||
"lead to SSL verification errors"
|
||||
).format(RECENT_DATE),
|
||||
SystemTimeWarning,
|
||||
)
|
||||
|
||||
# Wrap socket using verification with the root certs in
|
||||
# trusted_root_certs
|
||||
default_ssl_context = False
|
||||
if self.ssl_context is None:
|
||||
default_ssl_context = True
|
||||
self.ssl_context = create_urllib3_context(
|
||||
ssl_version=resolve_ssl_version(self.ssl_version),
|
||||
cert_reqs=resolve_cert_reqs(self.cert_reqs),
|
||||
)
|
||||
|
||||
context = self.ssl_context
|
||||
context.verify_mode = resolve_cert_reqs(self.cert_reqs)
|
||||
|
||||
# Try to load OS default certs if none are given.
|
||||
# Works well on Windows (requires Python3.4+)
|
||||
if (
|
||||
not self.ca_certs
|
||||
and not self.ca_cert_dir
|
||||
and not self.ca_cert_data
|
||||
and default_ssl_context
|
||||
and hasattr(context, "load_default_certs")
|
||||
):
|
||||
context.load_default_certs()
|
||||
|
||||
self.sock = ssl_wrap_socket(
|
||||
sock=conn,
|
||||
keyfile=self.key_file,
|
||||
certfile=self.cert_file,
|
||||
key_password=self.key_password,
|
||||
ca_certs=self.ca_certs,
|
||||
ca_cert_dir=self.ca_cert_dir,
|
||||
ca_cert_data=self.ca_cert_data,
|
||||
server_hostname=server_hostname,
|
||||
ssl_context=context,
|
||||
tls_in_tls=tls_in_tls,
|
||||
)
|
||||
|
||||
# If we're using all defaults and the connection
|
||||
# is TLSv1 or TLSv1.1 we throw a DeprecationWarning
|
||||
# for the host.
|
||||
if (
|
||||
default_ssl_context
|
||||
and self.ssl_version is None
|
||||
and hasattr(self.sock, "version")
|
||||
and self.sock.version() in {"TLSv1", "TLSv1.1"}
|
||||
):
|
||||
warnings.warn(
|
||||
"Negotiating TLSv1/TLSv1.1 by default is deprecated "
|
||||
"and will be disabled in urllib3 v2.0.0. Connecting to "
|
||||
"'%s' with '%s' can be enabled by explicitly opting-in "
|
||||
"with 'ssl_version'" % (self.host, self.sock.version()),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.assert_fingerprint:
|
||||
assert_fingerprint(
|
||||
self.sock.getpeercert(binary_form=True), self.assert_fingerprint
|
||||
)
|
||||
elif (
|
||||
context.verify_mode != ssl.CERT_NONE
|
||||
and not getattr(context, "check_hostname", False)
|
||||
and self.assert_hostname is not False
|
||||
):
|
||||
# While urllib3 attempts to always turn off hostname matching from
|
||||
# the TLS library, this cannot always be done. So we check whether
|
||||
# the TLS Library still thinks it's matching hostnames.
|
||||
cert = self.sock.getpeercert()
|
||||
if not cert.get("subjectAltName", ()):
|
||||
warnings.warn(
|
||||
(
|
||||
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
|
||||
"`commonName` for now. This feature is being removed by major browsers and "
|
||||
"deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
|
||||
"for details.)".format(hostname)
|
||||
),
|
||||
SubjectAltNameWarning,
|
||||
)
|
||||
_match_hostname(cert, self.assert_hostname or server_hostname)
|
||||
|
||||
self.is_verified = (
|
||||
context.verify_mode == ssl.CERT_REQUIRED
|
||||
or self.assert_fingerprint is not None
|
||||
)
|
||||
|
||||
def _connect_tls_proxy(self, hostname, conn):
|
||||
"""
|
||||
Establish a TLS connection to the proxy using the provided SSL context.
|
||||
"""
|
||||
proxy_config = self.proxy_config
|
||||
ssl_context = proxy_config.ssl_context
|
||||
if ssl_context:
|
||||
# If the user provided a proxy context, we assume CA and client
|
||||
# certificates have already been set
|
||||
return ssl_wrap_socket(
|
||||
sock=conn,
|
||||
server_hostname=hostname,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
|
||||
ssl_context = create_proxy_ssl_context(
|
||||
self.ssl_version,
|
||||
self.cert_reqs,
|
||||
self.ca_certs,
|
||||
self.ca_cert_dir,
|
||||
self.ca_cert_data,
|
||||
)
|
||||
# By default urllib3's SSLContext disables `check_hostname` and uses
|
||||
# a custom check. For proxies we're good with relying on the default
|
||||
# verification.
|
||||
ssl_context.check_hostname = True
|
||||
|
||||
# If no cert was provided, use only the default options for server
|
||||
# certificate validation
|
||||
return ssl_wrap_socket(
|
||||
sock=conn,
|
||||
ca_certs=self.ca_certs,
|
||||
ca_cert_dir=self.ca_cert_dir,
|
||||
ca_cert_data=self.ca_cert_data,
|
||||
server_hostname=hostname,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
|
||||
|
||||
def _match_hostname(cert, asserted_hostname):
|
||||
try:
|
||||
match_hostname(cert, asserted_hostname)
|
||||
except CertificateError as e:
|
||||
log.warning(
|
||||
"Certificate did not match expected hostname: %s. Certificate: %s",
|
||||
asserted_hostname,
|
||||
cert,
|
||||
)
|
||||
# Add cert to exception and reraise so client code can inspect
|
||||
# the cert when catching the exception, if they want to
|
||||
e._peer_cert = cert
|
||||
raise
|
||||
|
||||
|
||||
def _get_default_user_agent():
|
||||
return "python-urllib3/%s" % __version__
|
||||
|
||||
|
||||
class DummyConnection(object):
|
||||
"""Used to detect a failed ConnectionCls import."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
if not ssl:
|
||||
HTTPSConnection = DummyConnection # noqa: F811
|
||||
|
||||
|
||||
VerifiedHTTPSConnection = HTTPSConnection
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,36 +0,0 @@
|
|||
"""
|
||||
This module provides means to detect the App Engine environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def is_appengine():
|
||||
return is_local_appengine() or is_prod_appengine()
|
||||
|
||||
|
||||
def is_appengine_sandbox():
|
||||
"""Reports if the app is running in the first generation sandbox.
|
||||
|
||||
The second generation runtimes are technically still in a sandbox, but it
|
||||
is much less restrictive, so generally you shouldn't need to check for it.
|
||||
see https://cloud.google.com/appengine/docs/standard/runtimes
|
||||
"""
|
||||
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
|
||||
|
||||
|
||||
def is_local_appengine():
|
||||
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
|
||||
"SERVER_SOFTWARE", ""
|
||||
).startswith("Development/")
|
||||
|
||||
|
||||
def is_prod_appengine():
|
||||
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
|
||||
"SERVER_SOFTWARE", ""
|
||||
).startswith("Google App Engine/")
|
||||
|
||||
|
||||
def is_prod_appengine_mvms():
|
||||
"""Deprecated."""
|
||||
return False
|
||||
|
|
@ -1,519 +0,0 @@
|
|||
"""
|
||||
This module uses ctypes to bind a whole bunch of functions and constants from
|
||||
SecureTransport. The goal here is to provide the low-level API to
|
||||
SecureTransport. These are essentially the C-level functions and constants, and
|
||||
they're pretty gross to work with.
|
||||
|
||||
This code is a bastardised version of the code found in Will Bond's oscrypto
|
||||
library. An enormous debt is owed to him for blazing this trail for us. For
|
||||
that reason, this code should be considered to be covered both by urllib3's
|
||||
license and by oscrypto's:
|
||||
|
||||
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import platform
|
||||
from ctypes import (
|
||||
CDLL,
|
||||
CFUNCTYPE,
|
||||
POINTER,
|
||||
c_bool,
|
||||
c_byte,
|
||||
c_char_p,
|
||||
c_int32,
|
||||
c_long,
|
||||
c_size_t,
|
||||
c_uint32,
|
||||
c_ulong,
|
||||
c_void_p,
|
||||
)
|
||||
from ctypes.util import find_library
|
||||
|
||||
from urllib3.packages.six import raise_from
|
||||
|
||||
if platform.system() != "Darwin":
|
||||
raise ImportError("Only macOS is supported")
|
||||
|
||||
version = platform.mac_ver()[0]
|
||||
version_info = tuple(map(int, version.split(".")))
|
||||
if version_info < (10, 8):
|
||||
raise OSError(
|
||||
"Only OS X 10.8 and newer are supported, not %s.%s"
|
||||
% (version_info[0], version_info[1])
|
||||
)
|
||||
|
||||
|
||||
def load_cdll(name, macos10_16_path):
|
||||
"""Loads a CDLL by name, falling back to known path on 10.16+"""
|
||||
try:
|
||||
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
|
||||
# beta being labeled as 10.16.
|
||||
if version_info >= (10, 16):
|
||||
path = macos10_16_path
|
||||
else:
|
||||
path = find_library(name)
|
||||
if not path:
|
||||
raise OSError # Caught and reraised as 'ImportError'
|
||||
return CDLL(path, use_errno=True)
|
||||
except OSError:
|
||||
raise_from(ImportError("The library %s failed to load" % name), None)
|
||||
|
||||
|
||||
Security = load_cdll(
|
||||
"Security", "/System/Library/Frameworks/Security.framework/Security"
|
||||
)
|
||||
CoreFoundation = load_cdll(
|
||||
"CoreFoundation",
|
||||
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
|
||||
)
|
||||
|
||||
|
||||
Boolean = c_bool
|
||||
CFIndex = c_long
|
||||
CFStringEncoding = c_uint32
|
||||
CFData = c_void_p
|
||||
CFString = c_void_p
|
||||
CFArray = c_void_p
|
||||
CFMutableArray = c_void_p
|
||||
CFDictionary = c_void_p
|
||||
CFError = c_void_p
|
||||
CFType = c_void_p
|
||||
CFTypeID = c_ulong
|
||||
|
||||
CFTypeRef = POINTER(CFType)
|
||||
CFAllocatorRef = c_void_p
|
||||
|
||||
OSStatus = c_int32
|
||||
|
||||
CFDataRef = POINTER(CFData)
|
||||
CFStringRef = POINTER(CFString)
|
||||
CFArrayRef = POINTER(CFArray)
|
||||
CFMutableArrayRef = POINTER(CFMutableArray)
|
||||
CFDictionaryRef = POINTER(CFDictionary)
|
||||
CFArrayCallBacks = c_void_p
|
||||
CFDictionaryKeyCallBacks = c_void_p
|
||||
CFDictionaryValueCallBacks = c_void_p
|
||||
|
||||
SecCertificateRef = POINTER(c_void_p)
|
||||
SecExternalFormat = c_uint32
|
||||
SecExternalItemType = c_uint32
|
||||
SecIdentityRef = POINTER(c_void_p)
|
||||
SecItemImportExportFlags = c_uint32
|
||||
SecItemImportExportKeyParameters = c_void_p
|
||||
SecKeychainRef = POINTER(c_void_p)
|
||||
SSLProtocol = c_uint32
|
||||
SSLCipherSuite = c_uint32
|
||||
SSLContextRef = POINTER(c_void_p)
|
||||
SecTrustRef = POINTER(c_void_p)
|
||||
SSLConnectionRef = c_uint32
|
||||
SecTrustResultType = c_uint32
|
||||
SecTrustOptionFlags = c_uint32
|
||||
SSLProtocolSide = c_uint32
|
||||
SSLConnectionType = c_uint32
|
||||
SSLSessionOption = c_uint32
|
||||
|
||||
|
||||
try:
|
||||
Security.SecItemImport.argtypes = [
|
||||
CFDataRef,
|
||||
CFStringRef,
|
||||
POINTER(SecExternalFormat),
|
||||
POINTER(SecExternalItemType),
|
||||
SecItemImportExportFlags,
|
||||
POINTER(SecItemImportExportKeyParameters),
|
||||
SecKeychainRef,
|
||||
POINTER(CFArrayRef),
|
||||
]
|
||||
Security.SecItemImport.restype = OSStatus
|
||||
|
||||
Security.SecCertificateGetTypeID.argtypes = []
|
||||
Security.SecCertificateGetTypeID.restype = CFTypeID
|
||||
|
||||
Security.SecIdentityGetTypeID.argtypes = []
|
||||
Security.SecIdentityGetTypeID.restype = CFTypeID
|
||||
|
||||
Security.SecKeyGetTypeID.argtypes = []
|
||||
Security.SecKeyGetTypeID.restype = CFTypeID
|
||||
|
||||
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
|
||||
Security.SecCertificateCreateWithData.restype = SecCertificateRef
|
||||
|
||||
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
|
||||
Security.SecCertificateCopyData.restype = CFDataRef
|
||||
|
||||
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
|
||||
Security.SecCopyErrorMessageString.restype = CFStringRef
|
||||
|
||||
Security.SecIdentityCreateWithCertificate.argtypes = [
|
||||
CFTypeRef,
|
||||
SecCertificateRef,
|
||||
POINTER(SecIdentityRef),
|
||||
]
|
||||
Security.SecIdentityCreateWithCertificate.restype = OSStatus
|
||||
|
||||
Security.SecKeychainCreate.argtypes = [
|
||||
c_char_p,
|
||||
c_uint32,
|
||||
c_void_p,
|
||||
Boolean,
|
||||
c_void_p,
|
||||
POINTER(SecKeychainRef),
|
||||
]
|
||||
Security.SecKeychainCreate.restype = OSStatus
|
||||
|
||||
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
|
||||
Security.SecKeychainDelete.restype = OSStatus
|
||||
|
||||
Security.SecPKCS12Import.argtypes = [
|
||||
CFDataRef,
|
||||
CFDictionaryRef,
|
||||
POINTER(CFArrayRef),
|
||||
]
|
||||
Security.SecPKCS12Import.restype = OSStatus
|
||||
|
||||
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
|
||||
SSLWriteFunc = CFUNCTYPE(
|
||||
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
|
||||
)
|
||||
|
||||
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
|
||||
Security.SSLSetIOFuncs.restype = OSStatus
|
||||
|
||||
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
|
||||
Security.SSLSetPeerID.restype = OSStatus
|
||||
|
||||
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
|
||||
Security.SSLSetCertificate.restype = OSStatus
|
||||
|
||||
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
|
||||
Security.SSLSetCertificateAuthorities.restype = OSStatus
|
||||
|
||||
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
|
||||
Security.SSLSetConnection.restype = OSStatus
|
||||
|
||||
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
|
||||
Security.SSLSetPeerDomainName.restype = OSStatus
|
||||
|
||||
Security.SSLHandshake.argtypes = [SSLContextRef]
|
||||
Security.SSLHandshake.restype = OSStatus
|
||||
|
||||
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
|
||||
Security.SSLRead.restype = OSStatus
|
||||
|
||||
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
|
||||
Security.SSLWrite.restype = OSStatus
|
||||
|
||||
Security.SSLClose.argtypes = [SSLContextRef]
|
||||
Security.SSLClose.restype = OSStatus
|
||||
|
||||
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
|
||||
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
|
||||
|
||||
Security.SSLGetSupportedCiphers.argtypes = [
|
||||
SSLContextRef,
|
||||
POINTER(SSLCipherSuite),
|
||||
POINTER(c_size_t),
|
||||
]
|
||||
Security.SSLGetSupportedCiphers.restype = OSStatus
|
||||
|
||||
Security.SSLSetEnabledCiphers.argtypes = [
|
||||
SSLContextRef,
|
||||
POINTER(SSLCipherSuite),
|
||||
c_size_t,
|
||||
]
|
||||
Security.SSLSetEnabledCiphers.restype = OSStatus
|
||||
|
||||
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
|
||||
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
|
||||
|
||||
Security.SSLGetEnabledCiphers.argtypes = [
|
||||
SSLContextRef,
|
||||
POINTER(SSLCipherSuite),
|
||||
POINTER(c_size_t),
|
||||
]
|
||||
Security.SSLGetEnabledCiphers.restype = OSStatus
|
||||
|
||||
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
|
||||
Security.SSLGetNegotiatedCipher.restype = OSStatus
|
||||
|
||||
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
|
||||
SSLContextRef,
|
||||
POINTER(SSLProtocol),
|
||||
]
|
||||
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
|
||||
|
||||
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
|
||||
Security.SSLCopyPeerTrust.restype = OSStatus
|
||||
|
||||
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
|
||||
Security.SecTrustSetAnchorCertificates.restype = OSStatus
|
||||
|
||||
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
|
||||
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
|
||||
|
||||
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
|
||||
Security.SecTrustEvaluate.restype = OSStatus
|
||||
|
||||
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
|
||||
Security.SecTrustGetCertificateCount.restype = CFIndex
|
||||
|
||||
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
|
||||
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
|
||||
|
||||
Security.SSLCreateContext.argtypes = [
|
||||
CFAllocatorRef,
|
||||
SSLProtocolSide,
|
||||
SSLConnectionType,
|
||||
]
|
||||
Security.SSLCreateContext.restype = SSLContextRef
|
||||
|
||||
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
|
||||
Security.SSLSetSessionOption.restype = OSStatus
|
||||
|
||||
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
|
||||
Security.SSLSetProtocolVersionMin.restype = OSStatus
|
||||
|
||||
Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
|
||||
Security.SSLSetProtocolVersionMax.restype = OSStatus
|
||||
|
||||
try:
|
||||
Security.SSLSetALPNProtocols.argtypes = [SSLContextRef, CFArrayRef]
|
||||
Security.SSLSetALPNProtocols.restype = OSStatus
|
||||
except AttributeError:
|
||||
# Supported only in 10.12+
|
||||
pass
|
||||
|
||||
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
|
||||
Security.SecCopyErrorMessageString.restype = CFStringRef
|
||||
|
||||
Security.SSLReadFunc = SSLReadFunc
|
||||
Security.SSLWriteFunc = SSLWriteFunc
|
||||
Security.SSLContextRef = SSLContextRef
|
||||
Security.SSLProtocol = SSLProtocol
|
||||
Security.SSLCipherSuite = SSLCipherSuite
|
||||
Security.SecIdentityRef = SecIdentityRef
|
||||
Security.SecKeychainRef = SecKeychainRef
|
||||
Security.SecTrustRef = SecTrustRef
|
||||
Security.SecTrustResultType = SecTrustResultType
|
||||
Security.SecExternalFormat = SecExternalFormat
|
||||
Security.OSStatus = OSStatus
|
||||
|
||||
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
|
||||
Security, "kSecImportExportPassphrase"
|
||||
)
|
||||
Security.kSecImportItemIdentity = CFStringRef.in_dll(
|
||||
Security, "kSecImportItemIdentity"
|
||||
)
|
||||
|
||||
# CoreFoundation time!
|
||||
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
|
||||
CoreFoundation.CFRetain.restype = CFTypeRef
|
||||
|
||||
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
|
||||
CoreFoundation.CFRelease.restype = None
|
||||
|
||||
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
|
||||
CoreFoundation.CFGetTypeID.restype = CFTypeID
|
||||
|
||||
CoreFoundation.CFStringCreateWithCString.argtypes = [
|
||||
CFAllocatorRef,
|
||||
c_char_p,
|
||||
CFStringEncoding,
|
||||
]
|
||||
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
|
||||
|
||||
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
|
||||
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
|
||||
|
||||
CoreFoundation.CFStringGetCString.argtypes = [
|
||||
CFStringRef,
|
||||
c_char_p,
|
||||
CFIndex,
|
||||
CFStringEncoding,
|
||||
]
|
||||
CoreFoundation.CFStringGetCString.restype = c_bool
|
||||
|
||||
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
|
||||
CoreFoundation.CFDataCreate.restype = CFDataRef
|
||||
|
||||
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
|
||||
CoreFoundation.CFDataGetLength.restype = CFIndex
|
||||
|
||||
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
|
||||
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
|
||||
|
||||
CoreFoundation.CFDictionaryCreate.argtypes = [
|
||||
CFAllocatorRef,
|
||||
POINTER(CFTypeRef),
|
||||
POINTER(CFTypeRef),
|
||||
CFIndex,
|
||||
CFDictionaryKeyCallBacks,
|
||||
CFDictionaryValueCallBacks,
|
||||
]
|
||||
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
|
||||
|
||||
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
|
||||
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
|
||||
|
||||
CoreFoundation.CFArrayCreate.argtypes = [
|
||||
CFAllocatorRef,
|
||||
POINTER(CFTypeRef),
|
||||
CFIndex,
|
||||
CFArrayCallBacks,
|
||||
]
|
||||
CoreFoundation.CFArrayCreate.restype = CFArrayRef
|
||||
|
||||
CoreFoundation.CFArrayCreateMutable.argtypes = [
|
||||
CFAllocatorRef,
|
||||
CFIndex,
|
||||
CFArrayCallBacks,
|
||||
]
|
||||
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
|
||||
|
||||
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
|
||||
CoreFoundation.CFArrayAppendValue.restype = None
|
||||
|
||||
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
|
||||
CoreFoundation.CFArrayGetCount.restype = CFIndex
|
||||
|
||||
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
|
||||
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
|
||||
|
||||
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
|
||||
CoreFoundation, "kCFAllocatorDefault"
|
||||
)
|
||||
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
|
||||
CoreFoundation, "kCFTypeArrayCallBacks"
|
||||
)
|
||||
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
|
||||
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
|
||||
)
|
||||
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
|
||||
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
|
||||
)
|
||||
|
||||
CoreFoundation.CFTypeRef = CFTypeRef
|
||||
CoreFoundation.CFArrayRef = CFArrayRef
|
||||
CoreFoundation.CFStringRef = CFStringRef
|
||||
CoreFoundation.CFDictionaryRef = CFDictionaryRef
|
||||
|
||||
except (AttributeError):
|
||||
raise ImportError("Error initializing ctypes")
|
||||
|
||||
|
||||
class CFConst(object):
|
||||
"""
|
||||
A class object that acts as essentially a namespace for CoreFoundation
|
||||
constants.
|
||||
"""
|
||||
|
||||
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
|
||||
|
||||
|
||||
class SecurityConst(object):
|
||||
"""
|
||||
A class object that acts as essentially a namespace for Security constants.
|
||||
"""
|
||||
|
||||
kSSLSessionOptionBreakOnServerAuth = 0
|
||||
|
||||
kSSLProtocol2 = 1
|
||||
kSSLProtocol3 = 2
|
||||
kTLSProtocol1 = 4
|
||||
kTLSProtocol11 = 7
|
||||
kTLSProtocol12 = 8
|
||||
# SecureTransport does not support TLS 1.3 even if there's a constant for it
|
||||
kTLSProtocol13 = 10
|
||||
kTLSProtocolMaxSupported = 999
|
||||
|
||||
kSSLClientSide = 1
|
||||
kSSLStreamType = 0
|
||||
|
||||
kSecFormatPEMSequence = 10
|
||||
|
||||
kSecTrustResultInvalid = 0
|
||||
kSecTrustResultProceed = 1
|
||||
# This gap is present on purpose: this was kSecTrustResultConfirm, which
|
||||
# is deprecated.
|
||||
kSecTrustResultDeny = 3
|
||||
kSecTrustResultUnspecified = 4
|
||||
kSecTrustResultRecoverableTrustFailure = 5
|
||||
kSecTrustResultFatalTrustFailure = 6
|
||||
kSecTrustResultOtherError = 7
|
||||
|
||||
errSSLProtocol = -9800
|
||||
errSSLWouldBlock = -9803
|
||||
errSSLClosedGraceful = -9805
|
||||
errSSLClosedNoNotify = -9816
|
||||
errSSLClosedAbort = -9806
|
||||
|
||||
errSSLXCertChainInvalid = -9807
|
||||
errSSLCrypto = -9809
|
||||
errSSLInternal = -9810
|
||||
errSSLCertExpired = -9814
|
||||
errSSLCertNotYetValid = -9815
|
||||
errSSLUnknownRootCert = -9812
|
||||
errSSLNoRootCert = -9813
|
||||
errSSLHostNameMismatch = -9843
|
||||
errSSLPeerHandshakeFail = -9824
|
||||
errSSLPeerUserCancelled = -9839
|
||||
errSSLWeakPeerEphemeralDHKey = -9850
|
||||
errSSLServerAuthCompleted = -9841
|
||||
errSSLRecordOverflow = -9847
|
||||
|
||||
errSecVerifyFailed = -67808
|
||||
errSecNoTrustSettings = -25263
|
||||
errSecItemNotFound = -25300
|
||||
errSecInvalidTrustSettings = -25262
|
||||
|
||||
# Cipher suites. We only pick the ones our default cipher string allows.
|
||||
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
|
||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
|
||||
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
|
||||
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
|
||||
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
|
||||
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
|
||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
|
||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
|
||||
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
|
||||
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
|
||||
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
|
||||
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
|
||||
TLS_AES_128_GCM_SHA256 = 0x1301
|
||||
TLS_AES_256_GCM_SHA384 = 0x1302
|
||||
TLS_AES_128_CCM_8_SHA256 = 0x1305
|
||||
TLS_AES_128_CCM_SHA256 = 0x1304
|
||||
|
|
@ -1,396 +0,0 @@
|
|||
"""
|
||||
Low-level helpers for the SecureTransport bindings.
|
||||
|
||||
These are Python functions that are not directly related to the high-level APIs
|
||||
but are necessary to get them to work. They include a whole bunch of low-level
|
||||
CoreFoundation messing about and memory management. The concerns in this module
|
||||
are almost entirely about trying to avoid memory leaks and providing
|
||||
appropriate and useful assistance to the higher-level code.
|
||||
"""
|
||||
import base64
|
||||
import ctypes
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import struct
|
||||
import tempfile
|
||||
|
||||
from .bindings import CFConst, CoreFoundation, Security
|
||||
|
||||
# This regular expression is used to grab PEM data out of a PEM bundle.
|
||||
_PEM_CERTS_RE = re.compile(
|
||||
b"-----BEGIN CERTIFICATE-----\n(.*?)\n-----END CERTIFICATE-----", re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
def _cf_data_from_bytes(bytestring):
|
||||
"""
|
||||
Given a bytestring, create a CFData object from it. This CFData object must
|
||||
be CFReleased by the caller.
|
||||
"""
|
||||
return CoreFoundation.CFDataCreate(
|
||||
CoreFoundation.kCFAllocatorDefault, bytestring, len(bytestring)
|
||||
)
|
||||
|
||||
|
||||
def _cf_dictionary_from_tuples(tuples):
|
||||
"""
|
||||
Given a list of Python tuples, create an associated CFDictionary.
|
||||
"""
|
||||
dictionary_size = len(tuples)
|
||||
|
||||
# We need to get the dictionary keys and values out in the same order.
|
||||
keys = (t[0] for t in tuples)
|
||||
values = (t[1] for t in tuples)
|
||||
cf_keys = (CoreFoundation.CFTypeRef * dictionary_size)(*keys)
|
||||
cf_values = (CoreFoundation.CFTypeRef * dictionary_size)(*values)
|
||||
|
||||
return CoreFoundation.CFDictionaryCreate(
|
||||
CoreFoundation.kCFAllocatorDefault,
|
||||
cf_keys,
|
||||
cf_values,
|
||||
dictionary_size,
|
||||
CoreFoundation.kCFTypeDictionaryKeyCallBacks,
|
||||
CoreFoundation.kCFTypeDictionaryValueCallBacks,
|
||||
)
|
||||
|
||||
|
||||
def _cfstr(py_bstr):
|
||||
"""
|
||||
Given a Python binary data, create a CFString.
|
||||
The string must be CFReleased by the caller.
|
||||
"""
|
||||
c_str = ctypes.c_char_p(py_bstr)
|
||||
cf_str = CoreFoundation.CFStringCreateWithCString(
|
||||
CoreFoundation.kCFAllocatorDefault,
|
||||
c_str,
|
||||
CFConst.kCFStringEncodingUTF8,
|
||||
)
|
||||
return cf_str
|
||||
|
||||
|
||||
def _create_cfstring_array(lst):
|
||||
"""
|
||||
Given a list of Python binary data, create an associated CFMutableArray.
|
||||
The array must be CFReleased by the caller.
|
||||
|
||||
Raises an ssl.SSLError on failure.
|
||||
"""
|
||||
cf_arr = None
|
||||
try:
|
||||
cf_arr = CoreFoundation.CFArrayCreateMutable(
|
||||
CoreFoundation.kCFAllocatorDefault,
|
||||
0,
|
||||
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
|
||||
)
|
||||
if not cf_arr:
|
||||
raise MemoryError("Unable to allocate memory!")
|
||||
for item in lst:
|
||||
cf_str = _cfstr(item)
|
||||
if not cf_str:
|
||||
raise MemoryError("Unable to allocate memory!")
|
||||
try:
|
||||
CoreFoundation.CFArrayAppendValue(cf_arr, cf_str)
|
||||
finally:
|
||||
CoreFoundation.CFRelease(cf_str)
|
||||
except BaseException as e:
|
||||
if cf_arr:
|
||||
CoreFoundation.CFRelease(cf_arr)
|
||||
raise ssl.SSLError("Unable to allocate array: %s" % (e,))
|
||||
return cf_arr
|
||||
|
||||
|
||||
def _cf_string_to_unicode(value):
|
||||
"""
|
||||
Creates a Unicode string from a CFString object. Used entirely for error
|
||||
reporting.
|
||||
|
||||
Yes, it annoys me quite a lot that this function is this complex.
|
||||
"""
|
||||
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
|
||||
|
||||
string = CoreFoundation.CFStringGetCStringPtr(
|
||||
value_as_void_p, CFConst.kCFStringEncodingUTF8
|
||||
)
|
||||
if string is None:
|
||||
buffer = ctypes.create_string_buffer(1024)
|
||||
result = CoreFoundation.CFStringGetCString(
|
||||
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
|
||||
)
|
||||
if not result:
|
||||
raise OSError("Error copying C string from CFStringRef")
|
||||
string = buffer.value
|
||||
if string is not None:
|
||||
string = string.decode("utf-8")
|
||||
return string
|
||||
|
||||
|
||||
def _assert_no_error(error, exception_class=None):
|
||||
"""
|
||||
Checks the return code and throws an exception if there is an error to
|
||||
report
|
||||
"""
|
||||
if error == 0:
|
||||
return
|
||||
|
||||
cf_error_string = Security.SecCopyErrorMessageString(error, None)
|
||||
output = _cf_string_to_unicode(cf_error_string)
|
||||
CoreFoundation.CFRelease(cf_error_string)
|
||||
|
||||
if output is None or output == u"":
|
||||
output = u"OSStatus %s" % error
|
||||
|
||||
if exception_class is None:
|
||||
exception_class = ssl.SSLError
|
||||
|
||||
raise exception_class(output)
|
||||
|
||||
|
||||
def _cert_array_from_pem(pem_bundle):
|
||||
"""
|
||||
Given a bundle of certs in PEM format, turns them into a CFArray of certs
|
||||
that can be used to validate a cert chain.
|
||||
"""
|
||||
# Normalize the PEM bundle's line endings.
|
||||
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
|
||||
|
||||
der_certs = [
|
||||
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
|
||||
]
|
||||
if not der_certs:
|
||||
raise ssl.SSLError("No root certificates specified")
|
||||
|
||||
cert_array = CoreFoundation.CFArrayCreateMutable(
|
||||
CoreFoundation.kCFAllocatorDefault,
|
||||
0,
|
||||
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
|
||||
)
|
||||
if not cert_array:
|
||||
raise ssl.SSLError("Unable to allocate memory!")
|
||||
|
||||
try:
|
||||
for der_bytes in der_certs:
|
||||
certdata = _cf_data_from_bytes(der_bytes)
|
||||
if not certdata:
|
||||
raise ssl.SSLError("Unable to allocate memory!")
|
||||
cert = Security.SecCertificateCreateWithData(
|
||||
CoreFoundation.kCFAllocatorDefault, certdata
|
||||
)
|
||||
CoreFoundation.CFRelease(certdata)
|
||||
if not cert:
|
||||
raise ssl.SSLError("Unable to build cert object!")
|
||||
|
||||
CoreFoundation.CFArrayAppendValue(cert_array, cert)
|
||||
CoreFoundation.CFRelease(cert)
|
||||
except Exception:
|
||||
# We need to free the array before the exception bubbles further.
|
||||
# We only want to do that if an error occurs: otherwise, the caller
|
||||
# should free.
|
||||
CoreFoundation.CFRelease(cert_array)
|
||||
|
||||
return cert_array
|
||||
|
||||
|
||||
def _is_cert(item):
|
||||
"""
|
||||
Returns True if a given CFTypeRef is a certificate.
|
||||
"""
|
||||
expected = Security.SecCertificateGetTypeID()
|
||||
return CoreFoundation.CFGetTypeID(item) == expected
|
||||
|
||||
|
||||
def _is_identity(item):
|
||||
"""
|
||||
Returns True if a given CFTypeRef is an identity.
|
||||
"""
|
||||
expected = Security.SecIdentityGetTypeID()
|
||||
return CoreFoundation.CFGetTypeID(item) == expected
|
||||
|
||||
|
||||
def _temporary_keychain():
|
||||
"""
|
||||
This function creates a temporary Mac keychain that we can use to work with
|
||||
credentials. This keychain uses a one-time password and a temporary file to
|
||||
store the data. We expect to have one keychain per socket. The returned
|
||||
SecKeychainRef must be freed by the caller, including calling
|
||||
SecKeychainDelete.
|
||||
|
||||
Returns a tuple of the SecKeychainRef and the path to the temporary
|
||||
directory that contains it.
|
||||
"""
|
||||
# Unfortunately, SecKeychainCreate requires a path to a keychain. This
|
||||
# means we cannot use mkstemp to use a generic temporary file. Instead,
|
||||
# we're going to create a temporary directory and a filename to use there.
|
||||
# This filename will be 8 random bytes expanded into base64. We also need
|
||||
# some random bytes to password-protect the keychain we're creating, so we
|
||||
# ask for 40 random bytes.
|
||||
random_bytes = os.urandom(40)
|
||||
filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
|
||||
password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
|
||||
tempdirectory = tempfile.mkdtemp()
|
||||
|
||||
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
|
||||
|
||||
# We now want to create the keychain itself.
|
||||
keychain = Security.SecKeychainRef()
|
||||
status = Security.SecKeychainCreate(
|
||||
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
|
||||
)
|
||||
_assert_no_error(status)
|
||||
|
||||
# Having created the keychain, we want to pass it off to the caller.
|
||||
return keychain, tempdirectory
|
||||
|
||||
|
||||
def _load_items_from_file(keychain, path):
|
||||
"""
|
||||
Given a single file, loads all the trust objects from it into arrays and
|
||||
the keychain.
|
||||
Returns a tuple of lists: the first list is a list of identities, the
|
||||
second a list of certs.
|
||||
"""
|
||||
certificates = []
|
||||
identities = []
|
||||
result_array = None
|
||||
|
||||
with open(path, "rb") as f:
|
||||
raw_filedata = f.read()
|
||||
|
||||
try:
|
||||
filedata = CoreFoundation.CFDataCreate(
|
||||
CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
|
||||
)
|
||||
result_array = CoreFoundation.CFArrayRef()
|
||||
result = Security.SecItemImport(
|
||||
filedata, # cert data
|
||||
None, # Filename, leaving it out for now
|
||||
None, # What the type of the file is, we don't care
|
||||
None, # what's in the file, we don't care
|
||||
0, # import flags
|
||||
None, # key params, can include passphrase in the future
|
||||
keychain, # The keychain to insert into
|
||||
ctypes.byref(result_array), # Results
|
||||
)
|
||||
_assert_no_error(result)
|
||||
|
||||
# A CFArray is not very useful to us as an intermediary
|
||||
# representation, so we are going to extract the objects we want
|
||||
# and then free the array. We don't need to keep hold of keys: the
|
||||
# keychain already has them!
|
||||
result_count = CoreFoundation.CFArrayGetCount(result_array)
|
||||
for index in range(result_count):
|
||||
item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
|
||||
item = ctypes.cast(item, CoreFoundation.CFTypeRef)
|
||||
|
||||
if _is_cert(item):
|
||||
CoreFoundation.CFRetain(item)
|
||||
certificates.append(item)
|
||||
elif _is_identity(item):
|
||||
CoreFoundation.CFRetain(item)
|
||||
identities.append(item)
|
||||
finally:
|
||||
if result_array:
|
||||
CoreFoundation.CFRelease(result_array)
|
||||
|
||||
CoreFoundation.CFRelease(filedata)
|
||||
|
||||
return (identities, certificates)
|
||||
|
||||
|
||||
def _load_client_cert_chain(keychain, *paths):
|
||||
"""
|
||||
Load certificates and maybe keys from a number of files. Has the end goal
|
||||
of returning a CFArray containing one SecIdentityRef, and then zero or more
|
||||
SecCertificateRef objects, suitable for use as a client certificate trust
|
||||
chain.
|
||||
"""
|
||||
# Ok, the strategy.
|
||||
#
|
||||
# This relies on knowing that macOS will not give you a SecIdentityRef
|
||||
# unless you have imported a key into a keychain. This is a somewhat
|
||||
# artificial limitation of macOS (for example, it doesn't necessarily
|
||||
# affect iOS), but there is nothing inside Security.framework that lets you
|
||||
# get a SecIdentityRef without having a key in a keychain.
|
||||
#
|
||||
# So the policy here is we take all the files and iterate them in order.
|
||||
# Each one will use SecItemImport to have one or more objects loaded from
|
||||
# it. We will also point at a keychain that macOS can use to work with the
|
||||
# private key.
|
||||
#
|
||||
# Once we have all the objects, we'll check what we actually have. If we
|
||||
# already have a SecIdentityRef in hand, fab: we'll use that. Otherwise,
|
||||
# we'll take the first certificate (which we assume to be our leaf) and
|
||||
# ask the keychain to give us a SecIdentityRef with that cert's associated
|
||||
# key.
|
||||
#
|
||||
# We'll then return a CFArray containing the trust chain: one
|
||||
# SecIdentityRef and then zero-or-more SecCertificateRef objects. The
|
||||
# responsibility for freeing this CFArray will be with the caller. This
|
||||
# CFArray must remain alive for the entire connection, so in practice it
|
||||
# will be stored with a single SSLSocket, along with the reference to the
|
||||
# keychain.
|
||||
certificates = []
|
||||
identities = []
|
||||
|
||||
# Filter out bad paths.
|
||||
paths = (path for path in paths if path)
|
||||
|
||||
try:
|
||||
for file_path in paths:
|
||||
new_identities, new_certs = _load_items_from_file(keychain, file_path)
|
||||
identities.extend(new_identities)
|
||||
certificates.extend(new_certs)
|
||||
|
||||
# Ok, we have everything. The question is: do we have an identity? If
|
||||
# not, we want to grab one from the first cert we have.
|
||||
if not identities:
|
||||
new_identity = Security.SecIdentityRef()
|
||||
status = Security.SecIdentityCreateWithCertificate(
|
||||
keychain, certificates[0], ctypes.byref(new_identity)
|
||||
)
|
||||
_assert_no_error(status)
|
||||
identities.append(new_identity)
|
||||
|
||||
# We now want to release the original certificate, as we no longer
|
||||
# need it.
|
||||
CoreFoundation.CFRelease(certificates.pop(0))
|
||||
|
||||
# We now need to build a new CFArray that holds the trust chain.
|
||||
trust_chain = CoreFoundation.CFArrayCreateMutable(
|
||||
CoreFoundation.kCFAllocatorDefault,
|
||||
0,
|
||||
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
|
||||
)
|
||||
for item in itertools.chain(identities, certificates):
|
||||
# ArrayAppendValue does a CFRetain on the item. That's fine,
|
||||
# because the finally block will release our other refs to them.
|
||||
CoreFoundation.CFArrayAppendValue(trust_chain, item)
|
||||
|
||||
return trust_chain
|
||||
finally:
|
||||
for obj in itertools.chain(identities, certificates):
|
||||
CoreFoundation.CFRelease(obj)
|
||||
|
||||
|
||||
TLS_PROTOCOL_VERSIONS = {
|
||||
"SSLv2": (0, 2),
|
||||
"SSLv3": (3, 0),
|
||||
"TLSv1": (3, 1),
|
||||
"TLSv1.1": (3, 2),
|
||||
"TLSv1.2": (3, 3),
|
||||
}
|
||||
|
||||
|
||||
def _build_tls_unknown_ca_alert(version):
|
||||
"""
|
||||
Builds a TLS alert record for an unknown CA.
|
||||
"""
|
||||
ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
|
||||
severity_fatal = 0x02
|
||||
description_unknown_ca = 0x30
|
||||
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
|
||||
msg_len = len(msg)
|
||||
record_type_alert = 0x15
|
||||
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
|
||||
return record
|
||||
|
|
@ -1,314 +0,0 @@
|
|||
"""
|
||||
This module provides a pool manager that uses Google App Engine's
|
||||
`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_.
|
||||
|
||||
Example usage::
|
||||
|
||||
from urllib3 import PoolManager
|
||||
from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox
|
||||
|
||||
if is_appengine_sandbox():
|
||||
# AppEngineManager uses AppEngine's URLFetch API behind the scenes
|
||||
http = AppEngineManager()
|
||||
else:
|
||||
# PoolManager uses a socket-level API behind the scenes
|
||||
http = PoolManager()
|
||||
|
||||
r = http.request('GET', 'https://google.com/')
|
||||
|
||||
There are `limitations <https://cloud.google.com/appengine/docs/python/\
|
||||
urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be
|
||||
the best choice for your application. There are three options for using
|
||||
urllib3 on Google App Engine:
|
||||
|
||||
1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is
|
||||
cost-effective in many circumstances as long as your usage is within the
|
||||
limitations.
|
||||
2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets.
|
||||
Sockets also have `limitations and restrictions
|
||||
<https://cloud.google.com/appengine/docs/python/sockets/\
|
||||
#limitations-and-restrictions>`_ and have a lower free quota than URLFetch.
|
||||
To use sockets, be sure to specify the following in your ``app.yaml``::
|
||||
|
||||
env_variables:
|
||||
GAE_USE_SOCKETS_HTTPLIB : 'true'
|
||||
|
||||
3. If you are using `App Engine Flexible
|
||||
<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard
|
||||
:class:`PoolManager` without any configuration or special environment variables.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from ..exceptions import (
|
||||
HTTPError,
|
||||
HTTPWarning,
|
||||
MaxRetryError,
|
||||
ProtocolError,
|
||||
SSLError,
|
||||
TimeoutError,
|
||||
)
|
||||
from ..packages.six.moves.urllib.parse import urljoin
|
||||
from ..request import RequestMethods
|
||||
from ..response import HTTPResponse
|
||||
from ..util.retry import Retry
|
||||
from ..util.timeout import Timeout
|
||||
from . import _appengine_environ
|
||||
|
||||
try:
|
||||
from google.appengine.api import urlfetch
|
||||
except ImportError:
|
||||
urlfetch = None
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppEnginePlatformWarning(HTTPWarning):
|
||||
pass
|
||||
|
||||
|
||||
class AppEnginePlatformError(HTTPError):
|
||||
pass
|
||||
|
||||
|
||||
class AppEngineManager(RequestMethods):
|
||||
"""
|
||||
Connection manager for Google App Engine sandbox applications.
|
||||
|
||||
This manager uses the URLFetch service directly instead of using the
|
||||
emulated httplib, and is subject to URLFetch limitations as described in
|
||||
the App Engine documentation `here
|
||||
<https://cloud.google.com/appengine/docs/python/urlfetch>`_.
|
||||
|
||||
Notably it will raise an :class:`AppEnginePlatformError` if:
|
||||
* URLFetch is not available.
|
||||
* If you attempt to use this on App Engine Flexible, as full socket
|
||||
support is available.
|
||||
* If a request size is more than 10 megabytes.
|
||||
* If a response size is more than 32 megabytes.
|
||||
* If you use an unsupported request method such as OPTIONS.
|
||||
|
||||
Beyond those cases, it will raise normal urllib3 errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers=None,
|
||||
retries=None,
|
||||
validate_certificate=True,
|
||||
urlfetch_retries=True,
|
||||
):
|
||||
if not urlfetch:
|
||||
raise AppEnginePlatformError(
|
||||
"URLFetch is not available in this environment."
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"urllib3 is using URLFetch on Google App Engine sandbox instead "
|
||||
"of sockets. To use sockets directly instead of URLFetch see "
|
||||
"https://urllib3.readthedocs.io/en/1.26.x/reference/urllib3.contrib.html.",
|
||||
AppEnginePlatformWarning,
|
||||
)
|
||||
|
||||
RequestMethods.__init__(self, headers)
|
||||
self.validate_certificate = validate_certificate
|
||||
self.urlfetch_retries = urlfetch_retries
|
||||
|
||||
self.retries = retries or Retry.DEFAULT
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Return False to re-raise any potential exceptions
|
||||
return False
|
||||
|
||||
def urlopen(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
body=None,
|
||||
headers=None,
|
||||
retries=None,
|
||||
redirect=True,
|
||||
timeout=Timeout.DEFAULT_TIMEOUT,
|
||||
**response_kw
|
||||
):
|
||||
|
||||
retries = self._get_retries(retries, redirect)
|
||||
|
||||
try:
|
||||
follow_redirects = redirect and retries.redirect != 0 and retries.total
|
||||
response = urlfetch.fetch(
|
||||
url,
|
||||
payload=body,
|
||||
method=method,
|
||||
headers=headers or {},
|
||||
allow_truncated=False,
|
||||
follow_redirects=self.urlfetch_retries and follow_redirects,
|
||||
deadline=self._get_absolute_timeout(timeout),
|
||||
validate_certificate=self.validate_certificate,
|
||||
)
|
||||
except urlfetch.DeadlineExceededError as e:
|
||||
raise TimeoutError(self, e)
|
||||
|
||||
except urlfetch.InvalidURLError as e:
|
||||
if "too large" in str(e):
|
||||
raise AppEnginePlatformError(
|
||||
"URLFetch request too large, URLFetch only "
|
||||
"supports requests up to 10mb in size.",
|
||||
e,
|
||||
)
|
||||
raise ProtocolError(e)
|
||||
|
||||
except urlfetch.DownloadError as e:
|
||||
if "Too many redirects" in str(e):
|
||||
raise MaxRetryError(self, url, reason=e)
|
||||
raise ProtocolError(e)
|
||||
|
||||
except urlfetch.ResponseTooLargeError as e:
|
||||
raise AppEnginePlatformError(
|
||||
"URLFetch response too large, URLFetch only supports"
|
||||
"responses up to 32mb in size.",
|
||||
e,
|
||||
)
|
||||
|
||||
except urlfetch.SSLCertificateError as e:
|
||||
raise SSLError(e)
|
||||
|
||||
except urlfetch.InvalidMethodError as e:
|
||||
raise AppEnginePlatformError(
|
||||
"URLFetch does not support method: %s" % method, e
|
||||
)
|
||||
|
||||
http_response = self._urlfetch_response_to_http_response(
|
||||
response, retries=retries, **response_kw
|
||||
)
|
||||
|
||||
# Handle redirect?
|
||||
redirect_location = redirect and http_response.get_redirect_location()
|
||||
if redirect_location:
|
||||
# Check for redirect response
|
||||
if self.urlfetch_retries and retries.raise_on_redirect:
|
||||
raise MaxRetryError(self, url, "too many redirects")
|
||||
else:
|
||||
if http_response.status == 303:
|
||||
method = "GET"
|
||||
|
||||
try:
|
||||
retries = retries.increment(
|
||||
method, url, response=http_response, _pool=self
|
||||
)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_redirect:
|
||||
raise MaxRetryError(self, url, "too many redirects")
|
||||
return http_response
|
||||
|
||||
retries.sleep_for_retry(http_response)
|
||||
log.debug("Redirecting %s -> %s", url, redirect_location)
|
||||
redirect_url = urljoin(url, redirect_location)
|
||||
return self.urlopen(
|
||||
method,
|
||||
redirect_url,
|
||||
body,
|
||||
headers,
|
||||
retries=retries,
|
||||
redirect=redirect,
|
||||
timeout=timeout,
|
||||
**response_kw
|
||||
)
|
||||
|
||||
# Check if we should retry the HTTP response.
|
||||
has_retry_after = bool(http_response.getheader("Retry-After"))
|
||||
if retries.is_retry(method, http_response.status, has_retry_after):
|
||||
retries = retries.increment(method, url, response=http_response, _pool=self)
|
||||
log.debug("Retry: %s", url)
|
||||
retries.sleep(http_response)
|
||||
return self.urlopen(
|
||||
method,
|
||||
url,
|
||||
body=body,
|
||||
headers=headers,
|
||||
retries=retries,
|
||||
redirect=redirect,
|
||||
timeout=timeout,
|
||||
**response_kw
|
||||
)
|
||||
|
||||
return http_response
|
||||
|
||||
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
|
||||
|
||||
if is_prod_appengine():
|
||||
# Production GAE handles deflate encoding automatically, but does
|
||||
# not remove the encoding header.
|
||||
content_encoding = urlfetch_resp.headers.get("content-encoding")
|
||||
|
||||
if content_encoding == "deflate":
|
||||
del urlfetch_resp.headers["content-encoding"]
|
||||
|
||||
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
|
||||
# We have a full response's content,
|
||||
# so let's make sure we don't report ourselves as chunked data.
|
||||
if transfer_encoding == "chunked":
|
||||
encodings = transfer_encoding.split(",")
|
||||
encodings.remove("chunked")
|
||||
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
|
||||
|
||||
original_response = HTTPResponse(
|
||||
# In order for decoding to work, we must present the content as
|
||||
# a file-like object.
|
||||
body=io.BytesIO(urlfetch_resp.content),
|
||||
msg=urlfetch_resp.header_msg,
|
||||
headers=urlfetch_resp.headers,
|
||||
status=urlfetch_resp.status_code,
|
||||
**response_kw
|
||||
)
|
||||
|
||||
return HTTPResponse(
|
||||
body=io.BytesIO(urlfetch_resp.content),
|
||||
headers=urlfetch_resp.headers,
|
||||
status=urlfetch_resp.status_code,
|
||||
original_response=original_response,
|
||||
**response_kw
|
||||
)
|
||||
|
||||
def _get_absolute_timeout(self, timeout):
|
||||
if timeout is Timeout.DEFAULT_TIMEOUT:
|
||||
return None # Defer to URLFetch's default.
|
||||
if isinstance(timeout, Timeout):
|
||||
if timeout._read is not None or timeout._connect is not None:
|
||||
warnings.warn(
|
||||
"URLFetch does not support granular timeout settings, "
|
||||
"reverting to total or default URLFetch timeout.",
|
||||
AppEnginePlatformWarning,
|
||||
)
|
||||
return timeout.total
|
||||
return timeout
|
||||
|
||||
def _get_retries(self, retries, redirect):
|
||||
if not isinstance(retries, Retry):
|
||||
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
|
||||
|
||||
if retries.connect or retries.read or retries.redirect:
|
||||
warnings.warn(
|
||||
"URLFetch only supports total retries and does not "
|
||||
"recognize connect, read, or redirect retry parameters.",
|
||||
AppEnginePlatformWarning,
|
||||
)
|
||||
|
||||
return retries
|
||||
|
||||
|
||||
# Alias methods from _appengine_environ to maintain public API interface.
|
||||
|
||||
is_appengine = _appengine_environ.is_appengine
|
||||
is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
|
||||
is_local_appengine = _appengine_environ.is_local_appengine
|
||||
is_prod_appengine = _appengine_environ.is_prod_appengine
|
||||
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
"""
|
||||
NTLM authenticating pool, contributed by erikcederstran
|
||||
|
||||
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
|
||||
from ntlm import ntlm
|
||||
|
||||
from .. import HTTPSConnectionPool
|
||||
from ..packages.six.moves.http_client import HTTPSConnection
|
||||
|
||||
warnings.warn(
|
||||
"The 'urllib3.contrib.ntlmpool' module is deprecated and will be removed "
|
||||
"in urllib3 v2.0 release, urllib3 is not able to support it properly due "
|
||||
"to reasons listed in issue: https://github.com/urllib3/urllib3/issues/2282. "
|
||||
"If you are a user of this module please comment in the mentioned issue.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
class NTLMConnectionPool(HTTPSConnectionPool):
|
||||
"""
|
||||
Implements an NTLM authentication version of an urllib3 connection pool
|
||||
"""
|
||||
|
||||
scheme = "https"
|
||||
|
||||
def __init__(self, user, pw, authurl, *args, **kwargs):
|
||||
"""
|
||||
authurl is a random URL on the server that is protected by NTLM.
|
||||
user is the Windows user, probably in the DOMAIN\\username format.
|
||||
pw is the password for the user.
|
||||
"""
|
||||
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
|
||||
self.authurl = authurl
|
||||
self.rawuser = user
|
||||
user_parts = user.split("\\", 1)
|
||||
self.domain = user_parts[0].upper()
|
||||
self.user = user_parts[1]
|
||||
self.pw = pw
|
||||
|
||||
def _new_conn(self):
|
||||
# Performs the NTLM handshake that secures the connection. The socket
|
||||
# must be kept open while requests are performed.
|
||||
self.num_connections += 1
|
||||
log.debug(
|
||||
"Starting NTLM HTTPS connection no. %d: https://%s%s",
|
||||
self.num_connections,
|
||||
self.host,
|
||||
self.authurl,
|
||||
)
|
||||
|
||||
headers = {"Connection": "Keep-Alive"}
|
||||
req_header = "Authorization"
|
||||
resp_header = "www-authenticate"
|
||||
|
||||
conn = HTTPSConnection(host=self.host, port=self.port)
|
||||
|
||||
# Send negotiation message
|
||||
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
|
||||
self.rawuser
|
||||
)
|
||||
log.debug("Request headers: %s", headers)
|
||||
conn.request("GET", self.authurl, None, headers)
|
||||
res = conn.getresponse()
|
||||
reshdr = dict(res.getheaders())
|
||||
log.debug("Response status: %s %s", res.status, res.reason)
|
||||
log.debug("Response headers: %s", reshdr)
|
||||
log.debug("Response data: %s [...]", res.read(100))
|
||||
|
||||
# Remove the reference to the socket, so that it can not be closed by
|
||||
# the response object (we want to keep the socket open)
|
||||
res.fp = None
|
||||
|
||||
# Server should respond with a challenge message
|
||||
auth_header_values = reshdr[resp_header].split(", ")
|
||||
auth_header_value = None
|
||||
for s in auth_header_values:
|
||||
if s[:5] == "NTLM ":
|
||||
auth_header_value = s[5:]
|
||||
if auth_header_value is None:
|
||||
raise Exception(
|
||||
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
|
||||
)
|
||||
|
||||
# Send authentication message
|
||||
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
|
||||
auth_header_value
|
||||
)
|
||||
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
|
||||
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
|
||||
)
|
||||
headers[req_header] = "NTLM %s" % auth_msg
|
||||
log.debug("Request headers: %s", headers)
|
||||
conn.request("GET", self.authurl, None, headers)
|
||||
res = conn.getresponse()
|
||||
log.debug("Response status: %s %s", res.status, res.reason)
|
||||
log.debug("Response headers: %s", dict(res.getheaders()))
|
||||
log.debug("Response data: %s [...]", res.read()[:100])
|
||||
if res.status != 200:
|
||||
if res.status == 401:
|
||||
raise Exception("Server rejected request: wrong username or password")
|
||||
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
|
||||
|
||||
res.fp = None
|
||||
log.debug("Connection established")
|
||||
return conn
|
||||
|
||||
def urlopen(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
body=None,
|
||||
headers=None,
|
||||
retries=3,
|
||||
redirect=True,
|
||||
assert_same_host=True,
|
||||
):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers["Connection"] = "Keep-Alive"
|
||||
return super(NTLMConnectionPool, self).urlopen(
|
||||
method, url, body, headers, retries, redirect, assert_same_host
|
||||
)
|
||||
|
|
@ -1,511 +0,0 @@
|
|||
"""
|
||||
TLS with SNI_-support for Python 2. Follow these instructions if you would
|
||||
like to verify TLS certificates in Python 2. Note, the default libraries do
|
||||
*not* do certificate checking; you need to do additional work to validate
|
||||
certificates yourself.
|
||||
|
||||
This needs the following packages installed:
|
||||
|
||||
* `pyOpenSSL`_ (tested with 16.0.0)
|
||||
* `cryptography`_ (minimum 1.3.4, from pyopenssl)
|
||||
* `idna`_ (minimum 2.0, from cryptography)
|
||||
|
||||
However, pyopenssl depends on cryptography, which depends on idna, so while we
|
||||
use all three directly here we end up having relatively few packages required.
|
||||
|
||||
You can install them with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python -m pip install pyopenssl cryptography idna
|
||||
|
||||
To activate certificate checking, call
|
||||
:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
|
||||
before you begin making HTTP requests. This can be done in a ``sitecustomize``
|
||||
module, or at any other time before your application begins using ``urllib3``,
|
||||
like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
try:
|
||||
import urllib3.contrib.pyopenssl
|
||||
urllib3.contrib.pyopenssl.inject_into_urllib3()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Now you can use :mod:`urllib3` as you normally would, and it will support SNI
|
||||
when the required modules are installed.
|
||||
|
||||
Activating this module also has the positive side effect of disabling SSL/TLS
|
||||
compression in Python 2 (see `CRIME attack`_).
|
||||
|
||||
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
|
||||
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
|
||||
.. _pyopenssl: https://www.pyopenssl.org
|
||||
.. _cryptography: https://cryptography.io
|
||||
.. _idna: https://github.com/kjd/idna
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import OpenSSL.SSL
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends.openssl import backend as openssl_backend
|
||||
from cryptography.hazmat.backends.openssl.x509 import _Certificate
|
||||
|
||||
try:
|
||||
from cryptography.x509 import UnsupportedExtension
|
||||
except ImportError:
|
||||
# UnsupportedExtension is gone in cryptography >= 2.1.0
|
||||
class UnsupportedExtension(Exception):
|
||||
pass
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
from socket import error as SocketError
|
||||
from socket import timeout
|
||||
|
||||
try: # Platform-specific: Python 2
|
||||
from socket import _fileobject
|
||||
except ImportError: # Platform-specific: Python 3
|
||||
_fileobject = None
|
||||
from ..packages.backports.makefile import backport_makefile
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
from .. import util
|
||||
from ..packages import six
|
||||
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
|
||||
|
||||
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
||||
|
||||
# SNI always works.
|
||||
HAS_SNI = True
|
||||
|
||||
# Map from urllib3 to PyOpenSSL compatible parameter-values.
|
||||
_openssl_versions = {
|
||||
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
|
||||
PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD,
|
||||
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
|
||||
}
|
||||
|
||||
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
|
||||
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
|
||||
|
||||
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
|
||||
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
|
||||
|
||||
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
|
||||
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
|
||||
|
||||
|
||||
_stdlib_to_openssl_verify = {
|
||||
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
|
||||
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
|
||||
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
|
||||
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
||||
}
|
||||
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
|
||||
|
||||
# OpenSSL will only write 16K at a time
|
||||
SSL_WRITE_BLOCKSIZE = 16384
|
||||
|
||||
orig_util_HAS_SNI = util.HAS_SNI
|
||||
orig_util_SSLContext = util.ssl_.SSLContext
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def inject_into_urllib3():
|
||||
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
|
||||
|
||||
_validate_dependencies_met()
|
||||
|
||||
util.SSLContext = PyOpenSSLContext
|
||||
util.ssl_.SSLContext = PyOpenSSLContext
|
||||
util.HAS_SNI = HAS_SNI
|
||||
util.ssl_.HAS_SNI = HAS_SNI
|
||||
util.IS_PYOPENSSL = True
|
||||
util.ssl_.IS_PYOPENSSL = True
|
||||
|
||||
|
||||
def extract_from_urllib3():
|
||||
"Undo monkey-patching by :func:`inject_into_urllib3`."
|
||||
|
||||
util.SSLContext = orig_util_SSLContext
|
||||
util.ssl_.SSLContext = orig_util_SSLContext
|
||||
util.HAS_SNI = orig_util_HAS_SNI
|
||||
util.ssl_.HAS_SNI = orig_util_HAS_SNI
|
||||
util.IS_PYOPENSSL = False
|
||||
util.ssl_.IS_PYOPENSSL = False
|
||||
|
||||
|
||||
def _validate_dependencies_met():
|
||||
"""
|
||||
Verifies that PyOpenSSL's package-level dependencies have been met.
|
||||
Throws `ImportError` if they are not met.
|
||||
"""
|
||||
# Method added in `cryptography==1.1`; not available in older versions
|
||||
from cryptography.x509.extensions import Extensions
|
||||
|
||||
if getattr(Extensions, "get_extension_for_class", None) is None:
|
||||
raise ImportError(
|
||||
"'cryptography' module missing required functionality. "
|
||||
"Try upgrading to v1.3.4 or newer."
|
||||
)
|
||||
|
||||
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
|
||||
# attribute is only present on those versions.
|
||||
from OpenSSL.crypto import X509
|
||||
|
||||
x509 = X509()
|
||||
if getattr(x509, "_x509", None) is None:
|
||||
raise ImportError(
|
||||
"'pyOpenSSL' module missing required functionality. "
|
||||
"Try upgrading to v0.14 or newer."
|
||||
)
|
||||
|
||||
|
||||
def _dnsname_to_stdlib(name):
|
||||
"""
|
||||
Converts a dNSName SubjectAlternativeName field to the form used by the
|
||||
standard library on the given Python version.
|
||||
|
||||
Cryptography produces a dNSName as a unicode string that was idna-decoded
|
||||
from ASCII bytes. We need to idna-encode that string to get it back, and
|
||||
then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
|
||||
uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
|
||||
|
||||
If the name cannot be idna-encoded then we return None signalling that
|
||||
the name given should be skipped.
|
||||
"""
|
||||
|
||||
def idna_encode(name):
|
||||
"""
|
||||
Borrowed wholesale from the Python Cryptography Project. It turns out
|
||||
that we can't just safely call `idna.encode`: it can explode for
|
||||
wildcard names. This avoids that problem.
|
||||
"""
|
||||
import idna
|
||||
|
||||
try:
|
||||
for prefix in [u"*.", u"."]:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix) :]
|
||||
return prefix.encode("ascii") + idna.encode(name)
|
||||
return idna.encode(name)
|
||||
except idna.core.IDNAError:
|
||||
return None
|
||||
|
||||
# Don't send IPv6 addresses through the IDNA encoder.
|
||||
if ":" in name:
|
||||
return name
|
||||
|
||||
name = idna_encode(name)
|
||||
if name is None:
|
||||
return None
|
||||
elif sys.version_info >= (3, 0):
|
||||
name = name.decode("utf-8")
|
||||
return name
|
||||
|
||||
|
||||
def get_subj_alt_name(peer_cert):
|
||||
"""
|
||||
Given an PyOpenSSL certificate, provides all the subject alternative names.
|
||||
"""
|
||||
# Pass the cert to cryptography, which has much better APIs for this.
|
||||
if hasattr(peer_cert, "to_cryptography"):
|
||||
cert = peer_cert.to_cryptography()
|
||||
else:
|
||||
# This is technically using private APIs, but should work across all
|
||||
# relevant versions before PyOpenSSL got a proper API for this.
|
||||
cert = _Certificate(openssl_backend, peer_cert._x509)
|
||||
|
||||
# We want to find the SAN extension. Ask Cryptography to locate it (it's
|
||||
# faster than looping in Python)
|
||||
try:
|
||||
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
|
||||
except x509.ExtensionNotFound:
|
||||
# No such extension, return the empty list.
|
||||
return []
|
||||
except (
|
||||
x509.DuplicateExtension,
|
||||
UnsupportedExtension,
|
||||
x509.UnsupportedGeneralNameType,
|
||||
UnicodeError,
|
||||
) as e:
|
||||
# A problem has been found with the quality of the certificate. Assume
|
||||
# no SAN field is present.
|
||||
log.warning(
|
||||
"A problem was encountered with the certificate that prevented "
|
||||
"urllib3 from finding the SubjectAlternativeName field. This can "
|
||||
"affect certificate validation. The error was %s",
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
# We want to return dNSName and iPAddress fields. We need to cast the IPs
|
||||
# back to strings because the match_hostname function wants them as
|
||||
# strings.
|
||||
# Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
|
||||
# decoded. This is pretty frustrating, but that's what the standard library
|
||||
# does with certificates, and so we need to attempt to do the same.
|
||||
# We also want to skip over names which cannot be idna encoded.
|
||||
names = [
|
||||
("DNS", name)
|
||||
for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
|
||||
if name is not None
|
||||
]
|
||||
names.extend(
|
||||
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
|
||||
)
|
||||
|
||||
return names
|
||||
|
||||
|
||||
class WrappedSocket(object):
|
||||
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
|
||||
|
||||
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
|
||||
collector of pypy.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, socket, suppress_ragged_eofs=True):
|
||||
self.connection = connection
|
||||
self.socket = socket
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
self._makefile_refs = 0
|
||||
self._closed = False
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
|
||||
# Copy-pasted from Python 3.5 source code
|
||||
def _decref_socketios(self):
|
||||
if self._makefile_refs > 0:
|
||||
self._makefile_refs -= 1
|
||||
if self._closed:
|
||||
self.close()
|
||||
|
||||
def recv(self, *args, **kwargs):
|
||||
try:
|
||||
data = self.connection.recv(*args, **kwargs)
|
||||
except OpenSSL.SSL.SysCallError as e:
|
||||
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
||||
return b""
|
||||
else:
|
||||
raise SocketError(str(e))
|
||||
except OpenSSL.SSL.ZeroReturnError:
|
||||
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
||||
return b""
|
||||
else:
|
||||
raise
|
||||
except OpenSSL.SSL.WantReadError:
|
||||
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
||||
raise timeout("The read operation timed out")
|
||||
else:
|
||||
return self.recv(*args, **kwargs)
|
||||
|
||||
# TLS 1.3 post-handshake authentication
|
||||
except OpenSSL.SSL.Error as e:
|
||||
raise ssl.SSLError("read error: %r" % e)
|
||||
else:
|
||||
return data
|
||||
|
||||
def recv_into(self, *args, **kwargs):
|
||||
try:
|
||||
return self.connection.recv_into(*args, **kwargs)
|
||||
except OpenSSL.SSL.SysCallError as e:
|
||||
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
|
||||
return 0
|
||||
else:
|
||||
raise SocketError(str(e))
|
||||
except OpenSSL.SSL.ZeroReturnError:
|
||||
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
|
||||
return 0
|
||||
else:
|
||||
raise
|
||||
except OpenSSL.SSL.WantReadError:
|
||||
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
|
||||
raise timeout("The read operation timed out")
|
||||
else:
|
||||
return self.recv_into(*args, **kwargs)
|
||||
|
||||
# TLS 1.3 post-handshake authentication
|
||||
except OpenSSL.SSL.Error as e:
|
||||
raise ssl.SSLError("read error: %r" % e)
|
||||
|
||||
def settimeout(self, timeout):
|
||||
return self.socket.settimeout(timeout)
|
||||
|
||||
def _send_until_done(self, data):
|
||||
while True:
|
||||
try:
|
||||
return self.connection.send(data)
|
||||
except OpenSSL.SSL.WantWriteError:
|
||||
if not util.wait_for_write(self.socket, self.socket.gettimeout()):
|
||||
raise timeout()
|
||||
continue
|
||||
except OpenSSL.SSL.SysCallError as e:
|
||||
raise SocketError(str(e))
|
||||
|
||||
def sendall(self, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
sent = self._send_until_done(
|
||||
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
|
||||
)
|
||||
total_sent += sent
|
||||
|
||||
def shutdown(self):
|
||||
# FIXME rethrow compatible exceptions should we ever use this
|
||||
self.connection.shutdown()
|
||||
|
||||
def close(self):
|
||||
if self._makefile_refs < 1:
|
||||
try:
|
||||
self._closed = True
|
||||
return self.connection.close()
|
||||
except OpenSSL.SSL.Error:
|
||||
return
|
||||
else:
|
||||
self._makefile_refs -= 1
|
||||
|
||||
def getpeercert(self, binary_form=False):
|
||||
x509 = self.connection.get_peer_certificate()
|
||||
|
||||
if not x509:
|
||||
return x509
|
||||
|
||||
if binary_form:
|
||||
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
|
||||
|
||||
return {
|
||||
"subject": ((("commonName", x509.get_subject().CN),),),
|
||||
"subjectAltName": get_subj_alt_name(x509),
|
||||
}
|
||||
|
||||
def version(self):
|
||||
return self.connection.get_protocol_version_name()
|
||||
|
||||
def _reuse(self):
|
||||
self._makefile_refs += 1
|
||||
|
||||
def _drop(self):
|
||||
if self._makefile_refs < 1:
|
||||
self.close()
|
||||
else:
|
||||
self._makefile_refs -= 1
|
||||
|
||||
|
||||
if _fileobject: # Platform-specific: Python 2
|
||||
|
||||
def makefile(self, mode, bufsize=-1):
|
||||
self._makefile_refs += 1
|
||||
return _fileobject(self, mode, bufsize, close=True)
|
||||
|
||||
|
||||
else: # Platform-specific: Python 3
|
||||
makefile = backport_makefile
|
||||
|
||||
WrappedSocket.makefile = makefile
|
||||
|
||||
|
||||
class PyOpenSSLContext(object):
|
||||
"""
|
||||
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
|
||||
for translating the interface of the standard library ``SSLContext`` object
|
||||
to calls into PyOpenSSL.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol):
|
||||
self.protocol = _openssl_versions[protocol]
|
||||
self._ctx = OpenSSL.SSL.Context(self.protocol)
|
||||
self._options = 0
|
||||
self.check_hostname = False
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
@options.setter
|
||||
def options(self, value):
|
||||
self._options = value
|
||||
self._ctx.set_options(value)
|
||||
|
||||
@property
|
||||
def verify_mode(self):
|
||||
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
|
||||
|
||||
@verify_mode.setter
|
||||
def verify_mode(self, value):
|
||||
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
self._ctx.set_default_verify_paths()
|
||||
|
||||
def set_ciphers(self, ciphers):
|
||||
if isinstance(ciphers, six.text_type):
|
||||
ciphers = ciphers.encode("utf-8")
|
||||
self._ctx.set_cipher_list(ciphers)
|
||||
|
||||
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
||||
if cafile is not None:
|
||||
cafile = cafile.encode("utf-8")
|
||||
if capath is not None:
|
||||
capath = capath.encode("utf-8")
|
||||
try:
|
||||
self._ctx.load_verify_locations(cafile, capath)
|
||||
if cadata is not None:
|
||||
self._ctx.load_verify_locations(BytesIO(cadata))
|
||||
except OpenSSL.SSL.Error as e:
|
||||
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
||||
self._ctx.use_certificate_chain_file(certfile)
|
||||
if password is not None:
|
||||
if not isinstance(password, six.binary_type):
|
||||
password = password.encode("utf-8")
|
||||
self._ctx.set_passwd_cb(lambda *_: password)
|
||||
self._ctx.use_privatekey_file(keyfile or certfile)
|
||||
|
||||
def set_alpn_protocols(self, protocols):
|
||||
protocols = [six.ensure_binary(p) for p in protocols]
|
||||
return self._ctx.set_alpn_protos(protocols)
|
||||
|
||||
def wrap_socket(
|
||||
self,
|
||||
sock,
|
||||
server_side=False,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True,
|
||||
server_hostname=None,
|
||||
):
|
||||
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
|
||||
|
||||
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
|
||||
server_hostname = server_hostname.encode("utf-8")
|
||||
|
||||
if server_hostname is not None:
|
||||
cnx.set_tlsext_host_name(server_hostname)
|
||||
|
||||
cnx.set_connect_state()
|
||||
|
||||
while True:
|
||||
try:
|
||||
cnx.do_handshake()
|
||||
except OpenSSL.SSL.WantReadError:
|
||||
if not util.wait_for_read(sock, sock.gettimeout()):
|
||||
raise timeout("select timed out")
|
||||
continue
|
||||
except OpenSSL.SSL.Error as e:
|
||||
raise ssl.SSLError("bad handshake: %r" % e)
|
||||
break
|
||||
|
||||
return WrappedSocket(cnx, sock)
|
||||
|
||||
|
||||
def _verify_callback(cnx, x509, err_no, err_depth, return_code):
|
||||
return err_no == 0
|
||||
|
|
@ -1,922 +0,0 @@
|
|||
"""
|
||||
SecureTranport support for urllib3 via ctypes.
|
||||
|
||||
This makes platform-native TLS available to urllib3 users on macOS without the
|
||||
use of a compiler. This is an important feature because the Python Package
|
||||
Index is moving to become a TLSv1.2-or-higher server, and the default OpenSSL
|
||||
that ships with macOS is not capable of doing TLSv1.2. The only way to resolve
|
||||
this is to give macOS users an alternative solution to the problem, and that
|
||||
solution is to use SecureTransport.
|
||||
|
||||
We use ctypes here because this solution must not require a compiler. That's
|
||||
because pip is not allowed to require a compiler either.
|
||||
|
||||
This is not intended to be a seriously long-term solution to this problem.
|
||||
The hope is that PEP 543 will eventually solve this issue for us, at which
|
||||
point we can retire this contrib module. But in the short term, we need to
|
||||
solve the impending tire fire that is Python on Mac without this kind of
|
||||
contrib module. So...here we are.
|
||||
|
||||
To use this module, simply import and inject it::
|
||||
|
||||
import urllib3.contrib.securetransport
|
||||
urllib3.contrib.securetransport.inject_into_urllib3()
|
||||
|
||||
Happy TLSing!
|
||||
|
||||
This code is a bastardised version of the code found in Will Bond's oscrypto
|
||||
library. An enormous debt is owed to him for blazing this trail for us. For
|
||||
that reason, this code should be considered to be covered both by urllib3's
|
||||
license and by oscrypto's:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import contextlib
|
||||
import ctypes
|
||||
import errno
|
||||
import os.path
|
||||
import shutil
|
||||
import socket
|
||||
import ssl
|
||||
import struct
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
import six
|
||||
|
||||
from .. import util
|
||||
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
|
||||
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
|
||||
from ._securetransport.low_level import (
|
||||
_assert_no_error,
|
||||
_build_tls_unknown_ca_alert,
|
||||
_cert_array_from_pem,
|
||||
_create_cfstring_array,
|
||||
_load_client_cert_chain,
|
||||
_temporary_keychain,
|
||||
)
|
||||
|
||||
try: # Platform-specific: Python 2
|
||||
from socket import _fileobject
|
||||
except ImportError: # Platform-specific: Python 3
|
||||
_fileobject = None
|
||||
from ..packages.backports.makefile import backport_makefile
|
||||
|
||||
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
||||
|
||||
# SNI always works
|
||||
HAS_SNI = True
|
||||
|
||||
orig_util_HAS_SNI = util.HAS_SNI
|
||||
orig_util_SSLContext = util.ssl_.SSLContext
|
||||
|
||||
# This dictionary is used by the read callback to obtain a handle to the
|
||||
# calling wrapped socket. This is a pretty silly approach, but for now it'll
|
||||
# do. I feel like I should be able to smuggle a handle to the wrapped socket
|
||||
# directly in the SSLConnectionRef, but for now this approach will work I
|
||||
# guess.
|
||||
#
|
||||
# We need to lock around this structure for inserts, but we don't do it for
|
||||
# reads/writes in the callbacks. The reasoning here goes as follows:
|
||||
#
|
||||
# 1. It is not possible to call into the callbacks before the dictionary is
|
||||
# populated, so once in the callback the id must be in the dictionary.
|
||||
# 2. The callbacks don't mutate the dictionary, they only read from it, and
|
||||
# so cannot conflict with any of the insertions.
|
||||
#
|
||||
# This is good: if we had to lock in the callbacks we'd drastically slow down
|
||||
# the performance of this code.
|
||||
_connection_refs = weakref.WeakValueDictionary()
|
||||
_connection_ref_lock = threading.Lock()
|
||||
|
||||
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
|
||||
# for no better reason than we need *a* limit, and this one is right there.
|
||||
SSL_WRITE_BLOCKSIZE = 16384
|
||||
|
||||
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
|
||||
# individual cipher suites. We need to do this because this is how
|
||||
# SecureTransport wants them.
|
||||
CIPHER_SUITES = [
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
SecurityConst.TLS_AES_256_GCM_SHA384,
|
||||
SecurityConst.TLS_AES_128_GCM_SHA256,
|
||||
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
SecurityConst.TLS_AES_128_CCM_8_SHA256,
|
||||
SecurityConst.TLS_AES_128_CCM_SHA256,
|
||||
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
|
||||
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
]
|
||||
|
||||
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
|
||||
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
|
||||
# TLSv1 to 1.2 are supported on macOS 10.8+
|
||||
_protocol_to_min_max = {
|
||||
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
|
||||
PROTOCOL_TLS_CLIENT: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
|
||||
}
|
||||
|
||||
if hasattr(ssl, "PROTOCOL_SSLv2"):
|
||||
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
|
||||
SecurityConst.kSSLProtocol2,
|
||||
SecurityConst.kSSLProtocol2,
|
||||
)
|
||||
if hasattr(ssl, "PROTOCOL_SSLv3"):
|
||||
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
|
||||
SecurityConst.kSSLProtocol3,
|
||||
SecurityConst.kSSLProtocol3,
|
||||
)
|
||||
if hasattr(ssl, "PROTOCOL_TLSv1"):
|
||||
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
|
||||
SecurityConst.kTLSProtocol1,
|
||||
SecurityConst.kTLSProtocol1,
|
||||
)
|
||||
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
|
||||
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
|
||||
SecurityConst.kTLSProtocol11,
|
||||
SecurityConst.kTLSProtocol11,
|
||||
)
|
||||
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
|
||||
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
|
||||
SecurityConst.kTLSProtocol12,
|
||||
SecurityConst.kTLSProtocol12,
|
||||
)
|
||||
|
||||
|
||||
def inject_into_urllib3():
|
||||
"""
|
||||
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
|
||||
"""
|
||||
util.SSLContext = SecureTransportContext
|
||||
util.ssl_.SSLContext = SecureTransportContext
|
||||
util.HAS_SNI = HAS_SNI
|
||||
util.ssl_.HAS_SNI = HAS_SNI
|
||||
util.IS_SECURETRANSPORT = True
|
||||
util.ssl_.IS_SECURETRANSPORT = True
|
||||
|
||||
|
||||
def extract_from_urllib3():
|
||||
"""
|
||||
Undo monkey-patching by :func:`inject_into_urllib3`.
|
||||
"""
|
||||
util.SSLContext = orig_util_SSLContext
|
||||
util.ssl_.SSLContext = orig_util_SSLContext
|
||||
util.HAS_SNI = orig_util_HAS_SNI
|
||||
util.ssl_.HAS_SNI = orig_util_HAS_SNI
|
||||
util.IS_SECURETRANSPORT = False
|
||||
util.ssl_.IS_SECURETRANSPORT = False
|
||||
|
||||
|
||||
def _read_callback(connection_id, data_buffer, data_length_pointer):
|
||||
"""
|
||||
SecureTransport read callback. This is called by ST to request that data
|
||||
be returned from the socket.
|
||||
"""
|
||||
wrapped_socket = None
|
||||
try:
|
||||
wrapped_socket = _connection_refs.get(connection_id)
|
||||
if wrapped_socket is None:
|
||||
return SecurityConst.errSSLInternal
|
||||
base_socket = wrapped_socket.socket
|
||||
|
||||
requested_length = data_length_pointer[0]
|
||||
|
||||
timeout = wrapped_socket.gettimeout()
|
||||
error = None
|
||||
read_count = 0
|
||||
|
||||
try:
|
||||
while read_count < requested_length:
|
||||
if timeout is None or timeout >= 0:
|
||||
if not util.wait_for_read(base_socket, timeout):
|
||||
raise socket.error(errno.EAGAIN, "timed out")
|
||||
|
||||
remaining = requested_length - read_count
|
||||
buffer = (ctypes.c_char * remaining).from_address(
|
||||
data_buffer + read_count
|
||||
)
|
||||
chunk_size = base_socket.recv_into(buffer, remaining)
|
||||
read_count += chunk_size
|
||||
if not chunk_size:
|
||||
if not read_count:
|
||||
return SecurityConst.errSSLClosedGraceful
|
||||
break
|
||||
except (socket.error) as e:
|
||||
error = e.errno
|
||||
|
||||
if error is not None and error != errno.EAGAIN:
|
||||
data_length_pointer[0] = read_count
|
||||
if error == errno.ECONNRESET or error == errno.EPIPE:
|
||||
return SecurityConst.errSSLClosedAbort
|
||||
raise
|
||||
|
||||
data_length_pointer[0] = read_count
|
||||
|
||||
if read_count != requested_length:
|
||||
return SecurityConst.errSSLWouldBlock
|
||||
|
||||
return 0
|
||||
except Exception as e:
|
||||
if wrapped_socket is not None:
|
||||
wrapped_socket._exception = e
|
||||
return SecurityConst.errSSLInternal
|
||||
|
||||
|
||||
def _write_callback(connection_id, data_buffer, data_length_pointer):
|
||||
"""
|
||||
SecureTransport write callback. This is called by ST to request that data
|
||||
actually be sent on the network.
|
||||
"""
|
||||
wrapped_socket = None
|
||||
try:
|
||||
wrapped_socket = _connection_refs.get(connection_id)
|
||||
if wrapped_socket is None:
|
||||
return SecurityConst.errSSLInternal
|
||||
base_socket = wrapped_socket.socket
|
||||
|
||||
bytes_to_write = data_length_pointer[0]
|
||||
data = ctypes.string_at(data_buffer, bytes_to_write)
|
||||
|
||||
timeout = wrapped_socket.gettimeout()
|
||||
error = None
|
||||
sent = 0
|
||||
|
||||
try:
|
||||
while sent < bytes_to_write:
|
||||
if timeout is None or timeout >= 0:
|
||||
if not util.wait_for_write(base_socket, timeout):
|
||||
raise socket.error(errno.EAGAIN, "timed out")
|
||||
chunk_sent = base_socket.send(data)
|
||||
sent += chunk_sent
|
||||
|
||||
# This has some needless copying here, but I'm not sure there's
|
||||
# much value in optimising this data path.
|
||||
data = data[chunk_sent:]
|
||||
except (socket.error) as e:
|
||||
error = e.errno
|
||||
|
||||
if error is not None and error != errno.EAGAIN:
|
||||
data_length_pointer[0] = sent
|
||||
if error == errno.ECONNRESET or error == errno.EPIPE:
|
||||
return SecurityConst.errSSLClosedAbort
|
||||
raise
|
||||
|
||||
data_length_pointer[0] = sent
|
||||
|
||||
if sent != bytes_to_write:
|
||||
return SecurityConst.errSSLWouldBlock
|
||||
|
||||
return 0
|
||||
except Exception as e:
|
||||
if wrapped_socket is not None:
|
||||
wrapped_socket._exception = e
|
||||
return SecurityConst.errSSLInternal
|
||||
|
||||
|
||||
# We need to keep these two objects references alive: if they get GC'd while
|
||||
# in use then SecureTransport could attempt to call a function that is in freed
|
||||
# memory. That would be...uh...bad. Yeah, that's the word. Bad.
|
||||
_read_callback_pointer = Security.SSLReadFunc(_read_callback)
|
||||
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
|
||||
|
||||
|
||||
class WrappedSocket(object):
|
||||
"""
|
||||
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
|
||||
|
||||
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
|
||||
collector of PyPy.
|
||||
"""
|
||||
|
||||
def __init__(self, socket):
|
||||
self.socket = socket
|
||||
self.context = None
|
||||
self._makefile_refs = 0
|
||||
self._closed = False
|
||||
self._exception = None
|
||||
self._keychain = None
|
||||
self._keychain_dir = None
|
||||
self._client_cert_chain = None
|
||||
|
||||
# We save off the previously-configured timeout and then set it to
|
||||
# zero. This is done because we use select and friends to handle the
|
||||
# timeouts, but if we leave the timeout set on the lower socket then
|
||||
# Python will "kindly" call select on that socket again for us. Avoid
|
||||
# that by forcing the timeout to zero.
|
||||
self._timeout = self.socket.gettimeout()
|
||||
self.socket.settimeout(0)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _raise_on_error(self):
|
||||
"""
|
||||
A context manager that can be used to wrap calls that do I/O from
|
||||
SecureTransport. If any of the I/O callbacks hit an exception, this
|
||||
context manager will correctly propagate the exception after the fact.
|
||||
This avoids silently swallowing those exceptions.
|
||||
|
||||
It also correctly forces the socket closed.
|
||||
"""
|
||||
self._exception = None
|
||||
|
||||
# We explicitly don't catch around this yield because in the unlikely
|
||||
# event that an exception was hit in the block we don't want to swallow
|
||||
# it.
|
||||
yield
|
||||
if self._exception is not None:
|
||||
exception, self._exception = self._exception, None
|
||||
self.close()
|
||||
raise exception
|
||||
|
||||
def _set_ciphers(self):
|
||||
"""
|
||||
Sets up the allowed ciphers. By default this matches the set in
|
||||
util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done
|
||||
custom and doesn't allow changing at this time, mostly because parsing
|
||||
OpenSSL cipher strings is going to be a freaking nightmare.
|
||||
"""
|
||||
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
|
||||
result = Security.SSLSetEnabledCiphers(
|
||||
self.context, ciphers, len(CIPHER_SUITES)
|
||||
)
|
||||
_assert_no_error(result)
|
||||
|
||||
def _set_alpn_protocols(self, protocols):
|
||||
"""
|
||||
Sets up the ALPN protocols on the context.
|
||||
"""
|
||||
if not protocols:
|
||||
return
|
||||
protocols_arr = _create_cfstring_array(protocols)
|
||||
try:
|
||||
result = Security.SSLSetALPNProtocols(self.context, protocols_arr)
|
||||
_assert_no_error(result)
|
||||
finally:
|
||||
CoreFoundation.CFRelease(protocols_arr)
|
||||
|
||||
def _custom_validate(self, verify, trust_bundle):
|
||||
"""
|
||||
Called when we have set custom validation. We do this in two cases:
|
||||
first, when cert validation is entirely disabled; and second, when
|
||||
using a custom trust DB.
|
||||
Raises an SSLError if the connection is not trusted.
|
||||
"""
|
||||
# If we disabled cert validation, just say: cool.
|
||||
if not verify:
|
||||
return
|
||||
|
||||
successes = (
|
||||
SecurityConst.kSecTrustResultUnspecified,
|
||||
SecurityConst.kSecTrustResultProceed,
|
||||
)
|
||||
try:
|
||||
trust_result = self._evaluate_trust(trust_bundle)
|
||||
if trust_result in successes:
|
||||
return
|
||||
reason = "error code: %d" % (trust_result,)
|
||||
except Exception as e:
|
||||
# Do not trust on error
|
||||
reason = "exception: %r" % (e,)
|
||||
|
||||
# SecureTransport does not send an alert nor shuts down the connection.
|
||||
rec = _build_tls_unknown_ca_alert(self.version())
|
||||
self.socket.sendall(rec)
|
||||
# close the connection immediately
|
||||
# l_onoff = 1, activate linger
|
||||
# l_linger = 0, linger for 0 seoncds
|
||||
opts = struct.pack("ii", 1, 0)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
|
||||
self.close()
|
||||
raise ssl.SSLError("certificate verify failed, %s" % reason)
|
||||
|
||||
def _evaluate_trust(self, trust_bundle):
|
||||
# We want data in memory, so load it up.
|
||||
if os.path.isfile(trust_bundle):
|
||||
with open(trust_bundle, "rb") as f:
|
||||
trust_bundle = f.read()
|
||||
|
||||
cert_array = None
|
||||
trust = Security.SecTrustRef()
|
||||
|
||||
try:
|
||||
# Get a CFArray that contains the certs we want.
|
||||
cert_array = _cert_array_from_pem(trust_bundle)
|
||||
|
||||
# Ok, now the hard part. We want to get the SecTrustRef that ST has
|
||||
# created for this connection, shove our CAs into it, tell ST to
|
||||
# ignore everything else it knows, and then ask if it can build a
|
||||
# chain. This is a buuuunch of code.
|
||||
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
|
||||
_assert_no_error(result)
|
||||
if not trust:
|
||||
raise ssl.SSLError("Failed to copy trust reference")
|
||||
|
||||
result = Security.SecTrustSetAnchorCertificates(trust, cert_array)
|
||||
_assert_no_error(result)
|
||||
|
||||
result = Security.SecTrustSetAnchorCertificatesOnly(trust, True)
|
||||
_assert_no_error(result)
|
||||
|
||||
trust_result = Security.SecTrustResultType()
|
||||
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
|
||||
_assert_no_error(result)
|
||||
finally:
|
||||
if trust:
|
||||
CoreFoundation.CFRelease(trust)
|
||||
|
||||
if cert_array is not None:
|
||||
CoreFoundation.CFRelease(cert_array)
|
||||
|
||||
return trust_result.value
|
||||
|
||||
def handshake(
|
||||
self,
|
||||
server_hostname,
|
||||
verify,
|
||||
trust_bundle,
|
||||
min_version,
|
||||
max_version,
|
||||
client_cert,
|
||||
client_key,
|
||||
client_key_passphrase,
|
||||
alpn_protocols,
|
||||
):
|
||||
"""
|
||||
Actually performs the TLS handshake. This is run automatically by
|
||||
wrapped socket, and shouldn't be needed in user code.
|
||||
"""
|
||||
# First, we do the initial bits of connection setup. We need to create
|
||||
# a context, set its I/O funcs, and set the connection reference.
|
||||
self.context = Security.SSLCreateContext(
|
||||
None, SecurityConst.kSSLClientSide, SecurityConst.kSSLStreamType
|
||||
)
|
||||
result = Security.SSLSetIOFuncs(
|
||||
self.context, _read_callback_pointer, _write_callback_pointer
|
||||
)
|
||||
_assert_no_error(result)
|
||||
|
||||
# Here we need to compute the handle to use. We do this by taking the
|
||||
# id of self modulo 2**31 - 1. If this is already in the dictionary, we
|
||||
# just keep incrementing by one until we find a free space.
|
||||
with _connection_ref_lock:
|
||||
handle = id(self) % 2147483647
|
||||
while handle in _connection_refs:
|
||||
handle = (handle + 1) % 2147483647
|
||||
_connection_refs[handle] = self
|
||||
|
||||
result = Security.SSLSetConnection(self.context, handle)
|
||||
_assert_no_error(result)
|
||||
|
||||
# If we have a server hostname, we should set that too.
|
||||
if server_hostname:
|
||||
if not isinstance(server_hostname, bytes):
|
||||
server_hostname = server_hostname.encode("utf-8")
|
||||
|
||||
result = Security.SSLSetPeerDomainName(
|
||||
self.context, server_hostname, len(server_hostname)
|
||||
)
|
||||
_assert_no_error(result)
|
||||
|
||||
# Setup the ciphers.
|
||||
self._set_ciphers()
|
||||
|
||||
# Setup the ALPN protocols.
|
||||
self._set_alpn_protocols(alpn_protocols)
|
||||
|
||||
# Set the minimum and maximum TLS versions.
|
||||
result = Security.SSLSetProtocolVersionMin(self.context, min_version)
|
||||
_assert_no_error(result)
|
||||
|
||||
result = Security.SSLSetProtocolVersionMax(self.context, max_version)
|
||||
_assert_no_error(result)
|
||||
|
||||
# If there's a trust DB, we need to use it. We do that by telling
|
||||
# SecureTransport to break on server auth. We also do that if we don't
|
||||
# want to validate the certs at all: we just won't actually do any
|
||||
# authing in that case.
|
||||
if not verify or trust_bundle is not None:
|
||||
result = Security.SSLSetSessionOption(
|
||||
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
|
||||
)
|
||||
_assert_no_error(result)
|
||||
|
||||
# If there's a client cert, we need to use it.
|
||||
if client_cert:
|
||||
self._keychain, self._keychain_dir = _temporary_keychain()
|
||||
self._client_cert_chain = _load_client_cert_chain(
|
||||
self._keychain, client_cert, client_key
|
||||
)
|
||||
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
|
||||
_assert_no_error(result)
|
||||
|
||||
while True:
|
||||
with self._raise_on_error():
|
||||
result = Security.SSLHandshake(self.context)
|
||||
|
||||
if result == SecurityConst.errSSLWouldBlock:
|
||||
raise socket.timeout("handshake timed out")
|
||||
elif result == SecurityConst.errSSLServerAuthCompleted:
|
||||
self._custom_validate(verify, trust_bundle)
|
||||
continue
|
||||
else:
|
||||
_assert_no_error(result)
|
||||
break
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
|
||||
# Copy-pasted from Python 3.5 source code
|
||||
def _decref_socketios(self):
|
||||
if self._makefile_refs > 0:
|
||||
self._makefile_refs -= 1
|
||||
if self._closed:
|
||||
self.close()
|
||||
|
||||
def recv(self, bufsiz):
|
||||
buffer = ctypes.create_string_buffer(bufsiz)
|
||||
bytes_read = self.recv_into(buffer, bufsiz)
|
||||
data = buffer[:bytes_read]
|
||||
return data
|
||||
|
||||
def recv_into(self, buffer, nbytes=None):
|
||||
# Read short on EOF.
|
||||
if self._closed:
|
||||
return 0
|
||||
|
||||
if nbytes is None:
|
||||
nbytes = len(buffer)
|
||||
|
||||
buffer = (ctypes.c_char * nbytes).from_buffer(buffer)
|
||||
processed_bytes = ctypes.c_size_t(0)
|
||||
|
||||
with self._raise_on_error():
|
||||
result = Security.SSLRead(
|
||||
self.context, buffer, nbytes, ctypes.byref(processed_bytes)
|
||||
)
|
||||
|
||||
# There are some result codes that we want to treat as "not always
|
||||
# errors". Specifically, those are errSSLWouldBlock,
|
||||
# errSSLClosedGraceful, and errSSLClosedNoNotify.
|
||||
if result == SecurityConst.errSSLWouldBlock:
|
||||
# If we didn't process any bytes, then this was just a time out.
|
||||
# However, we can get errSSLWouldBlock in situations when we *did*
|
||||
# read some data, and in those cases we should just read "short"
|
||||
# and return.
|
||||
if processed_bytes.value == 0:
|
||||
# Timed out, no data read.
|
||||
raise socket.timeout("recv timed out")
|
||||
elif result in (
|
||||
SecurityConst.errSSLClosedGraceful,
|
||||
SecurityConst.errSSLClosedNoNotify,
|
||||
):
|
||||
# The remote peer has closed this connection. We should do so as
|
||||
# well. Note that we don't actually return here because in
|
||||
# principle this could actually be fired along with return data.
|
||||
# It's unlikely though.
|
||||
self.close()
|
||||
else:
|
||||
_assert_no_error(result)
|
||||
|
||||
# Ok, we read and probably succeeded. We should return whatever data
|
||||
# was actually read.
|
||||
return processed_bytes.value
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self._timeout = timeout
|
||||
|
||||
def gettimeout(self):
|
||||
return self._timeout
|
||||
|
||||
def send(self, data):
|
||||
processed_bytes = ctypes.c_size_t(0)
|
||||
|
||||
with self._raise_on_error():
|
||||
result = Security.SSLWrite(
|
||||
self.context, data, len(data), ctypes.byref(processed_bytes)
|
||||
)
|
||||
|
||||
if result == SecurityConst.errSSLWouldBlock and processed_bytes.value == 0:
|
||||
# Timed out
|
||||
raise socket.timeout("send timed out")
|
||||
else:
|
||||
_assert_no_error(result)
|
||||
|
||||
# We sent, and probably succeeded. Tell them how much we sent.
|
||||
return processed_bytes.value
|
||||
|
||||
def sendall(self, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
|
||||
total_sent += sent
|
||||
|
||||
def shutdown(self):
|
||||
with self._raise_on_error():
|
||||
Security.SSLClose(self.context)
|
||||
|
||||
def close(self):
|
||||
# TODO: should I do clean shutdown here? Do I have to?
|
||||
if self._makefile_refs < 1:
|
||||
self._closed = True
|
||||
if self.context:
|
||||
CoreFoundation.CFRelease(self.context)
|
||||
self.context = None
|
||||
if self._client_cert_chain:
|
||||
CoreFoundation.CFRelease(self._client_cert_chain)
|
||||
self._client_cert_chain = None
|
||||
if self._keychain:
|
||||
Security.SecKeychainDelete(self._keychain)
|
||||
CoreFoundation.CFRelease(self._keychain)
|
||||
shutil.rmtree(self._keychain_dir)
|
||||
self._keychain = self._keychain_dir = None
|
||||
return self.socket.close()
|
||||
else:
|
||||
self._makefile_refs -= 1
|
||||
|
||||
def getpeercert(self, binary_form=False):
|
||||
# Urgh, annoying.
|
||||
#
|
||||
# Here's how we do this:
|
||||
#
|
||||
# 1. Call SSLCopyPeerTrust to get hold of the trust object for this
|
||||
# connection.
|
||||
# 2. Call SecTrustGetCertificateAtIndex for index 0 to get the leaf.
|
||||
# 3. To get the CN, call SecCertificateCopyCommonName and process that
|
||||
# string so that it's of the appropriate type.
|
||||
# 4. To get the SAN, we need to do something a bit more complex:
|
||||
# a. Call SecCertificateCopyValues to get the data, requesting
|
||||
# kSecOIDSubjectAltName.
|
||||
# b. Mess about with this dictionary to try to get the SANs out.
|
||||
#
|
||||
# This is gross. Really gross. It's going to be a few hundred LoC extra
|
||||
# just to repeat something that SecureTransport can *already do*. So my
|
||||
# operating assumption at this time is that what we want to do is
|
||||
# instead to just flag to urllib3 that it shouldn't do its own hostname
|
||||
# validation when using SecureTransport.
|
||||
if not binary_form:
|
||||
raise ValueError("SecureTransport only supports dumping binary certs")
|
||||
trust = Security.SecTrustRef()
|
||||
certdata = None
|
||||
der_bytes = None
|
||||
|
||||
try:
|
||||
# Grab the trust store.
|
||||
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
|
||||
_assert_no_error(result)
|
||||
if not trust:
|
||||
# Probably we haven't done the handshake yet. No biggie.
|
||||
return None
|
||||
|
||||
cert_count = Security.SecTrustGetCertificateCount(trust)
|
||||
if not cert_count:
|
||||
# Also a case that might happen if we haven't handshaked.
|
||||
# Handshook? Handshaken?
|
||||
return None
|
||||
|
||||
leaf = Security.SecTrustGetCertificateAtIndex(trust, 0)
|
||||
assert leaf
|
||||
|
||||
# Ok, now we want the DER bytes.
|
||||
certdata = Security.SecCertificateCopyData(leaf)
|
||||
assert certdata
|
||||
|
||||
data_length = CoreFoundation.CFDataGetLength(certdata)
|
||||
data_buffer = CoreFoundation.CFDataGetBytePtr(certdata)
|
||||
der_bytes = ctypes.string_at(data_buffer, data_length)
|
||||
finally:
|
||||
if certdata:
|
||||
CoreFoundation.CFRelease(certdata)
|
||||
if trust:
|
||||
CoreFoundation.CFRelease(trust)
|
||||
|
||||
return der_bytes
|
||||
|
||||
def version(self):
|
||||
protocol = Security.SSLProtocol()
|
||||
result = Security.SSLGetNegotiatedProtocolVersion(
|
||||
self.context, ctypes.byref(protocol)
|
||||
)
|
||||
_assert_no_error(result)
|
||||
if protocol.value == SecurityConst.kTLSProtocol13:
|
||||
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
|
||||
elif protocol.value == SecurityConst.kTLSProtocol12:
|
||||
return "TLSv1.2"
|
||||
elif protocol.value == SecurityConst.kTLSProtocol11:
|
||||
return "TLSv1.1"
|
||||
elif protocol.value == SecurityConst.kTLSProtocol1:
|
||||
return "TLSv1"
|
||||
elif protocol.value == SecurityConst.kSSLProtocol3:
|
||||
return "SSLv3"
|
||||
elif protocol.value == SecurityConst.kSSLProtocol2:
|
||||
return "SSLv2"
|
||||
else:
|
||||
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
|
||||
|
||||
def _reuse(self):
|
||||
self._makefile_refs += 1
|
||||
|
||||
def _drop(self):
|
||||
if self._makefile_refs < 1:
|
||||
self.close()
|
||||
else:
|
||||
self._makefile_refs -= 1
|
||||
|
||||
|
||||
if _fileobject: # Platform-specific: Python 2
|
||||
|
||||
def makefile(self, mode, bufsize=-1):
|
||||
self._makefile_refs += 1
|
||||
return _fileobject(self, mode, bufsize, close=True)
|
||||
|
||||
|
||||
else: # Platform-specific: Python 3
|
||||
|
||||
def makefile(self, mode="r", buffering=None, *args, **kwargs):
|
||||
# We disable buffering with SecureTransport because it conflicts with
|
||||
# the buffering that ST does internally (see issue #1153 for more).
|
||||
buffering = 0
|
||||
return backport_makefile(self, mode, buffering, *args, **kwargs)
|
||||
|
||||
|
||||
WrappedSocket.makefile = makefile
|
||||
|
||||
|
||||
class SecureTransportContext(object):
|
||||
"""
|
||||
I am a wrapper class for the SecureTransport library, to translate the
|
||||
interface of the standard library ``SSLContext`` object to calls into
|
||||
SecureTransport.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol):
|
||||
self._min_version, self._max_version = _protocol_to_min_max[protocol]
|
||||
self._options = 0
|
||||
self._verify = False
|
||||
self._trust_bundle = None
|
||||
self._client_cert = None
|
||||
self._client_key = None
|
||||
self._client_key_passphrase = None
|
||||
self._alpn_protocols = None
|
||||
|
||||
@property
|
||||
def check_hostname(self):
|
||||
"""
|
||||
SecureTransport cannot have its hostname checking disabled. For more,
|
||||
see the comment on getpeercert() in this file.
|
||||
"""
|
||||
return True
|
||||
|
||||
@check_hostname.setter
|
||||
def check_hostname(self, value):
|
||||
"""
|
||||
SecureTransport cannot have its hostname checking disabled. For more,
|
||||
see the comment on getpeercert() in this file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
# TODO: Well, crap.
|
||||
#
|
||||
# So this is the bit of the code that is the most likely to cause us
|
||||
# trouble. Essentially we need to enumerate all of the SSL options that
|
||||
# users might want to use and try to see if we can sensibly translate
|
||||
# them, or whether we should just ignore them.
|
||||
return self._options
|
||||
|
||||
@options.setter
|
||||
def options(self, value):
|
||||
# TODO: Update in line with above.
|
||||
self._options = value
|
||||
|
||||
@property
|
||||
def verify_mode(self):
|
||||
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
|
||||
|
||||
@verify_mode.setter
|
||||
def verify_mode(self, value):
|
||||
self._verify = True if value == ssl.CERT_REQUIRED else False
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
# So, this has to do something a bit weird. Specifically, what it does
|
||||
# is nothing.
|
||||
#
|
||||
# This means that, if we had previously had load_verify_locations
|
||||
# called, this does not undo that. We need to do that because it turns
|
||||
# out that the rest of the urllib3 code will attempt to load the
|
||||
# default verify paths if it hasn't been told about any paths, even if
|
||||
# the context itself was sometime earlier. We resolve that by just
|
||||
# ignoring it.
|
||||
pass
|
||||
|
||||
def load_default_certs(self):
|
||||
return self.set_default_verify_paths()
|
||||
|
||||
def set_ciphers(self, ciphers):
|
||||
# For now, we just require the default cipher string.
|
||||
if ciphers != util.ssl_.DEFAULT_CIPHERS:
|
||||
raise ValueError("SecureTransport doesn't support custom cipher strings")
|
||||
|
||||
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
||||
# OK, we only really support cadata and cafile.
|
||||
if capath is not None:
|
||||
raise ValueError("SecureTransport does not support cert directories")
|
||||
|
||||
# Raise if cafile does not exist.
|
||||
if cafile is not None:
|
||||
with open(cafile):
|
||||
pass
|
||||
|
||||
self._trust_bundle = cafile or cadata
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
||||
self._client_cert = certfile
|
||||
self._client_key = keyfile
|
||||
self._client_cert_passphrase = password
|
||||
|
||||
def set_alpn_protocols(self, protocols):
|
||||
"""
|
||||
Sets the ALPN protocols that will later be set on the context.
|
||||
|
||||
Raises a NotImplementedError if ALPN is not supported.
|
||||
"""
|
||||
if not hasattr(Security, "SSLSetALPNProtocols"):
|
||||
raise NotImplementedError(
|
||||
"SecureTransport supports ALPN only in macOS 10.12+"
|
||||
)
|
||||
self._alpn_protocols = [six.ensure_binary(p) for p in protocols]
|
||||
|
||||
def wrap_socket(
|
||||
self,
|
||||
sock,
|
||||
server_side=False,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True,
|
||||
server_hostname=None,
|
||||
):
|
||||
# So, what do we do here? Firstly, we assert some properties. This is a
|
||||
# stripped down shim, so there is some functionality we don't support.
|
||||
# See PEP 543 for the real deal.
|
||||
assert not server_side
|
||||
assert do_handshake_on_connect
|
||||
assert suppress_ragged_eofs
|
||||
|
||||
# Ok, we're good to go. Now we want to create the wrapped socket object
|
||||
# and store it in the appropriate place.
|
||||
wrapped_socket = WrappedSocket(sock)
|
||||
|
||||
# Now we can handshake
|
||||
wrapped_socket.handshake(
|
||||
server_hostname,
|
||||
self._verify,
|
||||
self._trust_bundle,
|
||||
self._min_version,
|
||||
self._max_version,
|
||||
self._client_cert,
|
||||
self._client_key,
|
||||
self._client_key_passphrase,
|
||||
self._alpn_protocols,
|
||||
)
|
||||
return wrapped_socket
|
||||
|
|
@ -1,216 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This module contains provisional support for SOCKS proxies from within
|
||||
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
|
||||
SOCKS5. To enable its functionality, either install PySocks or install this
|
||||
module with the ``socks`` extra.
|
||||
|
||||
The SOCKS implementation supports the full range of urllib3 features. It also
|
||||
supports the following SOCKS features:
|
||||
|
||||
- SOCKS4A (``proxy_url='socks4a://...``)
|
||||
- SOCKS4 (``proxy_url='socks4://...``)
|
||||
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
|
||||
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
|
||||
- Usernames and passwords for the SOCKS proxy
|
||||
|
||||
.. note::
|
||||
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
|
||||
your ``proxy_url`` to ensure that DNS resolution is done from the remote
|
||||
server instead of client-side when connecting to a domain name.
|
||||
|
||||
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
|
||||
supports IPv4, IPv6, and domain names.
|
||||
|
||||
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
|
||||
will be sent as the ``userid`` section of the SOCKS request:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks4a://<userid>@proxy-host"
|
||||
|
||||
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
|
||||
of the ``proxy_url`` will be sent as the username/password to authenticate
|
||||
with the proxy:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks5h://<username>:<password>@proxy-host"
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
try:
|
||||
import socks
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
from ..exceptions import DependencyWarning
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"SOCKS support in urllib3 requires the installation of optional "
|
||||
"dependencies: specifically, PySocks. For more information, see "
|
||||
"https://urllib3.readthedocs.io/en/1.26.x/contrib.html#socks-proxies"
|
||||
),
|
||||
DependencyWarning,
|
||||
)
|
||||
raise
|
||||
|
||||
from socket import error as SocketError
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
from ..connection import HTTPConnection, HTTPSConnection
|
||||
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
|
||||
from ..exceptions import ConnectTimeoutError, NewConnectionError
|
||||
from ..poolmanager import PoolManager
|
||||
from ..util.url import parse_url
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
|
||||
class SOCKSConnection(HTTPConnection):
|
||||
"""
|
||||
A plain-text HTTP connection that connects via a SOCKS proxy.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._socks_options = kwargs.pop("_socks_options")
|
||||
super(SOCKSConnection, self).__init__(*args, **kwargs)
|
||||
|
||||
def _new_conn(self):
|
||||
"""
|
||||
Establish a new connection via the SOCKS proxy.
|
||||
"""
|
||||
extra_kw = {}
|
||||
if self.source_address:
|
||||
extra_kw["source_address"] = self.source_address
|
||||
|
||||
if self.socket_options:
|
||||
extra_kw["socket_options"] = self.socket_options
|
||||
|
||||
try:
|
||||
conn = socks.create_connection(
|
||||
(self.host, self.port),
|
||||
proxy_type=self._socks_options["socks_version"],
|
||||
proxy_addr=self._socks_options["proxy_host"],
|
||||
proxy_port=self._socks_options["proxy_port"],
|
||||
proxy_username=self._socks_options["username"],
|
||||
proxy_password=self._socks_options["password"],
|
||||
proxy_rdns=self._socks_options["rdns"],
|
||||
timeout=self.timeout,
|
||||
**extra_kw
|
||||
)
|
||||
|
||||
except SocketTimeout:
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
"Connection to %s timed out. (connect timeout=%s)"
|
||||
% (self.host, self.timeout),
|
||||
)
|
||||
|
||||
except socks.ProxyError as e:
|
||||
# This is fragile as hell, but it seems to be the only way to raise
|
||||
# useful errors here.
|
||||
if e.socket_err:
|
||||
error = e.socket_err
|
||||
if isinstance(error, SocketTimeout):
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
"Connection to %s timed out. (connect timeout=%s)"
|
||||
% (self.host, self.timeout),
|
||||
)
|
||||
else:
|
||||
raise NewConnectionError(
|
||||
self, "Failed to establish a new connection: %s" % error
|
||||
)
|
||||
else:
|
||||
raise NewConnectionError(
|
||||
self, "Failed to establish a new connection: %s" % e
|
||||
)
|
||||
|
||||
except SocketError as e: # Defensive: PySocks should catch all these.
|
||||
raise NewConnectionError(
|
||||
self, "Failed to establish a new connection: %s" % e
|
||||
)
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
# We don't need to duplicate the Verified/Unverified distinction from
|
||||
# urllib3/connection.py here because the HTTPSConnection will already have been
|
||||
# correctly set to either the Verified or Unverified form by that module. This
|
||||
# means the SOCKSHTTPSConnection will automatically be the correct type.
|
||||
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
|
||||
pass
|
||||
|
||||
|
||||
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = SOCKSConnection
|
||||
|
||||
|
||||
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = SOCKSHTTPSConnection
|
||||
|
||||
|
||||
class SOCKSProxyManager(PoolManager):
|
||||
"""
|
||||
A version of the urllib3 ProxyManager that routes connections via the
|
||||
defined SOCKS proxy.
|
||||
"""
|
||||
|
||||
pool_classes_by_scheme = {
|
||||
"http": SOCKSHTTPConnectionPool,
|
||||
"https": SOCKSHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url,
|
||||
username=None,
|
||||
password=None,
|
||||
num_pools=10,
|
||||
headers=None,
|
||||
**connection_pool_kw
|
||||
):
|
||||
parsed = parse_url(proxy_url)
|
||||
|
||||
if username is None and password is None and parsed.auth is not None:
|
||||
split = parsed.auth.split(":")
|
||||
if len(split) == 2:
|
||||
username, password = split
|
||||
if parsed.scheme == "socks5":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS5
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks5h":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS5
|
||||
rdns = True
|
||||
elif parsed.scheme == "socks4":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS4
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks4a":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS4
|
||||
rdns = True
|
||||
else:
|
||||
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
|
||||
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
socks_options = {
|
||||
"socks_version": socks_version,
|
||||
"proxy_host": parsed.host,
|
||||
"proxy_port": parsed.port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"rdns": rdns,
|
||||
}
|
||||
connection_pool_kw["_socks_options"] = socks_options
|
||||
|
||||
super(SOCKSProxyManager, self).__init__(
|
||||
num_pools, headers, **connection_pool_kw
|
||||
)
|
||||
|
||||
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
|
||||
|
|
@ -1,323 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
|
||||
|
||||
# Base Exceptions
|
||||
|
||||
|
||||
class HTTPError(Exception):
|
||||
"""Base exception used by this module."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HTTPWarning(Warning):
|
||||
"""Base warning used by this module."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PoolError(HTTPError):
|
||||
"""Base exception for errors caused within a pool."""
|
||||
|
||||
def __init__(self, pool, message):
|
||||
self.pool = pool
|
||||
HTTPError.__init__(self, "%s: %s" % (pool, message))
|
||||
|
||||
def __reduce__(self):
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, None)
|
||||
|
||||
|
||||
class RequestError(PoolError):
|
||||
"""Base exception for PoolErrors that have associated URLs."""
|
||||
|
||||
def __init__(self, pool, url, message):
|
||||
self.url = url
|
||||
PoolError.__init__(self, pool, message)
|
||||
|
||||
def __reduce__(self):
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, self.url, None)
|
||||
|
||||
|
||||
class SSLError(HTTPError):
|
||||
"""Raised when SSL certificate fails in an HTTPS connection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProxyError(HTTPError):
|
||||
"""Raised when the connection to a proxy fails."""
|
||||
|
||||
def __init__(self, message, error, *args):
|
||||
super(ProxyError, self).__init__(message, error, *args)
|
||||
self.original_error = error
|
||||
|
||||
|
||||
class DecodeError(HTTPError):
|
||||
"""Raised when automatic decoding based on Content-Type fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(HTTPError):
|
||||
"""Raised when something unexpected happens mid-request/response."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
#: Renamed to ProtocolError but aliased for backwards compatibility.
|
||||
ConnectionError = ProtocolError
|
||||
|
||||
|
||||
# Leaf Exceptions
|
||||
|
||||
|
||||
class MaxRetryError(RequestError):
|
||||
"""Raised when the maximum number of retries is exceeded.
|
||||
|
||||
:param pool: The connection pool
|
||||
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
|
||||
:param string url: The requested Url
|
||||
:param exceptions.Exception reason: The underlying error
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, pool, url, reason=None):
|
||||
self.reason = reason
|
||||
|
||||
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
|
||||
|
||||
RequestError.__init__(self, pool, url, message)
|
||||
|
||||
|
||||
class HostChangedError(RequestError):
|
||||
"""Raised when an existing pool gets a request for a foreign host."""
|
||||
|
||||
def __init__(self, pool, url, retries=3):
|
||||
message = "Tried to open a foreign host with url: %s" % url
|
||||
RequestError.__init__(self, pool, url, message)
|
||||
self.retries = retries
|
||||
|
||||
|
||||
class TimeoutStateError(HTTPError):
|
||||
"""Raised when passing an invalid state to a timeout"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(HTTPError):
|
||||
"""Raised when a socket timeout error occurs.
|
||||
|
||||
Catching this error will catch both :exc:`ReadTimeoutErrors
|
||||
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ReadTimeoutError(TimeoutError, RequestError):
|
||||
"""Raised when a socket timeout occurs while receiving data from a server"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# This timeout error does not have a URL attached and needs to inherit from the
|
||||
# base HTTPError
|
||||
class ConnectTimeoutError(TimeoutError):
|
||||
"""Raised when a socket timeout occurs while connecting to a server"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NewConnectionError(ConnectTimeoutError, PoolError):
|
||||
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmptyPoolError(PoolError):
|
||||
"""Raised when a pool runs out of connections and no more are allowed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClosedPoolError(PoolError):
|
||||
"""Raised when a request enters a pool after the pool has been closed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LocationValueError(ValueError, HTTPError):
|
||||
"""Raised when there is something wrong with a given URL input."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LocationParseError(LocationValueError):
|
||||
"""Raised when get_host or similar fails to parse the URL input."""
|
||||
|
||||
def __init__(self, location):
|
||||
message = "Failed to parse: %s" % location
|
||||
HTTPError.__init__(self, message)
|
||||
|
||||
self.location = location
|
||||
|
||||
|
||||
class URLSchemeUnknown(LocationValueError):
|
||||
"""Raised when a URL input has an unsupported scheme."""
|
||||
|
||||
def __init__(self, scheme):
|
||||
message = "Not supported URL scheme %s" % scheme
|
||||
super(URLSchemeUnknown, self).__init__(message)
|
||||
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
class ResponseError(HTTPError):
|
||||
"""Used as a container for an error reason supplied in a MaxRetryError."""
|
||||
|
||||
GENERIC_ERROR = "too many error responses"
|
||||
SPECIFIC_ERROR = "too many {status_code} error responses"
|
||||
|
||||
|
||||
class SecurityWarning(HTTPWarning):
|
||||
"""Warned when performing security reducing actions"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubjectAltNameWarning(SecurityWarning):
|
||||
"""Warned when connecting to a host with a certificate missing a SAN."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InsecureRequestWarning(SecurityWarning):
|
||||
"""Warned when making an unverified HTTPS request."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemTimeWarning(SecurityWarning):
|
||||
"""Warned when system time is suspected to be wrong"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InsecurePlatformWarning(SecurityWarning):
|
||||
"""Warned when certain TLS/SSL configuration is not available on a platform."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SNIMissingWarning(HTTPWarning):
|
||||
"""Warned when making a HTTPS request without SNI available."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DependencyWarning(HTTPWarning):
|
||||
"""
|
||||
Warned when an attempt is made to import a module with missing optional
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResponseNotChunked(ProtocolError, ValueError):
|
||||
"""Response needs to be chunked in order to read it as chunks."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BodyNotHttplibCompatible(HTTPError):
|
||||
"""
|
||||
Body should be :class:`http.client.HTTPResponse` like
|
||||
(have an fp attribute which returns raw chunks) for read_chunked().
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IncompleteRead(HTTPError, httplib_IncompleteRead):
|
||||
"""
|
||||
Response length doesn't match expected Content-Length
|
||||
|
||||
Subclass of :class:`http.client.IncompleteRead` to allow int value
|
||||
for ``partial`` to avoid creating large objects on streamed reads.
|
||||
"""
|
||||
|
||||
def __init__(self, partial, expected):
|
||||
super(IncompleteRead, self).__init__(partial, expected)
|
||||
|
||||
def __repr__(self):
|
||||
return "IncompleteRead(%i bytes read, %i more expected)" % (
|
||||
self.partial,
|
||||
self.expected,
|
||||
)
|
||||
|
||||
|
||||
class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
|
||||
"""Invalid chunk length in a chunked response."""
|
||||
|
||||
def __init__(self, response, length):
|
||||
super(InvalidChunkLength, self).__init__(
|
||||
response.tell(), response.length_remaining
|
||||
)
|
||||
self.response = response
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return "InvalidChunkLength(got length %r, %i bytes read)" % (
|
||||
self.length,
|
||||
self.partial,
|
||||
)
|
||||
|
||||
|
||||
class InvalidHeader(HTTPError):
|
||||
"""The header provided was somehow invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
|
||||
"""ProxyManager does not support the supplied scheme"""
|
||||
|
||||
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
|
||||
|
||||
def __init__(self, scheme):
|
||||
# 'localhost' is here because our URL parser parses
|
||||
# localhost:8080 -> scheme=localhost, remove if we fix this.
|
||||
if scheme == "localhost":
|
||||
scheme = None
|
||||
if scheme is None:
|
||||
message = "Proxy URL had no scheme, should start with http:// or https://"
|
||||
else:
|
||||
message = (
|
||||
"Proxy URL had unsupported scheme %s, should use http:// or https://"
|
||||
% scheme
|
||||
)
|
||||
super(ProxySchemeUnknown, self).__init__(message)
|
||||
|
||||
|
||||
class ProxySchemeUnsupported(ValueError):
|
||||
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HeaderParsingError(HTTPError):
|
||||
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
|
||||
|
||||
def __init__(self, defects, unparsed_data):
|
||||
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
|
||||
super(HeaderParsingError, self).__init__(message)
|
||||
|
||||
|
||||
class UnrewindableBodyError(HTTPError):
|
||||
"""urllib3 encountered an error when trying to rewind a body"""
|
||||
|
||||
pass
|
||||
|
|
@ -1,274 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import email.utils
|
||||
import mimetypes
|
||||
import re
|
||||
|
||||
from .packages import six
|
||||
|
||||
|
||||
def guess_content_type(filename, default="application/octet-stream"):
|
||||
"""
|
||||
Guess the "Content-Type" of a file.
|
||||
|
||||
:param filename:
|
||||
The filename to guess the "Content-Type" of using :mod:`mimetypes`.
|
||||
:param default:
|
||||
If no "Content-Type" can be guessed, default to `default`.
|
||||
"""
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or default
|
||||
return default
|
||||
|
||||
|
||||
def format_header_param_rfc2231(name, value):
|
||||
"""
|
||||
Helper function to format and quote a single header parameter using the
|
||||
strategy defined in RFC 2231.
|
||||
|
||||
Particularly useful for header parameters which might contain
|
||||
non-ASCII values, like file names. This follows
|
||||
`RFC 2388 Section 4.4 <https://tools.ietf.org/html/rfc2388#section-4.4>`_.
|
||||
|
||||
:param name:
|
||||
The name of the parameter, a string expected to be ASCII only.
|
||||
:param value:
|
||||
The value of the parameter, provided as ``bytes`` or `str``.
|
||||
:ret:
|
||||
An RFC-2231-formatted unicode string.
|
||||
"""
|
||||
if isinstance(value, six.binary_type):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
if not any(ch in value for ch in '"\\\r\n'):
|
||||
result = u'%s="%s"' % (name, value)
|
||||
try:
|
||||
result.encode("ascii")
|
||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||
pass
|
||||
else:
|
||||
return result
|
||||
|
||||
if six.PY2: # Python 2:
|
||||
value = value.encode("utf-8")
|
||||
|
||||
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
|
||||
# string in Python 2 but accepts and returns unicode strings in Python 3
|
||||
value = email.utils.encode_rfc2231(value, "utf-8")
|
||||
value = "%s*=%s" % (name, value)
|
||||
|
||||
if six.PY2: # Python 2:
|
||||
value = value.decode("utf-8")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
_HTML5_REPLACEMENTS = {
|
||||
u"\u0022": u"%22",
|
||||
# Replace "\" with "\\".
|
||||
u"\u005C": u"\u005C\u005C",
|
||||
}
|
||||
|
||||
# All control characters from 0x00 to 0x1F *except* 0x1B.
|
||||
_HTML5_REPLACEMENTS.update(
|
||||
{
|
||||
six.unichr(cc): u"%{:02X}".format(cc)
|
||||
for cc in range(0x00, 0x1F + 1)
|
||||
if cc not in (0x1B,)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _replace_multiple(value, needles_and_replacements):
|
||||
def replacer(match):
|
||||
return needles_and_replacements[match.group(0)]
|
||||
|
||||
pattern = re.compile(
|
||||
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
|
||||
)
|
||||
|
||||
result = pattern.sub(replacer, value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_header_param_html5(name, value):
|
||||
"""
|
||||
Helper function to format and quote a single header parameter using the
|
||||
HTML5 strategy.
|
||||
|
||||
Particularly useful for header parameters which might contain
|
||||
non-ASCII values, like file names. This follows the `HTML5 Working Draft
|
||||
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers.
|
||||
|
||||
.. _HTML5 Working Draft Section 4.10.22.7:
|
||||
https://w3c.github.io/html/sec-forms.html#multipart-form-data
|
||||
|
||||
:param name:
|
||||
The name of the parameter, a string expected to be ASCII only.
|
||||
:param value:
|
||||
The value of the parameter, provided as ``bytes`` or `str``.
|
||||
:ret:
|
||||
A unicode string, stripped of troublesome characters.
|
||||
"""
|
||||
if isinstance(value, six.binary_type):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
value = _replace_multiple(value, _HTML5_REPLACEMENTS)
|
||||
|
||||
return u'%s="%s"' % (name, value)
|
||||
|
||||
|
||||
# For backwards-compatibility.
|
||||
format_header_param = format_header_param_html5
|
||||
|
||||
|
||||
class RequestField(object):
|
||||
"""
|
||||
A data container for request body parameters.
|
||||
|
||||
:param name:
|
||||
The name of this request field. Must be unicode.
|
||||
:param data:
|
||||
The data/value body.
|
||||
:param filename:
|
||||
An optional filename of the request field. Must be unicode.
|
||||
:param headers:
|
||||
An optional dict-like object of headers to initially use for the field.
|
||||
:param header_formatter:
|
||||
An optional callable that is used to encode and format the headers. By
|
||||
default, this is :func:`format_header_param_html5`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
data,
|
||||
filename=None,
|
||||
headers=None,
|
||||
header_formatter=format_header_param_html5,
|
||||
):
|
||||
self._name = name
|
||||
self._filename = filename
|
||||
self.data = data
|
||||
self.headers = {}
|
||||
if headers:
|
||||
self.headers = dict(headers)
|
||||
self.header_formatter = header_formatter
|
||||
|
||||
@classmethod
|
||||
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
|
||||
"""
|
||||
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
|
||||
|
||||
Supports constructing :class:`~urllib3.fields.RequestField` from
|
||||
parameter of key/value strings AND key/filetuple. A filetuple is a
|
||||
(filename, data, MIME type) tuple where the MIME type is optional.
|
||||
For example::
|
||||
|
||||
'foo': 'bar',
|
||||
'fakefile': ('foofile.txt', 'contents of foofile'),
|
||||
'realfile': ('barfile.txt', open('realfile').read()),
|
||||
'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
|
||||
'nonamefile': 'contents of nonamefile field',
|
||||
|
||||
Field names and filenames must be unicode.
|
||||
"""
|
||||
if isinstance(value, tuple):
|
||||
if len(value) == 3:
|
||||
filename, data, content_type = value
|
||||
else:
|
||||
filename, data = value
|
||||
content_type = guess_content_type(filename)
|
||||
else:
|
||||
filename = None
|
||||
content_type = None
|
||||
data = value
|
||||
|
||||
request_param = cls(
|
||||
fieldname, data, filename=filename, header_formatter=header_formatter
|
||||
)
|
||||
request_param.make_multipart(content_type=content_type)
|
||||
|
||||
return request_param
|
||||
|
||||
def _render_part(self, name, value):
|
||||
"""
|
||||
Overridable helper function to format a single header parameter. By
|
||||
default, this calls ``self.header_formatter``.
|
||||
|
||||
:param name:
|
||||
The name of the parameter, a string expected to be ASCII only.
|
||||
:param value:
|
||||
The value of the parameter, provided as a unicode string.
|
||||
"""
|
||||
|
||||
return self.header_formatter(name, value)
|
||||
|
||||
def _render_parts(self, header_parts):
|
||||
"""
|
||||
Helper function to format and quote a single header.
|
||||
|
||||
Useful for single headers that are composed of multiple items. E.g.,
|
||||
'Content-Disposition' fields.
|
||||
|
||||
:param header_parts:
|
||||
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
|
||||
as `k1="v1"; k2="v2"; ...`.
|
||||
"""
|
||||
parts = []
|
||||
iterable = header_parts
|
||||
if isinstance(header_parts, dict):
|
||||
iterable = header_parts.items()
|
||||
|
||||
for name, value in iterable:
|
||||
if value is not None:
|
||||
parts.append(self._render_part(name, value))
|
||||
|
||||
return u"; ".join(parts)
|
||||
|
||||
def render_headers(self):
|
||||
"""
|
||||
Renders the headers for this request field.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
|
||||
for sort_key in sort_keys:
|
||||
if self.headers.get(sort_key, False):
|
||||
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
|
||||
|
||||
for header_name, header_value in self.headers.items():
|
||||
if header_name not in sort_keys:
|
||||
if header_value:
|
||||
lines.append(u"%s: %s" % (header_name, header_value))
|
||||
|
||||
lines.append(u"\r\n")
|
||||
return u"\r\n".join(lines)
|
||||
|
||||
def make_multipart(
|
||||
self, content_disposition=None, content_type=None, content_location=None
|
||||
):
|
||||
"""
|
||||
Makes this request field into a multipart request field.
|
||||
|
||||
This method overrides "Content-Disposition", "Content-Type" and
|
||||
"Content-Location" headers to the request parameter.
|
||||
|
||||
:param content_type:
|
||||
The 'Content-Type' of the request body.
|
||||
:param content_location:
|
||||
The 'Content-Location' of the request body.
|
||||
|
||||
"""
|
||||
self.headers["Content-Disposition"] = content_disposition or u"form-data"
|
||||
self.headers["Content-Disposition"] += u"; ".join(
|
||||
[
|
||||
u"",
|
||||
self._render_parts(
|
||||
((u"name", self._name), (u"filename", self._filename))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.headers["Content-Type"] = content_type
|
||||
self.headers["Content-Location"] = content_location
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
from .fields import RequestField
|
||||
from .packages import six
|
||||
from .packages.six import b
|
||||
|
||||
writer = codecs.lookup("utf-8")[3]
|
||||
|
||||
|
||||
def choose_boundary():
|
||||
"""
|
||||
Our embarrassingly-simple replacement for mimetools.choose_boundary.
|
||||
"""
|
||||
boundary = binascii.hexlify(os.urandom(16))
|
||||
if not six.PY2:
|
||||
boundary = boundary.decode("ascii")
|
||||
return boundary
|
||||
|
||||
|
||||
def iter_field_objects(fields):
|
||||
"""
|
||||
Iterate over fields.
|
||||
|
||||
Supports list of (k, v) tuples and dicts, and lists of
|
||||
:class:`~urllib3.fields.RequestField`.
|
||||
|
||||
"""
|
||||
if isinstance(fields, dict):
|
||||
i = six.iteritems(fields)
|
||||
else:
|
||||
i = iter(fields)
|
||||
|
||||
for field in i:
|
||||
if isinstance(field, RequestField):
|
||||
yield field
|
||||
else:
|
||||
yield RequestField.from_tuples(*field)
|
||||
|
||||
|
||||
def iter_fields(fields):
|
||||
"""
|
||||
.. deprecated:: 1.6
|
||||
|
||||
Iterate over fields.
|
||||
|
||||
The addition of :class:`~urllib3.fields.RequestField` makes this function
|
||||
obsolete. Instead, use :func:`iter_field_objects`, which returns
|
||||
:class:`~urllib3.fields.RequestField` objects.
|
||||
|
||||
Supports list of (k, v) tuples and dicts.
|
||||
"""
|
||||
if isinstance(fields, dict):
|
||||
return ((k, v) for k, v in six.iteritems(fields))
|
||||
|
||||
return ((k, v) for k, v in fields)
|
||||
|
||||
|
||||
def encode_multipart_formdata(fields, boundary=None):
|
||||
"""
|
||||
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
|
||||
|
||||
:param fields:
|
||||
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
|
||||
|
||||
:param boundary:
|
||||
If not specified, then a random boundary will be generated using
|
||||
:func:`urllib3.filepost.choose_boundary`.
|
||||
"""
|
||||
body = BytesIO()
|
||||
if boundary is None:
|
||||
boundary = choose_boundary()
|
||||
|
||||
for field in iter_field_objects(fields):
|
||||
body.write(b("--%s\r\n" % (boundary)))
|
||||
|
||||
writer(body).write(field.render_headers())
|
||||
data = field.data
|
||||
|
||||
if isinstance(data, int):
|
||||
data = str(data) # Backwards compatibility
|
||||
|
||||
if isinstance(data, six.text_type):
|
||||
writer(body).write(data)
|
||||
else:
|
||||
body.write(data)
|
||||
|
||||
body.write(b"\r\n")
|
||||
|
||||
body.write(b("--%s--\r\n" % (boundary)))
|
||||
|
||||
content_type = str("multipart/form-data; boundary=%s" % boundary)
|
||||
|
||||
return body.getvalue(), content_type
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from . import ssl_match_hostname
|
||||
|
||||
__all__ = ("ssl_match_hostname",)
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
backports.makefile
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Backports the Python 3 ``socket.makefile`` method for use with anything that
|
||||
wants to create a "fake" socket object.
|
||||
"""
|
||||
import io
|
||||
from socket import SocketIO
|
||||
|
||||
|
||||
def backport_makefile(
|
||||
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
|
||||
):
|
||||
"""
|
||||
Backport of ``socket.makefile`` from Python 3.5.
|
||||
"""
|
||||
if not set(mode) <= {"r", "w", "b"}:
|
||||
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
|
||||
writing = "w" in mode
|
||||
reading = "r" in mode or not writing
|
||||
assert reading or writing
|
||||
binary = "b" in mode
|
||||
rawmode = ""
|
||||
if reading:
|
||||
rawmode += "r"
|
||||
if writing:
|
||||
rawmode += "w"
|
||||
raw = SocketIO(self, rawmode)
|
||||
self._makefile_refs += 1
|
||||
if buffering is None:
|
||||
buffering = -1
|
||||
if buffering < 0:
|
||||
buffering = io.DEFAULT_BUFFER_SIZE
|
||||
if buffering == 0:
|
||||
if not binary:
|
||||
raise ValueError("unbuffered streams must be binary")
|
||||
return raw
|
||||
if reading and writing:
|
||||
buffer = io.BufferedRWPair(raw, raw, buffering)
|
||||
elif reading:
|
||||
buffer = io.BufferedReader(raw, buffering)
|
||||
else:
|
||||
assert writing
|
||||
buffer = io.BufferedWriter(raw, buffering)
|
||||
if binary:
|
||||
return buffer
|
||||
text = io.TextIOWrapper(buffer, encoding, errors, newline)
|
||||
text.mode = mode
|
||||
return text
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,24 +0,0 @@
|
|||
import sys
|
||||
|
||||
try:
|
||||
# Our match_hostname function is the same as 3.10's, so we only want to
|
||||
# import the match_hostname function if it's at least that good.
|
||||
# We also fallback on Python 3.10+ because our code doesn't emit
|
||||
# deprecation warnings and is the same as Python 3.10 otherwise.
|
||||
if sys.version_info < (3, 5) or sys.version_info >= (3, 10):
|
||||
raise ImportError("Fallback to vendored code")
|
||||
|
||||
from ssl import CertificateError, match_hostname
|
||||
except ImportError:
|
||||
try:
|
||||
# Backport of the function from a pypi module
|
||||
from backports.ssl_match_hostname import ( # type: ignore
|
||||
CertificateError,
|
||||
match_hostname,
|
||||
)
|
||||
except ImportError:
|
||||
# Our vendored copy
|
||||
from ._implementation import CertificateError, match_hostname # type: ignore
|
||||
|
||||
# Not needed, but documenting what we provide.
|
||||
__all__ = ("CertificateError", "match_hostname")
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
"""The match_hostname() function from Python 3.3.3, essential when using SSL."""
|
||||
|
||||
# Note: This file is under the PSF license as the code comes from the python
|
||||
# stdlib. http://docs.python.org/3/license.html
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the
|
||||
# system, use it to handle IPAddress ServerAltnames (this was added in
|
||||
# python-3.5) otherwise only do DNS matching. This allows
|
||||
# backports.ssl_match_hostname to continue to be used in Python 2.7.
|
||||
try:
|
||||
import ipaddress
|
||||
except ImportError:
|
||||
ipaddress = None
|
||||
|
||||
__version__ = "3.5.0.1"
|
||||
|
||||
|
||||
class CertificateError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _dnsname_match(dn, hostname, max_wildcards=1):
|
||||
"""Matching according to RFC 6125, section 6.4.3
|
||||
|
||||
http://tools.ietf.org/html/rfc6125#section-6.4.3
|
||||
"""
|
||||
pats = []
|
||||
if not dn:
|
||||
return False
|
||||
|
||||
# Ported from python3-syntax:
|
||||
# leftmost, *remainder = dn.split(r'.')
|
||||
parts = dn.split(r".")
|
||||
leftmost = parts[0]
|
||||
remainder = parts[1:]
|
||||
|
||||
wildcards = leftmost.count("*")
|
||||
if wildcards > max_wildcards:
|
||||
# Issue #17980: avoid denials of service by refusing more
|
||||
# than one wildcard per fragment. A survey of established
|
||||
# policy among SSL implementations showed it to be a
|
||||
# reasonable choice.
|
||||
raise CertificateError(
|
||||
"too many wildcards in certificate DNS name: " + repr(dn)
|
||||
)
|
||||
|
||||
# speed up common case w/o wildcards
|
||||
if not wildcards:
|
||||
return dn.lower() == hostname.lower()
|
||||
|
||||
# RFC 6125, section 6.4.3, subitem 1.
|
||||
# The client SHOULD NOT attempt to match a presented identifier in which
|
||||
# the wildcard character comprises a label other than the left-most label.
|
||||
if leftmost == "*":
|
||||
# When '*' is a fragment by itself, it matches a non-empty dotless
|
||||
# fragment.
|
||||
pats.append("[^.]+")
|
||||
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
|
||||
# RFC 6125, section 6.4.3, subitem 3.
|
||||
# The client SHOULD NOT attempt to match a presented identifier
|
||||
# where the wildcard character is embedded within an A-label or
|
||||
# U-label of an internationalized domain name.
|
||||
pats.append(re.escape(leftmost))
|
||||
else:
|
||||
# Otherwise, '*' matches any dotless string, e.g. www*
|
||||
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
|
||||
|
||||
# add the remaining fragments, ignore any wildcards
|
||||
for frag in remainder:
|
||||
pats.append(re.escape(frag))
|
||||
|
||||
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
|
||||
return pat.match(hostname)
|
||||
|
||||
|
||||
def _to_unicode(obj):
|
||||
if isinstance(obj, str) and sys.version_info < (3,):
|
||||
obj = unicode(obj, encoding="ascii", errors="strict")
|
||||
return obj
|
||||
|
||||
|
||||
def _ipaddress_match(ipname, host_ip):
|
||||
"""Exact matching of IP addresses.
|
||||
|
||||
RFC 6125 explicitly doesn't define an algorithm for this
|
||||
(section 1.7.2 - "Out of Scope").
|
||||
"""
|
||||
# OpenSSL may add a trailing newline to a subjectAltName's IP address
|
||||
# Divergence from upstream: ipaddress can't handle byte str
|
||||
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip())
|
||||
return ip == host_ip
|
||||
|
||||
|
||||
def match_hostname(cert, hostname):
|
||||
"""Verify that *cert* (in decoded format as returned by
|
||||
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
|
||||
rules are followed, but IP addresses are not accepted for *hostname*.
|
||||
|
||||
CertificateError is raised on failure. On success, the function
|
||||
returns nothing.
|
||||
"""
|
||||
if not cert:
|
||||
raise ValueError(
|
||||
"empty or no certificate, match_hostname needs a "
|
||||
"SSL socket or SSL context with either "
|
||||
"CERT_OPTIONAL or CERT_REQUIRED"
|
||||
)
|
||||
try:
|
||||
# Divergence from upstream: ipaddress can't handle byte str
|
||||
host_ip = ipaddress.ip_address(_to_unicode(hostname))
|
||||
except ValueError:
|
||||
# Not an IP address (common case)
|
||||
host_ip = None
|
||||
except UnicodeError:
|
||||
# Divergence from upstream: Have to deal with ipaddress not taking
|
||||
# byte strings. addresses should be all ascii, so we consider it not
|
||||
# an ipaddress in this case
|
||||
host_ip = None
|
||||
except AttributeError:
|
||||
# Divergence from upstream: Make ipaddress library optional
|
||||
if ipaddress is None:
|
||||
host_ip = None
|
||||
else:
|
||||
raise
|
||||
dnsnames = []
|
||||
san = cert.get("subjectAltName", ())
|
||||
for key, value in san:
|
||||
if key == "DNS":
|
||||
if host_ip is None and _dnsname_match(value, hostname):
|
||||
return
|
||||
dnsnames.append(value)
|
||||
elif key == "IP Address":
|
||||
if host_ip is not None and _ipaddress_match(value, host_ip):
|
||||
return
|
||||
dnsnames.append(value)
|
||||
if not dnsnames:
|
||||
# The subject is only checked when there is no dNSName entry
|
||||
# in subjectAltName
|
||||
for sub in cert.get("subject", ()):
|
||||
for key, value in sub:
|
||||
# XXX according to RFC 2818, the most specific Common Name
|
||||
# must be used.
|
||||
if key == "commonName":
|
||||
if _dnsname_match(value, hostname):
|
||||
return
|
||||
dnsnames.append(value)
|
||||
if len(dnsnames) > 1:
|
||||
raise CertificateError(
|
||||
"hostname %r "
|
||||
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
|
||||
)
|
||||
elif len(dnsnames) == 1:
|
||||
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
|
||||
else:
|
||||
raise CertificateError(
|
||||
"no appropriate commonName or subjectAltName fields were found"
|
||||
)
|
||||
|
|
@ -1,536 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from ._collections import RecentlyUsedContainer
|
||||
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
|
||||
from .exceptions import (
|
||||
LocationValueError,
|
||||
MaxRetryError,
|
||||
ProxySchemeUnknown,
|
||||
ProxySchemeUnsupported,
|
||||
URLSchemeUnknown,
|
||||
)
|
||||
from .packages import six
|
||||
from .packages.six.moves.urllib.parse import urljoin
|
||||
from .request import RequestMethods
|
||||
from .util.proxy import connection_requires_http_tunnel
|
||||
from .util.retry import Retry
|
||||
from .util.url import parse_url
|
||||
|
||||
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SSL_KEYWORDS = (
|
||||
"key_file",
|
||||
"cert_file",
|
||||
"cert_reqs",
|
||||
"ca_certs",
|
||||
"ssl_version",
|
||||
"ca_cert_dir",
|
||||
"ssl_context",
|
||||
"key_password",
|
||||
)
|
||||
|
||||
# All known keyword arguments that could be provided to the pool manager, its
|
||||
# pools, or the underlying connections. This is used to construct a pool key.
|
||||
_key_fields = (
|
||||
"key_scheme", # str
|
||||
"key_host", # str
|
||||
"key_port", # int
|
||||
"key_timeout", # int or float or Timeout
|
||||
"key_retries", # int or Retry
|
||||
"key_strict", # bool
|
||||
"key_block", # bool
|
||||
"key_source_address", # str
|
||||
"key_key_file", # str
|
||||
"key_key_password", # str
|
||||
"key_cert_file", # str
|
||||
"key_cert_reqs", # str
|
||||
"key_ca_certs", # str
|
||||
"key_ssl_version", # str
|
||||
"key_ca_cert_dir", # str
|
||||
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
|
||||
"key_maxsize", # int
|
||||
"key_headers", # dict
|
||||
"key__proxy", # parsed proxy url
|
||||
"key__proxy_headers", # dict
|
||||
"key__proxy_config", # class
|
||||
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
|
||||
"key__socks_options", # dict
|
||||
"key_assert_hostname", # bool or string
|
||||
"key_assert_fingerprint", # str
|
||||
"key_server_hostname", # str
|
||||
)
|
||||
|
||||
#: The namedtuple class used to construct keys for the connection pool.
|
||||
#: All custom key schemes should include the fields in this key at a minimum.
|
||||
PoolKey = collections.namedtuple("PoolKey", _key_fields)
|
||||
|
||||
_proxy_config_fields = ("ssl_context", "use_forwarding_for_https")
|
||||
ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields)
|
||||
|
||||
|
||||
def _default_key_normalizer(key_class, request_context):
|
||||
"""
|
||||
Create a pool key out of a request context dictionary.
|
||||
|
||||
According to RFC 3986, both the scheme and host are case-insensitive.
|
||||
Therefore, this function normalizes both before constructing the pool
|
||||
key for an HTTPS request. If you wish to change this behaviour, provide
|
||||
alternate callables to ``key_fn_by_scheme``.
|
||||
|
||||
:param key_class:
|
||||
The class to use when constructing the key. This should be a namedtuple
|
||||
with the ``scheme`` and ``host`` keys at a minimum.
|
||||
:type key_class: namedtuple
|
||||
:param request_context:
|
||||
A dictionary-like object that contain the context for a request.
|
||||
:type request_context: dict
|
||||
|
||||
:return: A namedtuple that can be used as a connection pool key.
|
||||
:rtype: PoolKey
|
||||
"""
|
||||
# Since we mutate the dictionary, make a copy first
|
||||
context = request_context.copy()
|
||||
context["scheme"] = context["scheme"].lower()
|
||||
context["host"] = context["host"].lower()
|
||||
|
||||
# These are both dictionaries and need to be transformed into frozensets
|
||||
for key in ("headers", "_proxy_headers", "_socks_options"):
|
||||
if key in context and context[key] is not None:
|
||||
context[key] = frozenset(context[key].items())
|
||||
|
||||
# The socket_options key may be a list and needs to be transformed into a
|
||||
# tuple.
|
||||
socket_opts = context.get("socket_options")
|
||||
if socket_opts is not None:
|
||||
context["socket_options"] = tuple(socket_opts)
|
||||
|
||||
# Map the kwargs to the names in the namedtuple - this is necessary since
|
||||
# namedtuples can't have fields starting with '_'.
|
||||
for key in list(context.keys()):
|
||||
context["key_" + key] = context.pop(key)
|
||||
|
||||
# Default to ``None`` for keys missing from the context
|
||||
for field in key_class._fields:
|
||||
if field not in context:
|
||||
context[field] = None
|
||||
|
||||
return key_class(**context)
|
||||
|
||||
|
||||
#: A dictionary that maps a scheme to a callable that creates a pool key.
|
||||
#: This can be used to alter the way pool keys are constructed, if desired.
|
||||
#: Each PoolManager makes a copy of this dictionary so they can be configured
|
||||
#: globally here, or individually on the instance.
|
||||
key_fn_by_scheme = {
|
||||
"http": functools.partial(_default_key_normalizer, PoolKey),
|
||||
"https": functools.partial(_default_key_normalizer, PoolKey),
|
||||
}
|
||||
|
||||
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
|
||||
|
||||
|
||||
class PoolManager(RequestMethods):
|
||||
"""
|
||||
Allows for arbitrary requests while transparently keeping track of
|
||||
necessary connection pools for you.
|
||||
|
||||
:param num_pools:
|
||||
Number of connection pools to cache before discarding the least
|
||||
recently used pool.
|
||||
|
||||
:param headers:
|
||||
Headers to include with all requests, unless other headers are given
|
||||
explicitly.
|
||||
|
||||
:param \\**connection_pool_kw:
|
||||
Additional parameters are used to create fresh
|
||||
:class:`urllib3.connectionpool.ConnectionPool` instances.
|
||||
|
||||
Example::
|
||||
|
||||
>>> manager = PoolManager(num_pools=2)
|
||||
>>> r = manager.request('GET', 'http://google.com/')
|
||||
>>> r = manager.request('GET', 'http://google.com/mail')
|
||||
>>> r = manager.request('GET', 'http://yahoo.com/')
|
||||
>>> len(manager.pools)
|
||||
2
|
||||
|
||||
"""
|
||||
|
||||
proxy = None
|
||||
proxy_config = None
|
||||
|
||||
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
|
||||
RequestMethods.__init__(self, headers)
|
||||
self.connection_pool_kw = connection_pool_kw
|
||||
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
|
||||
|
||||
# Locally set the pool classes and keys so other PoolManagers can
|
||||
# override them.
|
||||
self.pool_classes_by_scheme = pool_classes_by_scheme
|
||||
self.key_fn_by_scheme = key_fn_by_scheme.copy()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.clear()
|
||||
# Return False to re-raise any potential exceptions
|
||||
return False
|
||||
|
||||
def _new_pool(self, scheme, host, port, request_context=None):
|
||||
"""
|
||||
Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
|
||||
any additional pool keyword arguments.
|
||||
|
||||
If ``request_context`` is provided, it is provided as keyword arguments
|
||||
to the pool class used. This method is used to actually create the
|
||||
connection pools handed out by :meth:`connection_from_url` and
|
||||
companion methods. It is intended to be overridden for customization.
|
||||
"""
|
||||
pool_cls = self.pool_classes_by_scheme[scheme]
|
||||
if request_context is None:
|
||||
request_context = self.connection_pool_kw.copy()
|
||||
|
||||
# Although the context has everything necessary to create the pool,
|
||||
# this function has historically only used the scheme, host, and port
|
||||
# in the positional args. When an API change is acceptable these can
|
||||
# be removed.
|
||||
for key in ("scheme", "host", "port"):
|
||||
request_context.pop(key, None)
|
||||
|
||||
if scheme == "http":
|
||||
for kw in SSL_KEYWORDS:
|
||||
request_context.pop(kw, None)
|
||||
|
||||
return pool_cls(host, port, **request_context)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Empty our store of pools and direct them all to close.
|
||||
|
||||
This will not affect in-flight connections, but they will not be
|
||||
re-used after completion.
|
||||
"""
|
||||
self.pools.clear()
|
||||
|
||||
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
|
||||
"""
|
||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
|
||||
|
||||
If ``port`` isn't given, it will be derived from the ``scheme`` using
|
||||
``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is
|
||||
provided, it is merged with the instance's ``connection_pool_kw``
|
||||
variable and used to create the new connection pool, if one is
|
||||
needed.
|
||||
"""
|
||||
|
||||
if not host:
|
||||
raise LocationValueError("No host specified.")
|
||||
|
||||
request_context = self._merge_pool_kwargs(pool_kwargs)
|
||||
request_context["scheme"] = scheme or "http"
|
||||
if not port:
|
||||
port = port_by_scheme.get(request_context["scheme"].lower(), 80)
|
||||
request_context["port"] = port
|
||||
request_context["host"] = host
|
||||
|
||||
return self.connection_from_context(request_context)
|
||||
|
||||
def connection_from_context(self, request_context):
|
||||
"""
|
||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
|
||||
|
||||
``request_context`` must at least contain the ``scheme`` key and its
|
||||
value must be a key in ``key_fn_by_scheme`` instance variable.
|
||||
"""
|
||||
scheme = request_context["scheme"].lower()
|
||||
pool_key_constructor = self.key_fn_by_scheme.get(scheme)
|
||||
if not pool_key_constructor:
|
||||
raise URLSchemeUnknown(scheme)
|
||||
pool_key = pool_key_constructor(request_context)
|
||||
|
||||
return self.connection_from_pool_key(pool_key, request_context=request_context)
|
||||
|
||||
def connection_from_pool_key(self, pool_key, request_context=None):
|
||||
"""
|
||||
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
|
||||
|
||||
``pool_key`` should be a namedtuple that only contains immutable
|
||||
objects. At a minimum it must have the ``scheme``, ``host``, and
|
||||
``port`` fields.
|
||||
"""
|
||||
with self.pools.lock:
|
||||
# If the scheme, host, or port doesn't match existing open
|
||||
# connections, open a new ConnectionPool.
|
||||
pool = self.pools.get(pool_key)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
# Make a fresh ConnectionPool of the desired type
|
||||
scheme = request_context["scheme"]
|
||||
host = request_context["host"]
|
||||
port = request_context["port"]
|
||||
pool = self._new_pool(scheme, host, port, request_context=request_context)
|
||||
self.pools[pool_key] = pool
|
||||
|
||||
return pool
|
||||
|
||||
def connection_from_url(self, url, pool_kwargs=None):
|
||||
"""
|
||||
Similar to :func:`urllib3.connectionpool.connection_from_url`.
|
||||
|
||||
If ``pool_kwargs`` is not provided and a new pool needs to be
|
||||
constructed, ``self.connection_pool_kw`` is used to initialize
|
||||
the :class:`urllib3.connectionpool.ConnectionPool`. If ``pool_kwargs``
|
||||
is provided, it is used instead. Note that if a new pool does not
|
||||
need to be created for the request, the provided ``pool_kwargs`` are
|
||||
not used.
|
||||
"""
|
||||
u = parse_url(url)
|
||||
return self.connection_from_host(
|
||||
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
|
||||
)
|
||||
|
||||
def _merge_pool_kwargs(self, override):
|
||||
"""
|
||||
Merge a dictionary of override values for self.connection_pool_kw.
|
||||
|
||||
This does not modify self.connection_pool_kw and returns a new dict.
|
||||
Any keys in the override dictionary with a value of ``None`` are
|
||||
removed from the merged dictionary.
|
||||
"""
|
||||
base_pool_kwargs = self.connection_pool_kw.copy()
|
||||
if override:
|
||||
for key, value in override.items():
|
||||
if value is None:
|
||||
try:
|
||||
del base_pool_kwargs[key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
base_pool_kwargs[key] = value
|
||||
return base_pool_kwargs
|
||||
|
||||
def _proxy_requires_url_absolute_form(self, parsed_url):
|
||||
"""
|
||||
Indicates if the proxy requires the complete destination URL in the
|
||||
request. Normally this is only needed when not using an HTTP CONNECT
|
||||
tunnel.
|
||||
"""
|
||||
if self.proxy is None:
|
||||
return False
|
||||
|
||||
return not connection_requires_http_tunnel(
|
||||
self.proxy, self.proxy_config, parsed_url.scheme
|
||||
)
|
||||
|
||||
def _validate_proxy_scheme_url_selection(self, url_scheme):
|
||||
"""
|
||||
Validates that were not attempting to do TLS in TLS connections on
|
||||
Python2 or with unsupported SSL implementations.
|
||||
"""
|
||||
if self.proxy is None or url_scheme != "https":
|
||||
return
|
||||
|
||||
if self.proxy.scheme != "https":
|
||||
return
|
||||
|
||||
if six.PY2 and not self.proxy_config.use_forwarding_for_https:
|
||||
raise ProxySchemeUnsupported(
|
||||
"Contacting HTTPS destinations through HTTPS proxies "
|
||||
"'via CONNECT tunnels' is not supported in Python 2"
|
||||
)
|
||||
|
||||
def urlopen(self, method, url, redirect=True, **kw):
|
||||
"""
|
||||
Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
|
||||
with custom cross-host redirect logic and only sends the request-uri
|
||||
portion of the ``url``.
|
||||
|
||||
The given ``url`` parameter must be absolute, such that an appropriate
|
||||
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
|
||||
"""
|
||||
u = parse_url(url)
|
||||
self._validate_proxy_scheme_url_selection(u.scheme)
|
||||
|
||||
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
|
||||
|
||||
kw["assert_same_host"] = False
|
||||
kw["redirect"] = False
|
||||
|
||||
if "headers" not in kw:
|
||||
kw["headers"] = self.headers.copy()
|
||||
|
||||
if self._proxy_requires_url_absolute_form(u):
|
||||
response = conn.urlopen(method, url, **kw)
|
||||
else:
|
||||
response = conn.urlopen(method, u.request_uri, **kw)
|
||||
|
||||
redirect_location = redirect and response.get_redirect_location()
|
||||
if not redirect_location:
|
||||
return response
|
||||
|
||||
# Support relative URLs for redirecting.
|
||||
redirect_location = urljoin(url, redirect_location)
|
||||
|
||||
# RFC 7231, Section 6.4.4
|
||||
if response.status == 303:
|
||||
method = "GET"
|
||||
|
||||
retries = kw.get("retries")
|
||||
if not isinstance(retries, Retry):
|
||||
retries = Retry.from_int(retries, redirect=redirect)
|
||||
|
||||
# Strip headers marked as unsafe to forward to the redirected location.
|
||||
# Check remove_headers_on_redirect to avoid a potential network call within
|
||||
# conn.is_same_host() which may use socket.gethostbyname() in the future.
|
||||
if retries.remove_headers_on_redirect and not conn.is_same_host(
|
||||
redirect_location
|
||||
):
|
||||
headers = list(six.iterkeys(kw["headers"]))
|
||||
for header in headers:
|
||||
if header.lower() in retries.remove_headers_on_redirect:
|
||||
kw["headers"].pop(header, None)
|
||||
|
||||
try:
|
||||
retries = retries.increment(method, url, response=response, _pool=conn)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_redirect:
|
||||
response.drain_conn()
|
||||
raise
|
||||
return response
|
||||
|
||||
kw["retries"] = retries
|
||||
kw["redirect"] = redirect
|
||||
|
||||
log.info("Redirecting %s -> %s", url, redirect_location)
|
||||
|
||||
response.drain_conn()
|
||||
return self.urlopen(method, redirect_location, **kw)
|
||||
|
||||
|
||||
class ProxyManager(PoolManager):
|
||||
"""
|
||||
Behaves just like :class:`PoolManager`, but sends all requests through
|
||||
the defined proxy, using the CONNECT method for HTTPS URLs.
|
||||
|
||||
:param proxy_url:
|
||||
The URL of the proxy to be used.
|
||||
|
||||
:param proxy_headers:
|
||||
A dictionary containing headers that will be sent to the proxy. In case
|
||||
of HTTP they are being sent with each request, while in the
|
||||
HTTPS/CONNECT case they are sent only once. Could be used for proxy
|
||||
authentication.
|
||||
|
||||
:param proxy_ssl_context:
|
||||
The proxy SSL context is used to establish the TLS connection to the
|
||||
proxy when using HTTPS proxies.
|
||||
|
||||
:param use_forwarding_for_https:
|
||||
(Defaults to False) If set to True will forward requests to the HTTPS
|
||||
proxy to be made on behalf of the client instead of creating a TLS
|
||||
tunnel via the CONNECT method. **Enabling this flag means that request
|
||||
and response headers and content will be visible from the HTTPS proxy**
|
||||
whereas tunneling keeps request and response headers and content
|
||||
private. IP address, target hostname, SNI, and port are always visible
|
||||
to an HTTPS proxy even when this flag is disabled.
|
||||
|
||||
Example:
|
||||
>>> proxy = urllib3.ProxyManager('http://localhost:3128/')
|
||||
>>> r1 = proxy.request('GET', 'http://google.com/')
|
||||
>>> r2 = proxy.request('GET', 'http://httpbin.org/')
|
||||
>>> len(proxy.pools)
|
||||
1
|
||||
>>> r3 = proxy.request('GET', 'https://httpbin.org/')
|
||||
>>> r4 = proxy.request('GET', 'https://twitter.com/')
|
||||
>>> len(proxy.pools)
|
||||
3
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url,
|
||||
num_pools=10,
|
||||
headers=None,
|
||||
proxy_headers=None,
|
||||
proxy_ssl_context=None,
|
||||
use_forwarding_for_https=False,
|
||||
**connection_pool_kw
|
||||
):
|
||||
|
||||
if isinstance(proxy_url, HTTPConnectionPool):
|
||||
proxy_url = "%s://%s:%i" % (
|
||||
proxy_url.scheme,
|
||||
proxy_url.host,
|
||||
proxy_url.port,
|
||||
)
|
||||
proxy = parse_url(proxy_url)
|
||||
|
||||
if proxy.scheme not in ("http", "https"):
|
||||
raise ProxySchemeUnknown(proxy.scheme)
|
||||
|
||||
if not proxy.port:
|
||||
port = port_by_scheme.get(proxy.scheme, 80)
|
||||
proxy = proxy._replace(port=port)
|
||||
|
||||
self.proxy = proxy
|
||||
self.proxy_headers = proxy_headers or {}
|
||||
self.proxy_ssl_context = proxy_ssl_context
|
||||
self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https)
|
||||
|
||||
connection_pool_kw["_proxy"] = self.proxy
|
||||
connection_pool_kw["_proxy_headers"] = self.proxy_headers
|
||||
connection_pool_kw["_proxy_config"] = self.proxy_config
|
||||
|
||||
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
|
||||
|
||||
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
|
||||
if scheme == "https":
|
||||
return super(ProxyManager, self).connection_from_host(
|
||||
host, port, scheme, pool_kwargs=pool_kwargs
|
||||
)
|
||||
|
||||
return super(ProxyManager, self).connection_from_host(
|
||||
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
|
||||
)
|
||||
|
||||
def _set_proxy_headers(self, url, headers=None):
|
||||
"""
|
||||
Sets headers needed by proxies: specifically, the Accept and Host
|
||||
headers. Only sets headers not provided by the user.
|
||||
"""
|
||||
headers_ = {"Accept": "*/*"}
|
||||
|
||||
netloc = parse_url(url).netloc
|
||||
if netloc:
|
||||
headers_["Host"] = netloc
|
||||
|
||||
if headers:
|
||||
headers_.update(headers)
|
||||
return headers_
|
||||
|
||||
def urlopen(self, method, url, redirect=True, **kw):
|
||||
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
|
||||
u = parse_url(url)
|
||||
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
|
||||
# For connections using HTTP CONNECT, httplib sets the necessary
|
||||
# headers on the CONNECT to the proxy. If we're not using CONNECT,
|
||||
# we'll definitely need to set 'Host' at the very least.
|
||||
headers = kw.get("headers", self.headers)
|
||||
kw["headers"] = self._set_proxy_headers(url, headers)
|
||||
|
||||
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
|
||||
|
||||
|
||||
def proxy_from_url(url, **kw):
|
||||
return ProxyManager(proxy_url=url, **kw)
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from .filepost import encode_multipart_formdata
|
||||
from .packages.six.moves.urllib.parse import urlencode
|
||||
|
||||
__all__ = ["RequestMethods"]
|
||||
|
||||
|
||||
class RequestMethods(object):
|
||||
"""
|
||||
Convenience mixin for classes who implement a :meth:`urlopen` method, such
|
||||
as :class:`urllib3.HTTPConnectionPool` and
|
||||
:class:`urllib3.PoolManager`.
|
||||
|
||||
Provides behavior for making common types of HTTP request methods and
|
||||
decides which type of request field encoding to use.
|
||||
|
||||
Specifically,
|
||||
|
||||
:meth:`.request_encode_url` is for sending requests whose fields are
|
||||
encoded in the URL (such as GET, HEAD, DELETE).
|
||||
|
||||
:meth:`.request_encode_body` is for sending requests whose fields are
|
||||
encoded in the *body* of the request using multipart or www-form-urlencoded
|
||||
(such as for POST, PUT, PATCH).
|
||||
|
||||
:meth:`.request` is for making any kind of request, it will look up the
|
||||
appropriate encoding format and use one of the above two methods to make
|
||||
the request.
|
||||
|
||||
Initializer parameters:
|
||||
|
||||
:param headers:
|
||||
Headers to include with all requests, unless other headers are given
|
||||
explicitly.
|
||||
"""
|
||||
|
||||
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
|
||||
|
||||
def __init__(self, headers=None):
|
||||
self.headers = headers or {}
|
||||
|
||||
def urlopen(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
body=None,
|
||||
headers=None,
|
||||
encode_multipart=True,
|
||||
multipart_boundary=None,
|
||||
**kw
|
||||
): # Abstract
|
||||
raise NotImplementedError(
|
||||
"Classes extending RequestMethods must implement "
|
||||
"their own ``urlopen`` method."
|
||||
)
|
||||
|
||||
def request(self, method, url, fields=None, headers=None, **urlopen_kw):
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the appropriate encoding of
|
||||
``fields`` based on the ``method`` used.
|
||||
|
||||
This is a convenience method that requires the least amount of manual
|
||||
effort. It can be used in most situations, while still having the
|
||||
option to drop down to more specific methods when necessary, such as
|
||||
:meth:`request_encode_url`, :meth:`request_encode_body`,
|
||||
or even the lowest level :meth:`urlopen`.
|
||||
"""
|
||||
method = method.upper()
|
||||
|
||||
urlopen_kw["request_url"] = url
|
||||
|
||||
if method in self._encode_url_methods:
|
||||
return self.request_encode_url(
|
||||
method, url, fields=fields, headers=headers, **urlopen_kw
|
||||
)
|
||||
else:
|
||||
return self.request_encode_body(
|
||||
method, url, fields=fields, headers=headers, **urlopen_kw
|
||||
)
|
||||
|
||||
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw = {"headers": headers}
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
if fields:
|
||||
url += "?" + urlencode(fields)
|
||||
|
||||
return self.urlopen(method, url, **extra_kw)
|
||||
|
||||
def request_encode_body(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
fields=None,
|
||||
headers=None,
|
||||
encode_multipart=True,
|
||||
multipart_boundary=None,
|
||||
**urlopen_kw
|
||||
):
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the body. This is useful for request methods like POST, PUT, PATCH, etc.
|
||||
|
||||
When ``encode_multipart=True`` (default), then
|
||||
:func:`urllib3.encode_multipart_formdata` is used to encode
|
||||
the payload with the appropriate content type. Otherwise
|
||||
:func:`urllib.parse.urlencode` is used with the
|
||||
'application/x-www-form-urlencoded' content type.
|
||||
|
||||
Multipart encoding must be used when posting files, and it's reasonably
|
||||
safe to use it in other times too. However, it may break request
|
||||
signing, such as with OAuth.
|
||||
|
||||
Supports an optional ``fields`` parameter of key/value strings AND
|
||||
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
|
||||
the MIME type is optional. For example::
|
||||
|
||||
fields = {
|
||||
'foo': 'bar',
|
||||
'fakefile': ('foofile.txt', 'contents of foofile'),
|
||||
'realfile': ('barfile.txt', open('realfile').read()),
|
||||
'typedfile': ('bazfile.bin', open('bazfile').read(),
|
||||
'image/jpeg'),
|
||||
'nonamefile': 'contents of nonamefile field',
|
||||
}
|
||||
|
||||
When uploading a file, providing a filename (the first parameter of the
|
||||
tuple) is optional but recommended to best mimic behavior of browsers.
|
||||
|
||||
Note that if ``headers`` are supplied, the 'Content-Type' header will
|
||||
be overwritten because it depends on the dynamic random boundary string
|
||||
which is used to compose the body of the request. The random boundary
|
||||
string can be explicitly set with the ``multipart_boundary`` parameter.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw = {"headers": {}}
|
||||
|
||||
if fields:
|
||||
if "body" in urlopen_kw:
|
||||
raise TypeError(
|
||||
"request got values for both 'fields' and 'body', can only specify one."
|
||||
)
|
||||
|
||||
if encode_multipart:
|
||||
body, content_type = encode_multipart_formdata(
|
||||
fields, boundary=multipart_boundary
|
||||
)
|
||||
else:
|
||||
body, content_type = (
|
||||
urlencode(fields),
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
extra_kw["body"] = body
|
||||
extra_kw["headers"] = {"Content-Type": content_type}
|
||||
|
||||
extra_kw["headers"].update(headers)
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
return self.urlopen(method, url, **extra_kw)
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue