removed fusion addon

This commit is contained in:
Jakub Trllo 2024-06-27 14:08:49 +02:00
parent 552270dd0c
commit 348104c32d
118 changed files with 0 additions and 21923 deletions

View file

@ -1,17 +0,0 @@
from .version import __version__
from .addon import (
get_fusion_version,
FusionAddon,
FUSION_ADDON_ROOT,
FUSION_VERSIONS_DICT,
)
__all__ = (
"__version__",
"get_fusion_version",
"FusionAddon",
"FUSION_ADDON_ROOT",
"FUSION_VERSIONS_DICT",
)

View file

@ -1,72 +0,0 @@
import os
import re
from ayon_core.addon import AYONAddon, IHostAddon
from ayon_core.lib import Logger
from .version import __version__
FUSION_ADDON_ROOT = os.path.dirname(os.path.abspath(__file__))
# FUSION_VERSIONS_DICT is used by the pre-launch hooks
# The keys correspond to all currently supported Fusion versions
# Each value is a list of corresponding Python home variables and a profile
# number, which is used by the profile hook to set Fusion profile variables.
FUSION_VERSIONS_DICT = {
9: ("FUSION_PYTHON36_HOME", 9),
16: ("FUSION16_PYTHON36_HOME", 16),
17: ("FUSION16_PYTHON36_HOME", 16),
18: ("FUSION_PYTHON3_HOME", 16),
}
def get_fusion_version(app_name):
"""
The function is triggered by the prelaunch hooks to get the fusion version.
`app_name` is obtained by prelaunch hooks from the
`launch_context.env.get("AYON_APP_NAME")`.
To get a correct Fusion version, a version number should be present
in the `applications/fusion/variants` key
of the Blackmagic Fusion Application Settings.
"""
log = Logger.get_logger(__name__)
if not app_name:
return
app_version_candidates = re.findall(r"\d+", app_name)
if not app_version_candidates:
return
for app_version in app_version_candidates:
if int(app_version) in FUSION_VERSIONS_DICT:
return int(app_version)
else:
log.info(
"Unsupported Fusion version: {app_version}".format(
app_version=app_version
)
)
class FusionAddon(AYONAddon, IHostAddon):
name = "fusion"
version = __version__
host_name = "fusion"
def get_launch_hook_paths(self, app):
if app.host_name != self.host_name:
return []
return [os.path.join(FUSION_ADDON_ROOT, "hooks")]
def add_implementation_envs(self, env, app):
# Set default values if are not already set via settings
defaults = {"AYON_LOG_NO_COLORS": "1"}
for key, value in defaults.items():
if not env.get(key):
env[key] = value
def get_workfile_extensions(self):
return [".comp"]

View file

@ -1,39 +0,0 @@
from .pipeline import (
FusionHost,
ls,
imprint_container,
parse_container
)
from .lib import (
maintained_selection,
update_frame_range,
set_current_context_framerange,
get_current_comp,
get_bmd_library,
comp_lock_and_undo_chunk
)
from .menu import launch_ayon_menu
__all__ = [
# pipeline
"FusionHost",
"ls",
"imprint_container",
"parse_container",
# lib
"maintained_selection",
"update_frame_range",
"set_current_context_framerange",
"get_current_comp",
"get_bmd_library",
"comp_lock_and_undo_chunk",
# menu
"launch_ayon_menu",
]

View file

@ -1,112 +0,0 @@
import pyblish.api
from ayon_fusion.api.lib import get_current_comp
from ayon_core.pipeline.publish import get_errored_instances_from_context
class SelectInvalidAction(pyblish.api.Action):
"""Select invalid nodes in Fusion when plug-in failed.
To retrieve the invalid nodes this assumes a static `get_invalid()`
method is available on the plugin.
"""
label = "Select invalid"
on = "failed" # This action is only available on a failed plug-in
icon = "search" # Icon from Awesome Icon
def process(self, context, plugin):
errored_instances = get_errored_instances_from_context(
context,
plugin=plugin,
)
# Get the invalid nodes for the plug-ins
self.log.info("Finding invalid nodes..")
invalid = list()
for instance in errored_instances:
invalid_nodes = plugin.get_invalid(instance)
if invalid_nodes:
if isinstance(invalid_nodes, (list, tuple)):
invalid.extend(invalid_nodes)
else:
self.log.warning(
"Plug-in returned to be invalid, "
"but has no selectable nodes."
)
if not invalid:
# Assume relevant comp is current comp and clear selection
self.log.info("No invalid tools found.")
comp = get_current_comp()
flow = comp.CurrentFrame.FlowView
flow.Select() # No args equals clearing selection
return
# Assume a single comp
first_tool = invalid[0]
comp = first_tool.Comp()
flow = comp.CurrentFrame.FlowView
flow.Select() # No args equals clearing selection
names = set()
for tool in invalid:
flow.Select(tool, True)
comp.SetActiveTool(tool)
names.add(tool.Name)
self.log.info(
"Selecting invalid tools: %s" % ", ".join(sorted(names))
)
class SelectToolAction(pyblish.api.Action):
"""Select invalid output tool in Fusion when plug-in failed.
"""
label = "Select saver"
on = "failed" # This action is only available on a failed plug-in
icon = "search" # Icon from Awesome Icon
def process(self, context, plugin):
errored_instances = get_errored_instances_from_context(
context,
plugin=plugin,
)
# Get the invalid nodes for the plug-ins
self.log.info("Finding invalid nodes..")
tools = []
for instance in errored_instances:
tool = instance.data.get("tool")
if tool is not None:
tools.append(tool)
else:
self.log.warning(
"Plug-in returned to be invalid, "
f"but has no saver for instance {instance.name}."
)
if not tools:
# Assume relevant comp is current comp and clear selection
self.log.info("No invalid tools found.")
comp = get_current_comp()
flow = comp.CurrentFrame.FlowView
flow.Select() # No args equals clearing selection
return
# Assume a single comp
first_tool = tools[0]
comp = first_tool.Comp()
flow = comp.CurrentFrame.FlowView
flow.Select() # No args equals clearing selection
names = set()
for tool in tools:
flow.Select(tool, True)
comp.SetActiveTool(tool)
names.add(tool.Name)
self.log.info(
"Selecting invalid tools: %s" % ", ".join(sorted(names))
)

View file

@ -1,402 +0,0 @@
import os
import sys
import re
import contextlib
from ayon_core.lib import Logger, BoolDef, UILabelDef
from ayon_core.style import load_stylesheet
from ayon_core.pipeline import registered_host
from ayon_core.pipeline.create import CreateContext
from ayon_core.pipeline.context_tools import get_current_folder_entity
self = sys.modules[__name__]
self._project = None
def update_frame_range(start, end, comp=None, set_render_range=True,
handle_start=0, handle_end=0):
"""Set Fusion comp's start and end frame range
Args:
start (float, int): start frame
end (float, int): end frame
comp (object, Optional): comp object from fusion
set_render_range (bool, Optional): When True this will also set the
composition's render start and end frame.
handle_start (float, int, Optional): frame handles before start frame
handle_end (float, int, Optional): frame handles after end frame
Returns:
None
"""
if not comp:
comp = get_current_comp()
# Convert any potential none type to zero
handle_start = handle_start or 0
handle_end = handle_end or 0
attrs = {
"COMPN_GlobalStart": start - handle_start,
"COMPN_GlobalEnd": end + handle_end
}
# set frame range
if set_render_range:
attrs.update({
"COMPN_RenderStart": start,
"COMPN_RenderEnd": end
})
with comp_lock_and_undo_chunk(comp):
comp.SetAttrs(attrs)
def set_current_context_framerange(folder_entity=None):
"""Set Comp's frame range based on current folder."""
if folder_entity is None:
folder_entity = get_current_folder_entity(
fields={"attrib.frameStart",
"attrib.frameEnd",
"attrib.handleStart",
"attrib.handleEnd"})
folder_attributes = folder_entity["attrib"]
start = folder_attributes["frameStart"]
end = folder_attributes["frameEnd"]
handle_start = folder_attributes["handleStart"]
handle_end = folder_attributes["handleEnd"]
update_frame_range(start, end, set_render_range=True,
handle_start=handle_start,
handle_end=handle_end)
def set_current_context_fps(folder_entity=None):
"""Set Comp's frame rate (FPS) to based on current asset"""
if folder_entity is None:
folder_entity = get_current_folder_entity(fields={"attrib.fps"})
fps = float(folder_entity["attrib"].get("fps", 24.0))
comp = get_current_comp()
comp.SetPrefs({
"Comp.FrameFormat.Rate": fps,
})
def set_current_context_resolution(folder_entity=None):
"""Set Comp's resolution width x height default based on current folder"""
if folder_entity is None:
folder_entity = get_current_folder_entity(
fields={"attrib.resolutionWidth", "attrib.resolutionHeight"})
folder_attributes = folder_entity["attrib"]
width = folder_attributes["resolutionWidth"]
height = folder_attributes["resolutionHeight"]
comp = get_current_comp()
print("Setting comp frame format resolution to {}x{}".format(width,
height))
comp.SetPrefs({
"Comp.FrameFormat.Width": width,
"Comp.FrameFormat.Height": height,
})
def validate_comp_prefs(comp=None, force_repair=False):
"""Validate current comp defaults with folder settings.
Validates fps, resolutionWidth, resolutionHeight, aspectRatio.
This does *not* validate frameStart, frameEnd, handleStart and handleEnd.
"""
if comp is None:
comp = get_current_comp()
log = Logger.get_logger("validate_comp_prefs")
fields = {
"path",
"attrib.fps",
"attrib.resolutionWidth",
"attrib.resolutionHeight",
"attrib.pixelAspect",
}
folder_entity = get_current_folder_entity(fields=fields)
folder_path = folder_entity["path"]
folder_attributes = folder_entity["attrib"]
comp_frame_format_prefs = comp.GetPrefs("Comp.FrameFormat")
# Pixel aspect ratio in Fusion is set as AspectX and AspectY so we convert
# the data to something that is more sensible to Fusion
folder_attributes["pixelAspectX"] = folder_attributes.pop("pixelAspect")
folder_attributes["pixelAspectY"] = 1.0
validations = [
("fps", "Rate", "FPS"),
("resolutionWidth", "Width", "Resolution Width"),
("resolutionHeight", "Height", "Resolution Height"),
("pixelAspectX", "AspectX", "Pixel Aspect Ratio X"),
("pixelAspectY", "AspectY", "Pixel Aspect Ratio Y")
]
invalid = []
for key, comp_key, label in validations:
folder_value = folder_attributes[key]
comp_value = comp_frame_format_prefs.get(comp_key)
if folder_value != comp_value:
invalid_msg = "{} {} should be {}".format(label,
comp_value,
folder_value)
invalid.append(invalid_msg)
if not force_repair:
# Do not log warning if we force repair anyway
log.warning(
"Comp {pref} {value} does not match folder "
"'{folder_path}' {pref} {folder_value}".format(
pref=label,
value=comp_value,
folder_path=folder_path,
folder_value=folder_value)
)
if invalid:
def _on_repair():
attributes = dict()
for key, comp_key, _label in validations:
value = folder_attributes[key]
comp_key_full = "Comp.FrameFormat.{}".format(comp_key)
attributes[comp_key_full] = value
comp.SetPrefs(attributes)
if force_repair:
log.info("Applying default Comp preferences..")
_on_repair()
return
from . import menu
from ayon_core.tools.utils import SimplePopup
dialog = SimplePopup(parent=menu.menu)
dialog.setWindowTitle("Fusion comp has invalid configuration")
msg = "Comp preferences mismatches '{}'".format(folder_path)
msg += "\n" + "\n".join(invalid)
dialog.set_message(msg)
dialog.set_button_text("Repair")
dialog.on_clicked.connect(_on_repair)
dialog.show()
dialog.raise_()
dialog.activateWindow()
dialog.setStyleSheet(load_stylesheet())
@contextlib.contextmanager
def maintained_selection(comp=None):
"""Reset comp selection from before the context after the context"""
if comp is None:
comp = get_current_comp()
previous_selection = comp.GetToolList(True).values()
try:
yield
finally:
flow = comp.CurrentFrame.FlowView
flow.Select() # No args equals clearing selection
if previous_selection:
for tool in previous_selection:
flow.Select(tool, True)
@contextlib.contextmanager
def maintained_comp_range(comp=None,
global_start=True,
global_end=True,
render_start=True,
render_end=True):
"""Reset comp frame ranges from before the context after the context"""
if comp is None:
comp = get_current_comp()
comp_attrs = comp.GetAttrs()
preserve_attrs = {}
if global_start:
preserve_attrs["COMPN_GlobalStart"] = comp_attrs["COMPN_GlobalStart"]
if global_end:
preserve_attrs["COMPN_GlobalEnd"] = comp_attrs["COMPN_GlobalEnd"]
if render_start:
preserve_attrs["COMPN_RenderStart"] = comp_attrs["COMPN_RenderStart"]
if render_end:
preserve_attrs["COMPN_RenderEnd"] = comp_attrs["COMPN_RenderEnd"]
try:
yield
finally:
comp.SetAttrs(preserve_attrs)
def get_frame_path(path):
"""Get filename for the Fusion Saver with padded number as '#'
>>> get_frame_path("C:/test.exr")
('C:/test', 4, '.exr')
>>> get_frame_path("filename.00.tif")
('filename.', 2, '.tif')
>>> get_frame_path("foobar35.tif")
('foobar', 2, '.tif')
Args:
path (str): The path to render to.
Returns:
tuple: head, padding, tail (extension)
"""
filename, ext = os.path.splitext(path)
# Find a final number group
match = re.match('.*?([0-9]+)$', filename)
if match:
padding = len(match.group(1))
# remove number from end since fusion
# will swap it with the frame number
filename = filename[:-padding]
else:
padding = 4 # default Fusion padding
return filename, padding, ext
def get_fusion_module():
"""Get current Fusion instance"""
fusion = getattr(sys.modules["__main__"], "fusion", None)
return fusion
def get_bmd_library():
"""Get bmd library"""
bmd = getattr(sys.modules["__main__"], "bmd", None)
return bmd
def get_current_comp():
"""Get current comp in this session"""
fusion = get_fusion_module()
if fusion is not None:
comp = fusion.CurrentComp
return comp
@contextlib.contextmanager
def comp_lock_and_undo_chunk(
comp,
undo_queue_name="Script CMD",
keep_undo=True,
):
"""Lock comp and open an undo chunk during the context"""
try:
comp.Lock()
comp.StartUndo(undo_queue_name)
yield
finally:
comp.Unlock()
comp.EndUndo(keep_undo)
def update_content_on_context_change():
"""Update all Creator instances to current asset"""
host = registered_host()
context = host.get_current_context()
folder_path = context["folder_path"]
task = context["task_name"]
create_context = CreateContext(host, reset=True)
for instance in create_context.instances:
instance_folder_path = instance.get("folderPath")
if instance_folder_path and instance_folder_path != folder_path:
instance["folderPath"] = folder_path
instance_task = instance.get("task")
if instance_task and instance_task != task:
instance["task"] = task
create_context.save_changes()
def prompt_reset_context():
"""Prompt the user what context settings to reset.
This prompt is used on saving to a different task to allow the scene to
get matched to the new context.
"""
# TODO: Cleanup this prototyped mess of imports and odd dialog
from ayon_core.tools.attribute_defs.dialog import (
AttributeDefinitionsDialog
)
from qtpy import QtCore
definitions = [
UILabelDef(
label=(
"You are saving your workfile into a different folder or task."
"\n\n"
"Would you like to update some settings to the new context?\n"
)
),
BoolDef(
"fps",
label="FPS",
tooltip="Reset Comp FPS",
default=True
),
BoolDef(
"frame_range",
label="Frame Range",
tooltip="Reset Comp start and end frame ranges",
default=True
),
BoolDef(
"resolution",
label="Comp Resolution",
tooltip="Reset Comp resolution",
default=True
),
BoolDef(
"instances",
label="Publish instances",
tooltip="Update all publish instance's folder and task to match "
"the new folder and task",
default=True
),
]
dialog = AttributeDefinitionsDialog(definitions)
dialog.setWindowFlags(
dialog.windowFlags() | QtCore.Qt.WindowStaysOnTopHint
)
dialog.setWindowTitle("Saving to different context.")
dialog.setStyleSheet(load_stylesheet())
if not dialog.exec_():
return None
options = dialog.get_values()
folder_entity = get_current_folder_entity()
if options["frame_range"]:
set_current_context_framerange(folder_entity)
if options["fps"]:
set_current_context_fps(folder_entity)
if options["resolution"]:
set_current_context_resolution(folder_entity)
if options["instances"]:
update_content_on_context_change()
dialog.deleteLater()

View file

@ -1,190 +0,0 @@
import os
import sys
from qtpy import QtWidgets, QtCore, QtGui
from ayon_core.tools.utils import host_tools
from ayon_core.style import load_stylesheet
from ayon_core.lib import register_event_callback
from ayon_fusion.scripts import (
duplicate_with_inputs,
)
from ayon_fusion.api.lib import (
set_current_context_framerange,
set_current_context_resolution,
)
from ayon_core.pipeline import get_current_folder_path
from ayon_core.resources import get_ayon_icon_filepath
from ayon_core.tools.utils import get_qt_app
from .pipeline import FusionEventHandler
from .pulse import FusionPulse
MENU_LABEL = os.environ["AYON_MENU_LABEL"]
self = sys.modules[__name__]
self.menu = None
class AYONMenu(QtWidgets.QWidget):
def __init__(self, *args, **kwargs):
super(AYONMenu, self).__init__(*args, **kwargs)
self.setObjectName(f"{MENU_LABEL}Menu")
icon_path = get_ayon_icon_filepath()
icon = QtGui.QIcon(icon_path)
self.setWindowIcon(icon)
self.setWindowFlags(
QtCore.Qt.Window
| QtCore.Qt.CustomizeWindowHint
| QtCore.Qt.WindowTitleHint
| QtCore.Qt.WindowMinimizeButtonHint
| QtCore.Qt.WindowCloseButtonHint
| QtCore.Qt.WindowStaysOnTopHint
)
self.render_mode_widget = None
self.setWindowTitle(MENU_LABEL)
context_label = QtWidgets.QLabel("Context", self)
context_label.setStyleSheet(
"""QLabel {
font-size: 14px;
font-weight: 600;
color: #5f9fb8;
}"""
)
context_label.setAlignment(QtCore.Qt.AlignHCenter)
workfiles_btn = QtWidgets.QPushButton("Workfiles...", self)
create_btn = QtWidgets.QPushButton("Create...", self)
load_btn = QtWidgets.QPushButton("Load...", self)
publish_btn = QtWidgets.QPushButton("Publish...", self)
manager_btn = QtWidgets.QPushButton("Manage...", self)
libload_btn = QtWidgets.QPushButton("Library...", self)
set_framerange_btn = QtWidgets.QPushButton("Set Frame Range", self)
set_resolution_btn = QtWidgets.QPushButton("Set Resolution", self)
duplicate_with_inputs_btn = QtWidgets.QPushButton(
"Duplicate with input connections", self
)
layout = QtWidgets.QVBoxLayout(self)
layout.setContentsMargins(10, 20, 10, 20)
layout.addWidget(context_label)
layout.addSpacing(20)
layout.addWidget(workfiles_btn)
layout.addSpacing(20)
layout.addWidget(create_btn)
layout.addWidget(load_btn)
layout.addWidget(publish_btn)
layout.addWidget(manager_btn)
layout.addSpacing(20)
layout.addWidget(libload_btn)
layout.addSpacing(20)
layout.addWidget(set_framerange_btn)
layout.addWidget(set_resolution_btn)
layout.addSpacing(20)
layout.addWidget(duplicate_with_inputs_btn)
self.setLayout(layout)
# Store reference so we can update the label
self.context_label = context_label
workfiles_btn.clicked.connect(self.on_workfile_clicked)
create_btn.clicked.connect(self.on_create_clicked)
publish_btn.clicked.connect(self.on_publish_clicked)
load_btn.clicked.connect(self.on_load_clicked)
manager_btn.clicked.connect(self.on_manager_clicked)
libload_btn.clicked.connect(self.on_libload_clicked)
duplicate_with_inputs_btn.clicked.connect(
self.on_duplicate_with_inputs_clicked
)
set_resolution_btn.clicked.connect(self.on_set_resolution_clicked)
set_framerange_btn.clicked.connect(self.on_set_framerange_clicked)
self._callbacks = []
self.register_callback("taskChanged", self.on_task_changed)
self.on_task_changed()
# Force close current process if Fusion is closed
self._pulse = FusionPulse(parent=self)
self._pulse.start()
# Detect Fusion events as AYON events
self._event_handler = FusionEventHandler(parent=self)
self._event_handler.start()
def on_task_changed(self):
# Update current context label
label = get_current_folder_path()
self.context_label.setText(label)
def register_callback(self, name, fn):
# Create a wrapper callback that we only store
# for as long as we want it to persist as callback
def _callback(*args):
fn()
self._callbacks.append(_callback)
register_event_callback(name, _callback)
def deregister_all_callbacks(self):
self._callbacks[:] = []
def on_workfile_clicked(self):
host_tools.show_workfiles()
def on_create_clicked(self):
host_tools.show_publisher(tab="create")
def on_publish_clicked(self):
host_tools.show_publisher(tab="publish")
def on_load_clicked(self):
host_tools.show_loader(use_context=True)
def on_manager_clicked(self):
host_tools.show_scene_inventory()
def on_libload_clicked(self):
host_tools.show_library_loader()
def on_duplicate_with_inputs_clicked(self):
duplicate_with_inputs.duplicate_with_input_connections()
def on_set_resolution_clicked(self):
set_current_context_resolution()
def on_set_framerange_clicked(self):
set_current_context_framerange()
def launch_ayon_menu():
app = get_qt_app()
ayon_menu = AYONMenu()
stylesheet = load_stylesheet()
ayon_menu.setStyleSheet(stylesheet)
ayon_menu.show()
self.menu = ayon_menu
result = app.exec_()
print("Shutting down..")
sys.exit(result)

View file

@ -1,439 +0,0 @@
"""
Basic avalon integration
"""
import os
import sys
import logging
import contextlib
from pathlib import Path
import pyblish.api
from qtpy import QtCore
from ayon_core.lib import (
Logger,
register_event_callback,
emit_event
)
from ayon_core.pipeline import (
register_loader_plugin_path,
register_creator_plugin_path,
register_inventory_action_path,
AVALON_CONTAINER_ID,
)
from ayon_core.pipeline.load import any_outdated_containers
from ayon_core.host import HostBase, IWorkfileHost, ILoadHost, IPublishHost
from ayon_core.tools.utils import host_tools
from ayon_fusion import FUSION_ADDON_ROOT
from .lib import (
get_current_comp,
validate_comp_prefs,
prompt_reset_context
)
log = Logger.get_logger(__name__)
PLUGINS_DIR = os.path.join(FUSION_ADDON_ROOT, "plugins")
PUBLISH_PATH = os.path.join(PLUGINS_DIR, "publish")
LOAD_PATH = os.path.join(PLUGINS_DIR, "load")
CREATE_PATH = os.path.join(PLUGINS_DIR, "create")
INVENTORY_PATH = os.path.join(PLUGINS_DIR, "inventory")
# Track whether the workfile tool is about to save
_about_to_save = False
class FusionLogHandler(logging.Handler):
# Keep a reference to fusion's Print function (Remote Object)
_print = None
@property
def print(self):
if self._print is not None:
# Use cached
return self._print
_print = getattr(sys.modules["__main__"], "fusion").Print
if _print is None:
# Backwards compatibility: Print method on Fusion instance was
# added around Fusion 17.4 and wasn't available on PyRemote Object
# before
_print = get_current_comp().Print
self._print = _print
return _print
def emit(self, record):
entry = self.format(record)
self.print(entry)
class FusionHost(HostBase, IWorkfileHost, ILoadHost, IPublishHost):
name = "fusion"
def install(self):
"""Install fusion-specific functionality of AYON.
This is where you install menus and register families, data
and loaders into fusion.
It is called automatically when installing via
`ayon_core.pipeline.install_host(ayon_fusion.api)`
See the Maya equivalent for inspiration on how to implement this.
"""
# Remove all handlers associated with the root logger object, because
# that one always logs as "warnings" incorrectly.
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
# Attach default logging handler that prints to active comp
logger = logging.getLogger()
formatter = logging.Formatter(fmt="%(message)s\n")
handler = FusionLogHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
pyblish.api.register_host("fusion")
pyblish.api.register_plugin_path(PUBLISH_PATH)
log.info("Registering Fusion plug-ins..")
register_loader_plugin_path(LOAD_PATH)
register_creator_plugin_path(CREATE_PATH)
register_inventory_action_path(INVENTORY_PATH)
# Register events
register_event_callback("open", on_after_open)
register_event_callback("workfile.save.before", before_workfile_save)
register_event_callback("save", on_save)
register_event_callback("new", on_new)
register_event_callback("taskChanged", on_task_changed)
# region workfile io api
def has_unsaved_changes(self):
comp = get_current_comp()
return comp.GetAttrs()["COMPB_Modified"]
def get_workfile_extensions(self):
return [".comp"]
def save_workfile(self, dst_path=None):
comp = get_current_comp()
comp.Save(dst_path)
def open_workfile(self, filepath):
# Hack to get fusion, see
# ayon_fusion.api.pipeline.get_current_comp()
fusion = getattr(sys.modules["__main__"], "fusion", None)
return fusion.LoadComp(filepath)
def get_current_workfile(self):
comp = get_current_comp()
current_filepath = comp.GetAttrs()["COMPS_FileName"]
if not current_filepath:
return None
return current_filepath
def work_root(self, session):
work_dir = session["AYON_WORKDIR"]
scene_dir = session.get("AVALON_SCENEDIR")
if scene_dir:
return os.path.join(work_dir, scene_dir)
else:
return work_dir
# endregion
@contextlib.contextmanager
def maintained_selection(self):
from .lib import maintained_selection
return maintained_selection()
def get_containers(self):
return ls()
def update_context_data(self, data, changes):
comp = get_current_comp()
comp.SetData("openpype", data)
def get_context_data(self):
comp = get_current_comp()
return comp.GetData("openpype") or {}
def on_new(event):
comp = event["Rets"]["comp"]
validate_comp_prefs(comp, force_repair=True)
def on_save(event):
comp = event["sender"]
validate_comp_prefs(comp)
# We are now starting the actual save directly
global _about_to_save
_about_to_save = False
def on_task_changed():
global _about_to_save
print(f"Task changed: {_about_to_save}")
# TODO: Only do this if not headless
if _about_to_save:
# Let's prompt the user to update the context settings or not
prompt_reset_context()
def on_after_open(event):
comp = event["sender"]
validate_comp_prefs(comp)
if any_outdated_containers():
log.warning("Scene has outdated content.")
# Find AYON menu to attach to
from . import menu
def _on_show_scene_inventory():
# ensure that comp is active
frame = comp.CurrentFrame
if not frame:
print("Comp is closed, skipping show scene inventory")
return
frame.ActivateFrame() # raise comp window
host_tools.show_scene_inventory()
from ayon_core.tools.utils import SimplePopup
from ayon_core.style import load_stylesheet
dialog = SimplePopup(parent=menu.menu)
dialog.setWindowTitle("Fusion comp has outdated content")
dialog.set_message("There are outdated containers in "
"your Fusion comp.")
dialog.on_clicked.connect(_on_show_scene_inventory)
dialog.show()
dialog.raise_()
dialog.activateWindow()
dialog.setStyleSheet(load_stylesheet())
def before_workfile_save(event):
# Due to Fusion's external python process design we can't really
# detect whether the current Fusion environment matches the one the artists
# expects it to be. For example, our pipeline python process might
# have been shut down, and restarted - which will restart it to the
# environment Fusion started with; not necessarily where the artist
# is currently working.
# The `_about_to_save` var is used to detect context changes when
# saving into another asset. If we keep it False it will be ignored
# as context change. As such, before we change tasks we will only
# consider it the current filepath is within the currently known
# AVALON_WORKDIR. This way we avoid false positives of thinking it's
# saving to another context and instead sometimes just have false negatives
# where we fail to show the "Update on task change" prompt.
comp = get_current_comp()
filepath = comp.GetAttrs()["COMPS_FileName"]
workdir = os.environ.get("AYON_WORKDIR")
if Path(workdir) in Path(filepath).parents:
global _about_to_save
_about_to_save = True
def ls():
"""List containers from active Fusion scene
This is the host-equivalent of api.ls(), but instead of listing
assets on disk, it lists assets already loaded in Fusion; once loaded
they are called 'containers'
Yields:
dict: container
"""
comp = get_current_comp()
tools = comp.GetToolList(False).values()
for tool in tools:
container = parse_container(tool)
if container:
yield container
def imprint_container(tool,
name,
namespace,
context,
loader=None):
"""Imprint a Loader with metadata
Containerisation enables a tracking of version, author and origin
for loaded assets.
Arguments:
tool (object): The node in Fusion to imprint as container, usually a
Loader.
name (str): Name of resulting assembly
namespace (str): Namespace under which to host container
context (dict): Asset information
loader (str, optional): Name of loader used to produce this container.
Returns:
None
"""
data = [
("schema", "openpype:container-2.0"),
("id", AVALON_CONTAINER_ID),
("name", str(name)),
("namespace", str(namespace)),
("loader", str(loader)),
("representation", context["representation"]["id"]),
]
for key, value in data:
tool.SetData("avalon.{}".format(key), value)
def parse_container(tool):
"""Returns imprinted container data of a tool
This reads the imprinted data from `imprint_container`.
"""
data = tool.GetData('avalon')
if not isinstance(data, dict):
return
# If not all required data return the empty container
required = ['schema', 'id', 'name',
'namespace', 'loader', 'representation']
if not all(key in data for key in required):
return
container = {key: data[key] for key in required}
# Store the tool's name
container["objectName"] = tool.Name
# Store reference to the tool object
container["_tool"] = tool
return container
class FusionEventThread(QtCore.QThread):
"""QThread which will periodically ping Fusion app for any events.
The fusion.UIManager must be set up to be notified of events before they'll
be reported by this thread, for example:
fusion.UIManager.AddNotify("Comp_Save", None)
"""
on_event = QtCore.Signal(dict)
def run(self):
app = getattr(sys.modules["__main__"], "app", None)
if app is None:
# No Fusion app found
return
# As optimization store the GetEvent method directly because every
# getattr of UIManager.GetEvent tries to resolve the Remote Function
# through the PyRemoteObject
get_event = app.UIManager.GetEvent
delay = int(os.environ.get("AYON_FUSION_CALLBACK_INTERVAL", 1000))
while True:
if self.isInterruptionRequested():
return
# Process all events that have been queued up until now
while True:
event = get_event(False)
if not event:
break
self.on_event.emit(event)
# Wait some time before processing events again
# to not keep blocking the UI
self.msleep(delay)
class FusionEventHandler(QtCore.QObject):
"""Emits AYON events based on Fusion events captured in a QThread.
This will emit the following AYON events based on Fusion actions:
save: Comp_Save, Comp_SaveAs
open: Comp_Opened
new: Comp_New
To use this you can attach it to you Qt UI so it runs in the background.
E.g.
>>> handler = FusionEventHandler(parent=window)
>>> handler.start()
"""
ACTION_IDS = [
"Comp_Save",
"Comp_SaveAs",
"Comp_New",
"Comp_Opened"
]
def __init__(self, parent=None):
super(FusionEventHandler, self).__init__(parent=parent)
# Set up Fusion event callbacks
fusion = getattr(sys.modules["__main__"], "fusion", None)
ui = fusion.UIManager
# Add notifications for the ones we want to listen to
notifiers = []
for action_id in self.ACTION_IDS:
notifier = ui.AddNotify(action_id, None)
notifiers.append(notifier)
# TODO: Not entirely sure whether these must be kept to avoid
# garbage collection
self._notifiers = notifiers
self._event_thread = FusionEventThread(parent=self)
self._event_thread.on_event.connect(self._on_event)
def start(self):
self._event_thread.start()
def stop(self):
self._event_thread.stop()
def _on_event(self, event):
"""Handle Fusion events to emit AYON events"""
if not event:
return
what = event["what"]
# Comp Save
if what in {"Comp_Save", "Comp_SaveAs"}:
if not event["Rets"].get("success"):
# If the Save action is cancelled it will still emit an
# event but with "success": False so we ignore those cases
return
# Comp was saved
emit_event("save", data=event)
return
# Comp New
elif what in {"Comp_New"}:
emit_event("new", data=event)
# Comp Opened
elif what in {"Comp_Opened"}:
emit_event("open", data=event)

View file

@ -1,278 +0,0 @@
from copy import deepcopy
import os
from ayon_fusion.api import (
get_current_comp,
comp_lock_and_undo_chunk,
)
from ayon_core.lib import (
BoolDef,
EnumDef,
)
from ayon_core.pipeline import (
Creator,
CreatedInstance,
AVALON_INSTANCE_ID,
AYON_INSTANCE_ID,
)
from ayon_core.pipeline.workfile import get_workdir
from ayon_api import (
get_project,
get_folder_by_path,
get_task_by_name
)
class GenericCreateSaver(Creator):
default_variants = ["Main", "Mask"]
description = "Fusion Saver to generate image sequence"
icon = "fa5.eye"
instance_attributes = [
"reviewable"
]
settings_category = "fusion"
image_format = "exr"
# TODO: This should be renamed together with Nuke so it is aligned
temp_rendering_path_template = (
"{workdir}/renders/fusion/{product[name]}/"
"{product[name]}.{frame}.{ext}"
)
def create(self, product_name, instance_data, pre_create_data):
self.pass_pre_attributes_to_instance(instance_data, pre_create_data)
instance = CreatedInstance(
product_type=self.product_type,
product_name=product_name,
data=instance_data,
creator=self,
)
data = instance.data_to_store()
comp = get_current_comp()
with comp_lock_and_undo_chunk(comp):
args = (-32768, -32768) # Magical position numbers
saver = comp.AddTool("Saver", *args)
self._update_tool_with_data(saver, data=data)
# Register the CreatedInstance
self._imprint(saver, data)
# Insert the transient data
instance.transient_data["tool"] = saver
self._add_instance_to_context(instance)
return instance
def collect_instances(self):
comp = get_current_comp()
tools = comp.GetToolList(False, "Saver").values()
for tool in tools:
data = self.get_managed_tool_data(tool)
if not data:
continue
# Add instance
created_instance = CreatedInstance.from_existing(data, self)
# Collect transient data
created_instance.transient_data["tool"] = tool
self._add_instance_to_context(created_instance)
def update_instances(self, update_list):
for created_inst, _changes in update_list:
new_data = created_inst.data_to_store()
tool = created_inst.transient_data["tool"]
self._update_tool_with_data(tool, new_data)
self._imprint(tool, new_data)
def remove_instances(self, instances):
for instance in instances:
# Remove the tool from the scene
tool = instance.transient_data["tool"]
if tool:
tool.Delete()
# Remove the collected CreatedInstance to remove from UI directly
self._remove_instance_from_context(instance)
def _imprint(self, tool, data):
# Save all data in a "openpype.{key}" = value data
# Instance id is the tool's name so we don't need to imprint as data
data.pop("instance_id", None)
active = data.pop("active", None)
if active is not None:
# Use active value to set the passthrough state
tool.SetAttrs({"TOOLB_PassThrough": not active})
for key, value in data.items():
tool.SetData(f"openpype.{key}", value)
def _update_tool_with_data(self, tool, data):
"""Update tool node name and output path based on product data"""
if "productName" not in data:
return
original_product_name = tool.GetData("openpype.productName")
original_format = tool.GetData(
"openpype.creator_attributes.image_format"
)
product_name = data["productName"]
if (
original_product_name != product_name
or tool.GetData("openpype.task") != data["task"]
or tool.GetData("openpype.folderPath") != data["folderPath"]
or original_format != data["creator_attributes"]["image_format"]
):
self._configure_saver_tool(data, tool, product_name)
def _configure_saver_tool(self, data, tool, product_name):
formatting_data = deepcopy(data)
# get frame padding from anatomy templates
frame_padding = self.project_anatomy.templates_obj.frame_padding
# get output format
ext = data["creator_attributes"]["image_format"]
# Product change detected
product_type = formatting_data["productType"]
f_product_name = formatting_data["productName"]
folder_path = formatting_data["folderPath"]
folder_name = folder_path.rsplit("/", 1)[-1]
# If the folder path and task do not match the current context then the
# workdir is not just the `AYON_WORKDIR`. Hence, we need to actually
# compute the resulting workdir
if (
data["folderPath"] == self.create_context.get_current_folder_path()
and data["task"] == self.create_context.get_current_task_name()
):
workdir = os.path.normpath(os.getenv("AYON_WORKDIR"))
else:
# TODO: Optimize this logic
project_name = self.create_context.get_current_project_name()
project_entity = get_project(project_name)
folder_entity = get_folder_by_path(project_name,
data["folderPath"])
task_entity = get_task_by_name(project_name,
folder_id=folder_entity["id"],
task_name=data["task"])
workdir = get_workdir(
project_entity=project_entity,
folder_entity=folder_entity,
task_entity=task_entity,
host_name=self.create_context.host_name,
)
formatting_data.update({
"workdir": workdir,
"frame": "0" * frame_padding,
"ext": ext,
"product": {
"name": f_product_name,
"type": product_type,
},
# TODO add more variants for 'folder' and 'task'
"folder": {
"name": folder_name,
},
"task": {
"name": data["task"],
},
# Backwards compatibility
"asset": folder_name,
"subset": f_product_name,
"family": product_type,
})
# build file path to render
# TODO make sure the keys are available in 'formatting_data'
temp_rendering_path_template = (
self.temp_rendering_path_template
.replace("{task}", "{task[name]}")
)
filepath = temp_rendering_path_template.format(**formatting_data)
comp = get_current_comp()
tool["Clip"] = comp.ReverseMapPath(os.path.normpath(filepath))
# Rename tool
if tool.Name != product_name:
print(f"Renaming {tool.Name} -> {product_name}")
tool.SetAttrs({"TOOLS_Name": product_name})
def get_managed_tool_data(self, tool):
"""Return data of the tool if it matches creator identifier"""
data = tool.GetData("openpype")
if not isinstance(data, dict):
return
if (
data.get("creator_identifier") != self.identifier
or data.get("id") not in {
AYON_INSTANCE_ID, AVALON_INSTANCE_ID
}
):
return
# Get active state from the actual tool state
attrs = tool.GetAttrs()
passthrough = attrs["TOOLB_PassThrough"]
data["active"] = not passthrough
# Override publisher's UUID generation because tool names are
# already unique in Fusion in a comp
data["instance_id"] = tool.Name
return data
def get_instance_attr_defs(self):
"""Settings for publish page"""
return self.get_pre_create_attr_defs()
def pass_pre_attributes_to_instance(self, instance_data, pre_create_data):
creator_attrs = instance_data["creator_attributes"] = {}
for pass_key in pre_create_data.keys():
creator_attrs[pass_key] = pre_create_data[pass_key]
def _get_render_target_enum(self):
rendering_targets = {
"local": "Local machine rendering",
"frames": "Use existing frames",
}
if "farm_rendering" in self.instance_attributes:
rendering_targets["farm"] = "Farm rendering"
return EnumDef(
"render_target", items=rendering_targets, label="Render target"
)
def _get_reviewable_bool(self):
return BoolDef(
"review",
default=("reviewable" in self.instance_attributes),
label="Review",
)
def _get_image_format_enum(self):
image_format_options = ["exr", "tga", "tif", "png", "jpg"]
return EnumDef(
"image_format",
items=image_format_options,
default=self.image_format,
label="Output Image Format",
)

View file

@ -1,63 +0,0 @@
import os
import sys
from qtpy import QtCore
class PulseThread(QtCore.QThread):
no_response = QtCore.Signal()
def __init__(self, parent=None):
super(PulseThread, self).__init__(parent=parent)
def run(self):
app = getattr(sys.modules["__main__"], "app", None)
# Interval in milliseconds
interval = os.environ.get("AYON_FUSION_PULSE_INTERVAL", 1000)
while True:
if self.isInterruptionRequested():
return
# We don't need to call Test because PyRemoteObject of the app
# will actually fail to even resolve the Test function if it has
# gone down. So we can actually already just check by confirming
# the method is still getting resolved. (Optimization)
if app.Test is None:
self.no_response.emit()
self.msleep(interval)
class FusionPulse(QtCore.QObject):
"""A Timer that checks whether host app is still alive.
This checks whether the Fusion process is still active at a certain
interval. This is useful due to how Fusion runs its scripts. Each script
runs in its own environment and process (a `fusionscript` process each).
If Fusion would go down and we have a UI process running at the same time
then it can happen that the `fusionscript.exe` will remain running in the
background in limbo due to e.g. a Qt interface's QApplication that keeps
running infinitely.
Warning:
When the host is not detected this will automatically exit
the current process.
"""
def __init__(self, parent=None):
super(FusionPulse, self).__init__(parent=parent)
self._thread = PulseThread(parent=self)
self._thread.no_response.connect(self.on_no_response)
def on_no_response(self):
print("Pulse detected no response from Fusion..")
sys.exit(1)
def start(self):
self._thread.start()
def stop(self):
self._thread.requestInterruption()

View file

@ -1,6 +0,0 @@
### AYON deploy MenuScripts
Note that this `MenuScripts` is not an official Fusion folder.
AYON only uses this folder in `{fusion}/deploy/` to trigger the AYON menu actions.
They are used in the actions defined in `.fu` files in `{fusion}/deploy/Config`.

View file

@ -1,29 +0,0 @@
# This is just a quick hack for users running Py3 locally but having no
# Qt library installed
import os
import subprocess
import importlib
try:
from qtpy import API_NAME
print(f"Qt binding: {API_NAME}")
mod = importlib.import_module(API_NAME)
print(f"Qt path: {mod.__file__}")
print("Qt library found, nothing to do..")
except ImportError:
print("Assuming no Qt library is installed..")
print('Installing PySide2 for Python 3.6: '
f'{os.environ["FUSION16_PYTHON36_HOME"]}')
# Get full path to python executable
exe = "python.exe" if os.name == 'nt' else "python"
python = os.path.join(os.environ["FUSION16_PYTHON36_HOME"], exe)
assert os.path.exists(python), f"Python doesn't exist: {python}"
# Do python -m pip install PySide2
args = [python, "-m", "pip", "install", "PySide2"]
print(f"Args: {args}")
subprocess.Popen(args)

View file

@ -1,47 +0,0 @@
import os
import sys
if sys.version_info < (3, 7):
# hack to handle discrepancy between distributed libraries and Python 3.6
# mostly because wrong version of urllib3
# TODO remove when not necessary
from ayon_fusion import FUSION_ADDON_ROOT
vendor_path = os.path.join(FUSION_ADDON_ROOT, "vendor")
if vendor_path not in sys.path:
sys.path.insert(0, vendor_path)
print(f"Added vendorized libraries from {vendor_path}")
from ayon_core.lib import Logger
from ayon_core.pipeline import (
install_host,
registered_host,
)
def main(env):
# This script working directory starts in Fusion application folder.
# However the contents of that folder can conflict with Qt library dlls
# so we make sure to move out of it to avoid DLL Load Failed errors.
os.chdir("..")
from ayon_fusion.api import FusionHost
from ayon_fusion.api import menu
# activate resolve from pype
install_host(FusionHost())
log = Logger.get_logger(__name__)
log.info(f"Registered host: {registered_host()}")
menu.launch_ayon_menu()
# Initiate a QTimer to check if Fusion is still alive every X interval
# If Fusion is not found - kill itself
# todo(roy): Implement timer that ensures UI doesn't remain when e.g.
# Fusion closes down
if __name__ == "__main__":
result = main(os.environ)
sys.exit(not bool(result))

View file

@ -1,60 +0,0 @@
{
Action
{
ID = "AYON_Menu",
Category = "AYON",
Name = "AYON Menu",
Targets =
{
Composition =
{
Execute = _Lua [=[
local scriptPath = app:MapPath("AYON:../MenuScripts/launch_menu.py")
if bmd.fileexists(scriptPath) == false then
print("[AYON Error] Can't run file: " .. scriptPath)
else
target:RunScript(scriptPath)
end
]=],
},
},
},
Action
{
ID = "AYON_Install_PySide2",
Category = "AYON",
Name = "Install PySide2",
Targets =
{
Composition =
{
Execute = _Lua [=[
local scriptPath = app:MapPath("AYON:../MenuScripts/install_pyside2.py")
if bmd.fileexists(scriptPath) == false then
print("[AYON Error] Can't run file: " .. scriptPath)
else
target:RunScript(scriptPath)
end
]=],
},
},
},
Menus
{
Target = "ChildFrame",
Before "Help"
{
Sub "AYON"
{
"AYON_Menu{}",
"_",
Sub "Admin" {
"AYON_Install_PySide2{}"
}
}
},
},
}

View file

@ -1,19 +0,0 @@
{
Locked = true,
Global = {
Paths = {
Map = {
["AYON:"] = "$(AYON_FUSION_ROOT)/deploy/ayon",
["Config:"] = "UserPaths:Config;AYON:Config",
["Scripts:"] = "UserPaths:Scripts;Reactor:System/Scripts",
},
},
Script = {
PythonVersion = 3,
Python3Forced = true
},
UserInterface = {
Language = "en_US"
},
},
}

View file

@ -1,36 +0,0 @@
import os
from ayon_applications import PreLaunchHook
from ayon_fusion import FUSION_ADDON_ROOT
class FusionLaunchMenuHook(PreLaunchHook):
"""Launch AYON menu on start of Fusion"""
app_groups = ["fusion"]
order = 9
def execute(self):
# Prelaunch hook is optional
settings = self.data["project_settings"][self.host_name]
if not settings["hooks"]["FusionLaunchMenuHook"]["enabled"]:
return
variant = self.application.name
if variant.isnumeric():
version = int(variant)
if version < 18:
print("Skipping launch of OpenPype menu on Fusion start "
"because Fusion version below 18.0 does not support "
"/execute argument on launch. "
f"Version detected: {version}")
return
else:
print(f"Application variant is not numeric: {variant}. "
"Validation for Fusion version 18+ for /execute "
"prelaunch argument skipped.")
path = os.path.join(FUSION_ADDON_ROOT,
"deploy",
"MenuScripts",
"launch_menu.py").replace("\\", "/")
script = f"fusion:RunScript('{path}')"
self.launch_context.launch_args.extend(["/execute", script])

View file

@ -1,169 +0,0 @@
import os
import shutil
import platform
from pathlib import Path
from ayon_fusion import (
FUSION_ADDON_ROOT,
FUSION_VERSIONS_DICT,
get_fusion_version,
)
from ayon_applications import (
PreLaunchHook,
LaunchTypes,
ApplicationLaunchFailed,
)
class FusionCopyPrefsPrelaunch(PreLaunchHook):
"""
Prepares local Fusion profile directory, copies existing Fusion profile.
This also sets FUSION MasterPrefs variable, which is used
to apply Master.prefs file to override some Fusion profile settings to:
- enable the AYON menu
- force Python 3 over Python 2
- force English interface
Master.prefs is defined in openpype/hosts/fusion/deploy/fusion_shared.prefs
"""
app_groups = {"fusion"}
order = 2
launch_types = {LaunchTypes.local}
def get_fusion_profile_name(self, profile_version) -> str:
# Returns 'Default', unless FUSION16_PROFILE is set
return os.getenv(f"FUSION{profile_version}_PROFILE", "Default")
def get_fusion_profile_dir(self, profile_version) -> Path:
# Get FUSION_PROFILE_DIR variable
fusion_profile = self.get_fusion_profile_name(profile_version)
fusion_var_prefs_dir = os.getenv(
f"FUSION{profile_version}_PROFILE_DIR"
)
# Check if FUSION_PROFILE_DIR exists
if fusion_var_prefs_dir and Path(fusion_var_prefs_dir).is_dir():
fu_prefs_dir = Path(fusion_var_prefs_dir, fusion_profile)
self.log.info(f"{fusion_var_prefs_dir} is set to {fu_prefs_dir}")
return fu_prefs_dir
def get_profile_source(self, profile_version) -> Path:
"""Get Fusion preferences profile location.
See Per-User_Preferences_and_Paths on VFXpedia for reference.
"""
fusion_profile = self.get_fusion_profile_name(profile_version)
profile_source = self.get_fusion_profile_dir(profile_version)
if profile_source:
return profile_source
# otherwise get default location of the profile folder
fu_prefs_dir = f"Blackmagic Design/Fusion/Profiles/{fusion_profile}"
if platform.system() == "Windows":
profile_source = Path(os.getenv("AppData"), fu_prefs_dir)
elif platform.system() == "Darwin":
profile_source = Path(
"~/Library/Application Support/", fu_prefs_dir
).expanduser()
elif platform.system() == "Linux":
profile_source = Path("~/.fusion", fu_prefs_dir).expanduser()
self.log.info(
f"Locating source Fusion prefs directory: {profile_source}"
)
return profile_source
def get_copy_fusion_prefs_settings(self):
# Get copy preferences options from the global application settings
copy_fusion_settings = self.data["project_settings"]["fusion"].get(
"copy_fusion_settings", {}
)
if not copy_fusion_settings:
self.log.error("Copy prefs settings not found")
copy_status = copy_fusion_settings.get("copy_status", False)
force_sync = copy_fusion_settings.get("force_sync", False)
copy_path = copy_fusion_settings.get("copy_path") or None
if copy_path:
copy_path = Path(copy_path).expanduser()
return copy_status, copy_path, force_sync
def copy_fusion_profile(
self, copy_from: Path, copy_to: Path, force_sync: bool
) -> None:
"""On the first Fusion launch copy the contents of Fusion profile
directory to the working predefined location. If the Openpype profile
folder exists, skip copying, unless re-sync is checked.
If the prefs were not copied on the first launch,
clean Fusion profile will be created in fu_profile_dir.
"""
if copy_to.exists() and not force_sync:
self.log.info(
"Destination Fusion preferences folder already exists: "
f"{copy_to} "
)
return
self.log.info("Starting copying Fusion preferences")
self.log.debug(f"force_sync option is set to {force_sync}")
try:
copy_to.mkdir(exist_ok=True, parents=True)
except PermissionError:
self.log.warning(f"Creating the folder not permitted at {copy_to}")
return
if not copy_from.exists():
self.log.warning(f"Fusion preferences not found in {copy_from}")
return
for file in copy_from.iterdir():
if file.suffix in (
".prefs",
".def",
".blocklist",
".fu",
".toolbars",
):
# convert Path to str to be compatible with Python 3.6+
shutil.copy(str(file), str(copy_to))
self.log.info(
f"Successfully copied preferences: {copy_from} to {copy_to}"
)
def execute(self):
(
copy_status,
fu_profile_dir,
force_sync,
) = self.get_copy_fusion_prefs_settings()
# Get launched application context and return correct app version
app_name = self.launch_context.env.get("AYON_APP_NAME")
app_version = get_fusion_version(app_name)
if app_version is None:
version_names = ", ".join(str(x) for x in FUSION_VERSIONS_DICT)
raise ApplicationLaunchFailed(
"Unable to detect valid Fusion version number from app "
f"name: {app_name}.\nMake sure to include at least a digit "
"to indicate the Fusion version like '18'.\n"
f"Detectable Fusion versions are: {version_names}"
)
_, profile_version = FUSION_VERSIONS_DICT[app_version]
fu_profile = self.get_fusion_profile_name(profile_version)
# do a copy of Fusion profile if copy_status toggle is enabled
if copy_status and fu_profile_dir is not None:
profile_source = self.get_profile_source(profile_version)
dest_folder = Path(fu_profile_dir, fu_profile)
self.copy_fusion_profile(profile_source, dest_folder, force_sync)
# Add temporary profile directory variables to customize Fusion
# to define where it can read custom scripts and tools from
fu_profile_dir_variable = f"FUSION{profile_version}_PROFILE_DIR"
self.log.info(f"Setting {fu_profile_dir_variable}: {fu_profile_dir}")
self.launch_context.env[fu_profile_dir_variable] = str(fu_profile_dir)
# Add custom Fusion Master Prefs and the temporary
# profile directory variables to customize Fusion
# to define where it can read custom scripts and tools from
master_prefs_variable = f"FUSION{profile_version}_MasterPrefs"
master_prefs = Path(
FUSION_ADDON_ROOT, "deploy", "ayon", "fusion_shared.prefs")
self.log.info(f"Setting {master_prefs_variable}: {master_prefs}")
self.launch_context.env[master_prefs_variable] = str(master_prefs)

View file

@ -1,71 +0,0 @@
import os
from ayon_applications import (
PreLaunchHook,
LaunchTypes,
ApplicationLaunchFailed,
)
from ayon_fusion import (
FUSION_ADDON_ROOT,
FUSION_VERSIONS_DICT,
get_fusion_version,
)
class FusionPrelaunch(PreLaunchHook):
"""
Prepares AYON Fusion environment.
Requires correct Python home variable to be defined in the environment
settings for Fusion to point at a valid Python 3 build for Fusion.
Python3 versions that are supported by Fusion:
Fusion 9, 16, 17 : Python 3.6
Fusion 18 : Python 3.6 - 3.10
"""
app_groups = {"fusion"}
order = 1
launch_types = {LaunchTypes.local}
def execute(self):
# making sure python 3 is installed at provided path
# Py 3.3-3.10 for Fusion 18+ or Py 3.6 for Fu 16-17
app_data = self.launch_context.env.get("AYON_APP_NAME")
app_version = get_fusion_version(app_data)
if not app_version:
raise ApplicationLaunchFailed(
"Fusion version information not found in System settings.\n"
"The key field in the 'applications/fusion/variants' should "
"consist a number, corresponding to major Fusion version."
)
py3_var, _ = FUSION_VERSIONS_DICT[app_version]
fusion_python3_home = self.launch_context.env.get(py3_var, "")
for path in fusion_python3_home.split(os.pathsep):
# Allow defining multiple paths, separated by os.pathsep,
# to allow "fallback" to other path.
# But make to set only a single path as final variable.
py3_dir = os.path.normpath(path)
if os.path.isdir(py3_dir):
break
else:
raise ApplicationLaunchFailed(
"Python 3 is not installed at the provided path.\n"
"Make sure the environment in fusion settings has "
"'FUSION_PYTHON3_HOME' set correctly and make sure "
"Python 3 is installed in the given path."
f"\n\nPYTHON PATH: {fusion_python3_home}"
)
self.log.info(f"Setting {py3_var}: '{py3_dir}'...")
self.launch_context.env[py3_var] = py3_dir
# Fusion 18+ requires FUSION_PYTHON3_HOME to also be on PATH
if app_version >= 18:
self.launch_context.env["PATH"] += os.pathsep + py3_dir
self.launch_context.env[py3_var] = py3_dir
# for hook installing PySide2
self.data["fusion_python3_home"] = py3_dir
self.log.info(f"Setting AYON_FUSION_ROOT: {FUSION_ADDON_ROOT}")
self.launch_context.env["AYON_FUSION_ROOT"] = FUSION_ADDON_ROOT

View file

@ -1,185 +0,0 @@
import os
import subprocess
import platform
import uuid
from ayon_applications import PreLaunchHook, LaunchTypes
class InstallPySideToFusion(PreLaunchHook):
"""Automatically installs Qt binding to fusion's python packages.
Check if fusion has installed PySide2 and will try to install if not.
For pipeline implementation is required to have Qt binding installed in
fusion's python packages.
"""
app_groups = {"fusion"}
order = 2
launch_types = {LaunchTypes.local}
def execute(self):
# Prelaunch hook is not crucial
try:
settings = self.data["project_settings"][self.host_name]
if not settings["hooks"]["InstallPySideToFusion"]["enabled"]:
return
self.inner_execute()
except Exception:
self.log.warning(
"Processing of {} crashed.".format(self.__class__.__name__),
exc_info=True
)
def inner_execute(self):
self.log.debug("Check for PySide2 installation.")
fusion_python3_home = self.data.get("fusion_python3_home")
if not fusion_python3_home:
self.log.warning("'fusion_python3_home' was not provided. "
"Installation of PySide2 not possible")
return
if platform.system().lower() == "windows":
exe_filenames = ["python.exe"]
else:
exe_filenames = ["python3", "python"]
for exe_filename in exe_filenames:
python_executable = os.path.join(fusion_python3_home, exe_filename)
if os.path.exists(python_executable):
break
if not os.path.exists(python_executable):
self.log.warning(
"Couldn't find python executable for fusion. {}".format(
python_executable
)
)
return
# Check if PySide2 is installed and skip if yes
if self._is_pyside_installed(python_executable):
self.log.debug("Fusion has already installed PySide2.")
return
self.log.debug("Installing PySide2.")
# Install PySide2 in fusion's python
if self._windows_require_permissions(
os.path.dirname(python_executable)):
result = self._install_pyside_windows(python_executable)
else:
result = self._install_pyside(python_executable)
if result:
self.log.info("Successfully installed PySide2 module to fusion.")
else:
self.log.warning("Failed to install PySide2 module to fusion.")
def _install_pyside_windows(self, python_executable):
"""Install PySide2 python module to fusion's python.
Installation requires administration rights that's why it is required
to use "pywin32" module which can execute command's and ask for
administration rights.
"""
try:
import win32con
import win32process
import win32event
import pywintypes
from win32comext.shell.shell import ShellExecuteEx
from win32comext.shell import shellcon
except Exception:
self.log.warning("Couldn't import \"pywin32\" modules")
return False
try:
# Parameters
# - use "-m pip" as module pip to install PySide2 and argument
# "--ignore-installed" is to force install module to fusion's
# site-packages and make sure it is binary compatible
parameters = "-m pip install --ignore-installed PySide2"
# Execute command and ask for administrator's rights
process_info = ShellExecuteEx(
nShow=win32con.SW_SHOWNORMAL,
fMask=shellcon.SEE_MASK_NOCLOSEPROCESS,
lpVerb="runas",
lpFile=python_executable,
lpParameters=parameters,
lpDirectory=os.path.dirname(python_executable)
)
process_handle = process_info["hProcess"]
win32event.WaitForSingleObject(process_handle,
win32event.INFINITE)
returncode = win32process.GetExitCodeProcess(process_handle)
return returncode == 0
except pywintypes.error:
return False
def _install_pyside(self, python_executable):
"""Install PySide2 python module to fusion's python."""
try:
# Parameters
# - use "-m pip" as module pip to install PySide2 and argument
# "--ignore-installed" is to force install module to fusion's
# site-packages and make sure it is binary compatible
env = dict(os.environ)
del env['PYTHONPATH']
args = [
python_executable,
"-m",
"pip",
"install",
"--ignore-installed",
"PySide2",
]
process = subprocess.Popen(
args, stdout=subprocess.PIPE, universal_newlines=True,
env=env
)
process.communicate()
return process.returncode == 0
except PermissionError:
self.log.warning(
"Permission denied with command:"
"\"{}\".".format(" ".join(args))
)
except OSError as error:
self.log.warning(f"OS error has occurred: \"{error}\".")
except subprocess.SubprocessError:
pass
def _is_pyside_installed(self, python_executable):
"""Check if PySide2 module is in fusion's pip list."""
args = [python_executable, "-c", "from qtpy import QtWidgets"]
process = subprocess.Popen(args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
_, stderr = process.communicate()
stderr = stderr.decode()
if stderr:
return False
return True
def _windows_require_permissions(self, dirpath):
if platform.system().lower() != "windows":
return False
try:
# Attempt to create a temporary file in the folder
temp_file_path = os.path.join(dirpath, uuid.uuid4().hex)
with open(temp_file_path, "w"):
pass
os.remove(temp_file_path) # Clean up temporary file
return False
except PermissionError:
return True
except BaseException as exc:
print(("Failed to determine if root requires permissions."
"Unexpected error: {}").format(exc))
return False

View file

@ -1,63 +0,0 @@
from ayon_core.lib import NumberDef
from ayon_fusion.api.plugin import GenericCreateSaver
class CreateImageSaver(GenericCreateSaver):
"""Fusion Saver to generate single image.
Created to explicitly separate single ('image') or
multi frame('render) outputs.
This might be temporary creator until 'alias' functionality will be
implemented to limit creation of additional product types with similar, but
not the same workflows.
"""
identifier = "io.openpype.creators.fusion.imagesaver"
label = "Image (saver)"
name = "image"
product_type = "image"
description = "Fusion Saver to generate image"
default_frame = 0
def get_detail_description(self):
return """Fusion Saver to generate single image.
This creator is expected for publishing of single frame `image` product
type.
Artist should provide frame number (integer) to specify which frame
should be published. It must be inside of global timeline frame range.
Supports local and deadline rendering.
Supports selection from predefined set of output file extensions:
- exr
- tga
- png
- tif
- jpg
Created to explicitly separate single frame ('image') or
multi frame ('render') outputs.
"""
def get_pre_create_attr_defs(self):
"""Settings for create page"""
attr_defs = [
self._get_render_target_enum(),
self._get_reviewable_bool(),
self._get_frame_int(),
self._get_image_format_enum(),
]
return attr_defs
def _get_frame_int(self):
return NumberDef(
"frame",
default=self.default_frame,
label="Frame",
tooltip="Set frame to be rendered, must be inside of global "
"timeline range"
)

View file

@ -1,149 +0,0 @@
from ayon_core.lib import (
UILabelDef,
NumberDef,
EnumDef
)
from ayon_fusion.api.plugin import GenericCreateSaver
from ayon_fusion.api.lib import get_current_comp
class CreateSaver(GenericCreateSaver):
"""Fusion Saver to generate image sequence of 'render' product type.
Original Saver creator targeted for 'render' product type. It uses
original not to descriptive name because of values in Settings.
"""
identifier = "io.openpype.creators.fusion.saver"
label = "Render (saver)"
name = "render"
product_type = "render"
description = "Fusion Saver to generate image sequence"
default_frame_range_option = "current_folder"
def get_detail_description(self):
return """Fusion Saver to generate image sequence.
This creator is expected for publishing of image sequences for 'render'
product type. (But can publish even single frame 'render'.)
Select what should be source of render range:
- "Current Folder context" - values set on folder on AYON server
- "From render in/out" - from node itself
- "From composition timeline" - from timeline
Supports local and farm rendering.
Supports selection from predefined set of output file extensions:
- exr
- tga
- png
- tif
- jpg
"""
def get_pre_create_attr_defs(self):
"""Settings for create page"""
attr_defs = [
self._get_render_target_enum(),
self._get_reviewable_bool(),
self._get_frame_range_enum(),
self._get_image_format_enum(),
*self._get_custom_frame_range_attribute_defs()
]
return attr_defs
def _get_frame_range_enum(self):
frame_range_options = {
"current_folder": "Current Folder context",
"render_range": "From render in/out",
"comp_range": "From composition timeline",
"custom_range": "Custom frame range",
}
return EnumDef(
"frame_range_source",
items=frame_range_options,
label="Frame range source",
default=self.default_frame_range_option
)
@staticmethod
def _get_custom_frame_range_attribute_defs() -> list:
# Define custom frame range defaults based on current comp
# timeline settings (if a comp is currently open)
comp = get_current_comp()
if comp is not None:
attrs = comp.GetAttrs()
frame_defaults = {
"frameStart": int(attrs["COMPN_GlobalStart"]),
"frameEnd": int(attrs["COMPN_GlobalEnd"]),
"handleStart": int(
attrs["COMPN_RenderStart"] - attrs["COMPN_GlobalStart"]
),
"handleEnd": int(
attrs["COMPN_GlobalEnd"] - attrs["COMPN_RenderEnd"]
),
}
else:
frame_defaults = {
"frameStart": 1001,
"frameEnd": 1100,
"handleStart": 0,
"handleEnd": 0
}
return [
UILabelDef(
label="<br><b>Custom Frame Range</b><br>"
"<i>only used with 'Custom frame range' source</i>"
),
NumberDef(
"custom_frameStart",
label="Frame Start",
default=frame_defaults["frameStart"],
minimum=0,
decimals=0,
tooltip=(
"Set the start frame for the export.\n"
"Only used if frame range source is 'Custom frame range'."
)
),
NumberDef(
"custom_frameEnd",
label="Frame End",
default=frame_defaults["frameEnd"],
minimum=0,
decimals=0,
tooltip=(
"Set the end frame for the export.\n"
"Only used if frame range source is 'Custom frame range'."
)
),
NumberDef(
"custom_handleStart",
label="Handle Start",
default=frame_defaults["handleStart"],
minimum=0,
decimals=0,
tooltip=(
"Set the start handles for the export, this will be "
"added before the start frame.\n"
"Only used if frame range source is 'Custom frame range'."
)
),
NumberDef(
"custom_handleEnd",
label="Handle End",
default=frame_defaults["handleEnd"],
minimum=0,
decimals=0,
tooltip=(
"Set the end handles for the export, this will be added "
"after the end frame.\n"
"Only used if frame range source is 'Custom frame range'."
)
)
]

View file

@ -1,132 +0,0 @@
import ayon_api
from ayon_fusion.api import (
get_current_comp
)
from ayon_core.pipeline import (
AutoCreator,
CreatedInstance,
)
class FusionWorkfileCreator(AutoCreator):
identifier = "workfile"
product_type = "workfile"
label = "Workfile"
icon = "fa5.file"
default_variant = "Main"
create_allow_context_change = False
data_key = "openpype_workfile"
def collect_instances(self):
comp = get_current_comp()
data = comp.GetData(self.data_key)
if not data:
return
product_name = data.get("productName")
if product_name is None:
product_name = data["subset"]
instance = CreatedInstance(
product_type=self.product_type,
product_name=product_name,
data=data,
creator=self
)
instance.transient_data["comp"] = comp
self._add_instance_to_context(instance)
def update_instances(self, update_list):
for created_inst, _changes in update_list:
comp = created_inst.transient_data["comp"]
if not hasattr(comp, "SetData"):
# Comp is not alive anymore, likely closed by the user
self.log.error("Workfile comp not found for existing instance."
" Comp might have been closed in the meantime.")
continue
# Imprint data into the comp
data = created_inst.data_to_store()
comp.SetData(self.data_key, data)
def create(self, options=None):
comp = get_current_comp()
if not comp:
self.log.error("Unable to find current comp")
return
existing_instance = None
for instance in self.create_context.instances:
if instance.product_type == self.product_type:
existing_instance = instance
break
project_name = self.create_context.get_current_project_name()
folder_path = self.create_context.get_current_folder_path()
task_name = self.create_context.get_current_task_name()
host_name = self.create_context.host_name
existing_folder_path = None
if existing_instance is not None:
existing_folder_path = existing_instance["folderPath"]
if existing_instance is None:
folder_entity = ayon_api.get_folder_by_path(
project_name, folder_path
)
task_entity = ayon_api.get_task_by_name(
project_name, folder_entity["id"], task_name
)
product_name = self.get_product_name(
project_name,
folder_entity,
task_entity,
self.default_variant,
host_name,
)
data = {
"folderPath": folder_path,
"task": task_name,
"variant": self.default_variant,
}
data.update(self.get_dynamic_data(
project_name,
folder_entity,
task_entity,
self.default_variant,
host_name,
None
))
new_instance = CreatedInstance(
self.product_type, product_name, data, self
)
new_instance.transient_data["comp"] = comp
self._add_instance_to_context(new_instance)
elif (
existing_folder_path != folder_path
or existing_instance["task"] != task_name
):
folder_entity = ayon_api.get_folder_by_path(
project_name, folder_path
)
task_entity = ayon_api.get_task_by_name(
project_name, folder_entity["id"], task_name
)
product_name = self.get_product_name(
project_name,
folder_entity,
task_entity,
self.default_variant,
host_name,
)
existing_instance["folderPath"] = folder_path
existing_instance["task"] = task_name
existing_instance["productName"] = product_name

View file

@ -1,27 +0,0 @@
from ayon_core.pipeline import InventoryAction
class FusionSelectContainers(InventoryAction):
label = "Select Containers"
icon = "mouse-pointer"
color = "#d8d8d8"
def process(self, containers):
from ayon_fusion.api import (
get_current_comp,
comp_lock_and_undo_chunk
)
tools = [i["_tool"] for i in containers]
comp = get_current_comp()
flow = comp.CurrentFrame.FlowView
with comp_lock_and_undo_chunk(comp, self.label):
# Clear selection
flow.Select()
# Select tool
for tool in tools:
flow.Select(tool)

View file

@ -1,72 +0,0 @@
from qtpy import QtGui, QtWidgets
from ayon_core.pipeline import InventoryAction
from ayon_core import style
from ayon_fusion.api import (
get_current_comp,
comp_lock_and_undo_chunk
)
class FusionSetToolColor(InventoryAction):
"""Update the color of the selected tools"""
label = "Set Tool Color"
icon = "plus"
color = "#d8d8d8"
_fallback_color = QtGui.QColor(1.0, 1.0, 1.0)
def process(self, containers):
"""Color all selected tools the selected colors"""
result = []
comp = get_current_comp()
# Get tool color
first = containers[0]
tool = first["_tool"]
color = tool.TileColor
if color is not None:
qcolor = QtGui.QColor().fromRgbF(color["R"], color["G"], color["B"])
else:
qcolor = self._fallback_color
# Launch pick color
picked_color = self.get_color_picker(qcolor)
if not picked_color:
return
with comp_lock_and_undo_chunk(comp):
for container in containers:
# Convert color to RGB 0-1 floats
rgb_f = picked_color.getRgbF()
rgb_f_table = {"R": rgb_f[0], "G": rgb_f[1], "B": rgb_f[2]}
# Update tool
tool = container["_tool"]
tool.TileColor = rgb_f_table
result.append(container)
return result
def get_color_picker(self, color):
"""Launch color picker and return chosen color
Args:
color(QtGui.QColor): Start color to display
Returns:
QtGui.QColor
"""
color_dialog = QtWidgets.QColorDialog(color)
color_dialog.setStyleSheet(style.load_stylesheet())
accepted = color_dialog.exec_()
if not accepted:
return
return color_dialog.selectedColor()

View file

@ -1,81 +0,0 @@
"""A module containing generic loader actions that will display in the Loader.
"""
from ayon_core.pipeline import load
class FusionSetFrameRangeLoader(load.LoaderPlugin):
"""Set frame range excluding pre- and post-handles"""
product_types = {
"animation",
"camera",
"imagesequence",
"render",
"yeticache",
"pointcache",
"render",
}
representations = {"*"}
extensions = {"*"}
label = "Set frame range"
order = 11
icon = "clock-o"
color = "white"
def load(self, context, name, namespace, data):
from ayon_fusion.api import lib
version_attributes = context["version"]["attrib"]
start = version_attributes.get("frameStart", None)
end = version_attributes.get("frameEnd", None)
if start is None or end is None:
print("Skipping setting frame range because start or "
"end frame data is missing..")
return
lib.update_frame_range(start, end)
class FusionSetFrameRangeWithHandlesLoader(load.LoaderPlugin):
"""Set frame range including pre- and post-handles"""
product_types = {
"animation",
"camera",
"imagesequence",
"render",
"yeticache",
"pointcache",
"render",
}
representations = {"*"}
label = "Set frame range (with handles)"
order = 12
icon = "clock-o"
color = "white"
def load(self, context, name, namespace, data):
from ayon_fusion.api import lib
version_attributes = context["version"]["attrib"]
start = version_attributes.get("frameStart", None)
end = version_attributes.get("frameEnd", None)
if start is None or end is None:
print("Skipping setting frame range because start or "
"end frame data is missing..")
return
# Include handles
start -= version_attributes.get("handleStart", 0)
end += version_attributes.get("handleEnd", 0)
lib.update_frame_range(start, end)

View file

@ -1,72 +0,0 @@
from ayon_core.pipeline import (
load,
get_representation_path,
)
from ayon_fusion.api import (
imprint_container,
get_current_comp,
comp_lock_and_undo_chunk
)
class FusionLoadAlembicMesh(load.LoaderPlugin):
"""Load Alembic mesh into Fusion"""
product_types = {"pointcache", "model"}
representations = {"*"}
extensions = {"abc"}
label = "Load alembic mesh"
order = -10
icon = "code-fork"
color = "orange"
tool_type = "SurfaceAlembicMesh"
def load(self, context, name, namespace, data):
# Fallback to folder name when namespace is None
if namespace is None:
namespace = context["folder"]["name"]
# Create the Loader with the filename path set
comp = get_current_comp()
with comp_lock_and_undo_chunk(comp, "Create tool"):
path = self.filepath_from_context(context)
args = (-32768, -32768)
tool = comp.AddTool(self.tool_type, *args)
tool["Filename"] = path
imprint_container(tool,
name=name,
namespace=namespace,
context=context,
loader=self.__class__.__name__)
def switch(self, container, context):
self.update(container, context)
def update(self, container, context):
"""Update Alembic path"""
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
repre_entity = context["representation"]
path = get_representation_path(repre_entity)
with comp_lock_and_undo_chunk(comp, "Update tool"):
tool["Filename"] = path
# Update the imprinted representation
tool.SetData("avalon.representation", repre_entity["id"])
def remove(self, container):
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
with comp_lock_and_undo_chunk(comp, "Remove tool"):
tool.Delete()

View file

@ -1,87 +0,0 @@
from ayon_core.pipeline import (
load,
get_representation_path,
)
from ayon_fusion.api import (
imprint_container,
get_current_comp,
comp_lock_and_undo_chunk,
)
class FusionLoadFBXMesh(load.LoaderPlugin):
"""Load FBX mesh into Fusion"""
product_types = {"*"}
representations = {"*"}
extensions = {
"3ds",
"amc",
"aoa",
"asf",
"bvh",
"c3d",
"dae",
"dxf",
"fbx",
"htr",
"mcd",
"obj",
"trc",
}
label = "Load FBX mesh"
order = -10
icon = "code-fork"
color = "orange"
tool_type = "SurfaceFBXMesh"
def load(self, context, name, namespace, data):
# Fallback to folder name when namespace is None
if namespace is None:
namespace = context["folder"]["name"]
# Create the Loader with the filename path set
comp = get_current_comp()
with comp_lock_and_undo_chunk(comp, "Create tool"):
path = self.filepath_from_context(context)
args = (-32768, -32768)
tool = comp.AddTool(self.tool_type, *args)
tool["ImportFile"] = path
imprint_container(
tool,
name=name,
namespace=namespace,
context=context,
loader=self.__class__.__name__,
)
def switch(self, container, context):
self.update(container, context)
def update(self, container, context):
"""Update path"""
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
repre_entity = context["representation"]
path = get_representation_path(repre_entity)
with comp_lock_and_undo_chunk(comp, "Update tool"):
tool["ImportFile"] = path
# Update the imprinted representation
tool.SetData("avalon.representation", repre_entity["id"])
def remove(self, container):
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
with comp_lock_and_undo_chunk(comp, "Remove tool"):
tool.Delete()

View file

@ -1,291 +0,0 @@
import contextlib
import ayon_core.pipeline.load as load
from ayon_fusion.api import (
imprint_container,
get_current_comp,
comp_lock_and_undo_chunk,
)
from ayon_core.lib.transcoding import IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
comp = get_current_comp()
@contextlib.contextmanager
def preserve_inputs(tool, inputs):
"""Preserve the tool's inputs after context"""
comp = tool.Comp()
values = {}
for name in inputs:
tool_input = getattr(tool, name)
value = tool_input[comp.TIME_UNDEFINED]
values[name] = value
try:
yield
finally:
for name, value in values.items():
tool_input = getattr(tool, name)
tool_input[comp.TIME_UNDEFINED] = value
@contextlib.contextmanager
def preserve_trim(loader, log=None):
"""Preserve the relative trim of the Loader tool.
This tries to preserve the loader's trim (trim in and trim out) after
the context by reapplying the "amount" it trims on the clip's length at
start and end.
"""
# Get original trim as amount of "trimming" from length
time = loader.Comp().TIME_UNDEFINED
length = loader.GetAttrs()["TOOLIT_Clip_Length"][1] - 1
trim_from_start = loader["ClipTimeStart"][time]
trim_from_end = length - loader["ClipTimeEnd"][time]
try:
yield
finally:
length = loader.GetAttrs()["TOOLIT_Clip_Length"][1] - 1
if trim_from_start > length:
trim_from_start = length
if log:
log.warning(
"Reducing trim in to %d "
"(because of less frames)" % trim_from_start
)
remainder = length - trim_from_start
if trim_from_end > remainder:
trim_from_end = remainder
if log:
log.warning(
"Reducing trim in to %d "
"(because of less frames)" % trim_from_end
)
loader["ClipTimeStart"][time] = trim_from_start
loader["ClipTimeEnd"][time] = length - trim_from_end
def loader_shift(loader, frame, relative=True):
"""Shift global in time by i preserving duration
This moves the loader by i frames preserving global duration. When relative
is False it will shift the global in to the start frame.
Args:
loader (tool): The fusion loader tool.
frame (int): The amount of frames to move.
relative (bool): When True the shift is relative, else the shift will
change the global in to frame.
Returns:
int: The resulting relative frame change (how much it moved)
"""
comp = loader.Comp()
time = comp.TIME_UNDEFINED
old_in = loader["GlobalIn"][time]
old_out = loader["GlobalOut"][time]
if relative:
shift = frame
else:
shift = frame - old_in
if not shift:
return 0
# Shifting global in will try to automatically compensate for the change
# in the "ClipTimeStart" and "HoldFirstFrame" inputs, so we preserve those
# input values to "just shift" the clip
with preserve_inputs(
loader,
inputs=[
"ClipTimeStart",
"ClipTimeEnd",
"HoldFirstFrame",
"HoldLastFrame",
],
):
# GlobalIn cannot be set past GlobalOut or vice versa
# so we must apply them in the order of the shift.
if shift > 0:
loader["GlobalOut"][time] = old_out + shift
loader["GlobalIn"][time] = old_in + shift
else:
loader["GlobalIn"][time] = old_in + shift
loader["GlobalOut"][time] = old_out + shift
return int(shift)
class FusionLoadSequence(load.LoaderPlugin):
"""Load image sequence into Fusion"""
product_types = {
"imagesequence",
"review",
"render",
"plate",
"image",
"online",
}
representations = {"*"}
extensions = set(
ext.lstrip(".") for ext in IMAGE_EXTENSIONS.union(VIDEO_EXTENSIONS)
)
label = "Load sequence"
order = -10
icon = "code-fork"
color = "orange"
def load(self, context, name, namespace, data):
# Fallback to folder name when namespace is None
if namespace is None:
namespace = context["folder"]["name"]
# Use the first file for now
path = self.filepath_from_context(context)
# Create the Loader with the filename path set
comp = get_current_comp()
with comp_lock_and_undo_chunk(comp, "Create Loader"):
args = (-32768, -32768)
tool = comp.AddTool("Loader", *args)
tool["Clip"] = comp.ReverseMapPath(path)
# Set global in point to start frame (if in version.data)
start = self._get_start(context["version"], tool)
loader_shift(tool, start, relative=False)
imprint_container(
tool,
name=name,
namespace=namespace,
context=context,
loader=self.__class__.__name__,
)
def switch(self, container, context):
self.update(container, context)
def update(self, container, context):
"""Update the Loader's path
Fusion automatically tries to reset some variables when changing
the loader's path to a new file. These automatic changes are to its
inputs:
- ClipTimeStart: Fusion reset to 0 if duration changes
- We keep the trim in as close as possible to the previous value.
When there are less frames then the amount of trim we reduce
it accordingly.
- ClipTimeEnd: Fusion reset to 0 if duration changes
- We keep the trim out as close as possible to the previous value
within new amount of frames after trim in (ClipTimeStart) has
been set.
- GlobalIn: Fusion reset to comp's global in if duration changes
- We change it to the "frameStart"
- GlobalEnd: Fusion resets to globalIn + length if duration changes
- We do the same like Fusion - allow fusion to take control.
- HoldFirstFrame: Fusion resets this to 0
- We preserve the value.
- HoldLastFrame: Fusion resets this to 0
- We preserve the value.
- Reverse: Fusion resets to disabled if "Loop" is not enabled.
- We preserve the value.
- Depth: Fusion resets to "Format"
- We preserve the value.
- KeyCode: Fusion resets to ""
- We preserve the value.
- TimeCodeOffset: Fusion resets to 0
- We preserve the value.
"""
tool = container["_tool"]
assert tool.ID == "Loader", "Must be Loader"
comp = tool.Comp()
repre_entity = context["representation"]
path = self.filepath_from_context(context)
# Get start frame from version data
start = self._get_start(context["version"], tool)
with comp_lock_and_undo_chunk(comp, "Update Loader"):
# Update the loader's path whilst preserving some values
with preserve_trim(tool, log=self.log):
with preserve_inputs(
tool,
inputs=(
"HoldFirstFrame",
"HoldLastFrame",
"Reverse",
"Depth",
"KeyCode",
"TimeCodeOffset",
),
):
tool["Clip"] = comp.ReverseMapPath(path)
# Set the global in to the start frame of the sequence
global_in_changed = loader_shift(tool, start, relative=False)
if global_in_changed:
# Log this change to the user
self.log.debug(
"Changed '%s' global in: %d" % (tool.Name, start)
)
# Update the imprinted representation
tool.SetData("avalon.representation", repre_entity["id"])
def remove(self, container):
tool = container["_tool"]
assert tool.ID == "Loader", "Must be Loader"
comp = tool.Comp()
with comp_lock_and_undo_chunk(comp, "Remove Loader"):
tool.Delete()
def _get_start(self, version_entity, tool):
"""Return real start frame of published files (incl. handles)"""
attributes = version_entity["attrib"]
# Get start frame directly with handle if it's in data
start = attributes.get("frameStartHandle")
if start is not None:
return start
# Get frame start without handles
start = attributes.get("frameStart")
if start is None:
self.log.warning(
"Missing start frame for version "
"assuming starts at frame 0 for: "
"{}".format(tool.Name)
)
return 0
# Use `handleStart` if the data is available
handle_start = attributes.get("handleStart")
if handle_start:
start -= handle_start
return start

View file

@ -1,87 +0,0 @@
from ayon_core.pipeline import (
load,
get_representation_path,
)
from ayon_fusion.api import (
imprint_container,
get_current_comp,
comp_lock_and_undo_chunk
)
from ayon_fusion.api.lib import get_fusion_module
class FusionLoadUSD(load.LoaderPlugin):
"""Load USD into Fusion
Support for USD was added since Fusion 18.5
"""
product_types = {"*"}
representations = {"*"}
extensions = {"usd", "usda", "usdz"}
label = "Load USD"
order = -10
icon = "code-fork"
color = "orange"
tool_type = "uLoader"
@classmethod
def apply_settings(cls, project_settings):
super(FusionLoadUSD, cls).apply_settings(project_settings)
if cls.enabled:
# Enable only in Fusion 18.5+
fusion = get_fusion_module()
version = fusion.GetVersion()
major = version[1]
minor = version[2]
is_usd_supported = (major, minor) >= (18, 5)
cls.enabled = is_usd_supported
def load(self, context, name, namespace, data):
# Fallback to folder name when namespace is None
if namespace is None:
namespace = context["folder"]["name"]
# Create the Loader with the filename path set
comp = get_current_comp()
with comp_lock_and_undo_chunk(comp, "Create tool"):
path = self.fname
args = (-32768, -32768)
tool = comp.AddTool(self.tool_type, *args)
tool["Filename"] = path
imprint_container(tool,
name=name,
namespace=namespace,
context=context,
loader=self.__class__.__name__)
def switch(self, container, context):
self.update(container, context)
def update(self, container, context):
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
repre_entity = context["representation"]
path = get_representation_path(repre_entity)
with comp_lock_and_undo_chunk(comp, "Update tool"):
tool["Filename"] = path
# Update the imprinted representation
tool.SetData("avalon.representation", repre_entity["id"])
def remove(self, container):
tool = container["_tool"]
assert tool.ID == self.tool_type, f"Must be {self.tool_type}"
comp = tool.Comp()
with comp_lock_and_undo_chunk(comp, "Remove tool"):
tool.Delete()

View file

@ -1,33 +0,0 @@
"""Import workfiles into your current comp.
As all imported nodes are free floating and will probably be changed there
is no update or reload function added for this plugin
"""
from ayon_core.pipeline import load
from ayon_fusion.api import (
get_current_comp,
get_bmd_library,
)
class FusionLoadWorkfile(load.LoaderPlugin):
"""Load the content of a workfile into Fusion"""
product_types = {"workfile"}
representations = {"*"}
extensions = {"comp"}
label = "Load Workfile"
order = -10
icon = "code-fork"
color = "orange"
def load(self, context, name, namespace, data):
# Get needed elements
bmd = get_bmd_library()
comp = get_current_comp()
path = self.filepath_from_context(context)
# Paste the content of the file into the current comp
comp.Paste(bmd.readfile(path))

View file

@ -1,22 +0,0 @@
import pyblish.api
from ayon_fusion.api import get_current_comp
class CollectCurrentCompFusion(pyblish.api.ContextPlugin):
"""Collect current comp"""
order = pyblish.api.CollectorOrder - 0.4
label = "Collect Current Comp"
hosts = ["fusion"]
def process(self, context):
"""Collect all image sequence tools"""
current_comp = get_current_comp()
assert current_comp, "Must have active Fusion composition"
context.data["currentComp"] = current_comp
# Store path to current file
filepath = current_comp.GetAttrs().get("COMPS_FileName", "")
context.data['currentFile'] = filepath

View file

@ -1,44 +0,0 @@
import pyblish.api
def get_comp_render_range(comp):
"""Return comp's start-end render range and global start-end range."""
comp_attrs = comp.GetAttrs()
start = comp_attrs["COMPN_RenderStart"]
end = comp_attrs["COMPN_RenderEnd"]
global_start = comp_attrs["COMPN_GlobalStart"]
global_end = comp_attrs["COMPN_GlobalEnd"]
# Whenever render ranges are undefined fall back
# to the comp's global start and end
if start == -1000000000:
start = global_start
if end == -1000000000:
end = global_end
return start, end, global_start, global_end
class CollectFusionCompFrameRanges(pyblish.api.ContextPlugin):
"""Collect current comp"""
# We run this after CollectorOrder - 0.1 otherwise it gets
# overridden by global plug-in `CollectContextEntities`
order = pyblish.api.CollectorOrder - 0.05
label = "Collect Comp Frame Ranges"
hosts = ["fusion"]
def process(self, context):
"""Collect all image sequence tools"""
comp = context.data["currentComp"]
# Store comp render ranges
start, end, global_start, global_end = get_comp_render_range(comp)
context.data.update({
"renderFrameStart": int(start),
"renderFrameEnd": int(end),
"compFrameStart": int(global_start),
"compFrameEnd": int(global_end)
})

View file

@ -1,116 +0,0 @@
import pyblish.api
from ayon_core.pipeline import registered_host
def collect_input_containers(tools):
"""Collect containers that contain any of the node in `nodes`.
This will return any loaded Avalon container that contains at least one of
the nodes. As such, the Avalon container is an input for it. Or in short,
there are member nodes of that container.
Returns:
list: Input avalon containers
"""
# Lookup by node ids
lookup = frozenset([tool.Name for tool in tools])
containers = []
host = registered_host()
for container in host.ls():
name = container["_tool"].Name
# We currently assume no "groups" as containers but just single tools
# like a single "Loader" operator. As such we just check whether the
# Loader is part of the processing queue.
if name in lookup:
containers.append(container)
return containers
def iter_upstream(tool):
"""Yields all upstream inputs for the current tool.
Yields:
tool: The input tools.
"""
def get_connected_input_tools(tool):
"""Helper function that returns connected input tools for a tool."""
inputs = []
# Filter only to actual types that will have sensible upstream
# connections. So we ignore just "Number" inputs as they can be
# many to iterate, slowing things down quite a bit - and in practice
# they don't have upstream connections.
VALID_INPUT_TYPES = ['Image', 'Particles', 'Mask', 'DataType3D']
for type_ in VALID_INPUT_TYPES:
for input_ in tool.GetInputList(type_).values():
output = input_.GetConnectedOutput()
if output:
input_tool = output.GetTool()
inputs.append(input_tool)
return inputs
# Initialize process queue with the node's inputs itself
queue = get_connected_input_tools(tool)
# We keep track of which node names we have processed so far, to ensure we
# don't process the same hierarchy again. We are not pushing the tool
# itself into the set as that doesn't correctly recognize the same tool.
# Since tool names are unique in a comp in Fusion we rely on that.
collected = set(tool.Name for tool in queue)
# Traverse upstream references for all nodes and yield them as we
# process the queue.
while queue:
upstream_tool = queue.pop()
yield upstream_tool
# Find upstream tools that are not collected yet.
upstream_inputs = get_connected_input_tools(upstream_tool)
upstream_inputs = [t for t in upstream_inputs if
t.Name not in collected]
queue.extend(upstream_inputs)
collected.update(tool.Name for tool in upstream_inputs)
class CollectUpstreamInputs(pyblish.api.InstancePlugin):
"""Collect source input containers used for this publish.
This will include `inputs` data of which loaded publishes were used in the
generation of this publish. This leaves an upstream trace to what was used
as input.
"""
label = "Collect Inputs"
order = pyblish.api.CollectorOrder + 0.2
hosts = ["fusion"]
families = ["render", "image"]
def process(self, instance):
# Get all upstream and include itself
if not any(instance[:]):
self.log.debug("No tool found in instance, skipping..")
return
tool = instance[0]
nodes = list(iter_upstream(tool))
nodes.append(tool)
# Collect containers for the given set of nodes
containers = collect_input_containers(nodes)
inputs = [c["representation"] for c in containers]
instance.data["inputRepresentations"] = inputs
self.log.debug("Collected inputs: %s" % inputs)

View file

@ -1,109 +0,0 @@
import pyblish.api
class CollectInstanceData(pyblish.api.InstancePlugin):
"""Collect Fusion saver instances
This additionally stores the Comp start and end render range in the
current context's data as "frameStart" and "frameEnd".
"""
order = pyblish.api.CollectorOrder
label = "Collect Instances Data"
hosts = ["fusion"]
def process(self, instance):
"""Collect all image sequence tools"""
context = instance.context
# Include creator attributes directly as instance data
creator_attributes = instance.data["creator_attributes"]
instance.data.update(creator_attributes)
frame_range_source = creator_attributes.get("frame_range_source")
instance.data["frame_range_source"] = frame_range_source
# get folder frame ranges to all instances
# render product type instances `current_folder` render target
start = context.data["frameStart"]
end = context.data["frameEnd"]
handle_start = context.data["handleStart"]
handle_end = context.data["handleEnd"]
start_with_handle = start - handle_start
end_with_handle = end + handle_end
# conditions for render product type instances
if frame_range_source == "render_range":
# set comp render frame ranges
start = context.data["renderFrameStart"]
end = context.data["renderFrameEnd"]
handle_start = 0
handle_end = 0
start_with_handle = start
end_with_handle = end
if frame_range_source == "comp_range":
comp_start = context.data["compFrameStart"]
comp_end = context.data["compFrameEnd"]
render_start = context.data["renderFrameStart"]
render_end = context.data["renderFrameEnd"]
# set comp frame ranges
start = render_start
end = render_end
handle_start = render_start - comp_start
handle_end = comp_end - render_end
start_with_handle = comp_start
end_with_handle = comp_end
if frame_range_source == "custom_range":
start = int(instance.data["custom_frameStart"])
end = int(instance.data["custom_frameEnd"])
handle_start = int(instance.data["custom_handleStart"])
handle_end = int(instance.data["custom_handleEnd"])
start_with_handle = start - handle_start
end_with_handle = end + handle_end
frame = instance.data["creator_attributes"].get("frame")
# explicitly publishing only single frame
if frame is not None:
frame = int(frame)
start = frame
end = frame
handle_start = 0
handle_end = 0
start_with_handle = frame
end_with_handle = frame
# Include start and end render frame in label
product_name = instance.data["productName"]
label = (
"{product_name} ({start}-{end}) [{handle_start}-{handle_end}]"
).format(
product_name=product_name,
start=int(start),
end=int(end),
handle_start=int(handle_start),
handle_end=int(handle_end)
)
instance.data.update({
"label": label,
# todo: Allow custom frame range per instance
"frameStart": start,
"frameEnd": end,
"frameStartHandle": start_with_handle,
"frameEndHandle": end_with_handle,
"handleStart": handle_start,
"handleEnd": handle_end,
"fps": context.data["fps"],
})
# Add review family if the instance is marked as 'review'
# This could be done through a 'review' Creator attribute.
if instance.data.get("review", False):
self.log.debug("Adding review family..")
instance.data["families"].append("review")

View file

@ -1,208 +0,0 @@
import os
import attr
import pyblish.api
from ayon_core.pipeline import publish
from ayon_core.pipeline.publish import RenderInstance
from ayon_fusion.api.lib import get_frame_path
@attr.s
class FusionRenderInstance(RenderInstance):
# extend generic, composition name is needed
fps = attr.ib(default=None)
projectEntity = attr.ib(default=None)
stagingDir = attr.ib(default=None)
app_version = attr.ib(default=None)
tool = attr.ib(default=None)
workfileComp = attr.ib(default=None)
publish_attributes = attr.ib(default={})
frameStartHandle = attr.ib(default=None)
frameEndHandle = attr.ib(default=None)
class CollectFusionRender(
publish.AbstractCollectRender,
publish.ColormanagedPyblishPluginMixin
):
order = pyblish.api.CollectorOrder + 0.09
label = "Collect Fusion Render"
hosts = ["fusion"]
def get_instances(self, context):
comp = context.data.get("currentComp")
comp_frame_format_prefs = comp.GetPrefs("Comp.FrameFormat")
aspect_x = comp_frame_format_prefs["AspectX"]
aspect_y = comp_frame_format_prefs["AspectY"]
current_file = context.data["currentFile"]
version = context.data["version"]
project_entity = context.data["projectEntity"]
instances = []
for inst in context:
if not inst.data.get("active", True):
continue
product_type = inst.data["productType"]
if product_type not in ["render", "image"]:
continue
task_name = inst.data["task"]
tool = inst.data["transientData"]["tool"]
instance_families = inst.data.get("families", [])
product_name = inst.data["productName"]
instance = FusionRenderInstance(
tool=tool,
workfileComp=comp,
productType=product_type,
family=product_type,
families=instance_families,
version=version,
time="",
source=current_file,
label=inst.data["label"],
productName=product_name,
folderPath=inst.data["folderPath"],
task=task_name,
attachTo=False,
setMembers='',
publish=True,
name=product_name,
resolutionWidth=comp_frame_format_prefs.get("Width"),
resolutionHeight=comp_frame_format_prefs.get("Height"),
pixelAspect=aspect_x / aspect_y,
tileRendering=False,
tilesX=0,
tilesY=0,
review="review" in instance_families,
frameStart=inst.data["frameStart"],
frameEnd=inst.data["frameEnd"],
handleStart=inst.data["handleStart"],
handleEnd=inst.data["handleEnd"],
frameStartHandle=inst.data["frameStartHandle"],
frameEndHandle=inst.data["frameEndHandle"],
frameStep=1,
fps=comp_frame_format_prefs.get("Rate"),
app_version=comp.GetApp().Version,
publish_attributes=inst.data.get("publish_attributes", {}),
# The source instance this render instance replaces
source_instance=inst
)
render_target = inst.data["creator_attributes"]["render_target"]
# Add render target family
render_target_family = f"render.{render_target}"
if render_target_family not in instance.families:
instance.families.append(render_target_family)
# Add render target specific data
if render_target in {"local", "frames"}:
instance.projectEntity = project_entity
if render_target == "farm":
fam = "render.farm"
if fam not in instance.families:
instance.families.append(fam)
instance.farm = True # to skip integrate
if "review" in instance.families:
# to skip ExtractReview locally
instance.families.remove("review")
instance.deadline = inst.data.get("deadline")
instances.append(instance)
return instances
def post_collecting_action(self):
for instance in self._context:
if "render.frames" in instance.data.get("families", []):
# adding representation data to the instance
self._update_for_frames(instance)
def get_expected_files(self, render_instance):
"""
Returns list of rendered files that should be created by
Deadline. These are not published directly, they are source
for later 'submit_publish_job'.
Args:
render_instance (RenderInstance): to pull anatomy and parts used
in url
Returns:
(list) of absolute urls to rendered file
"""
start = render_instance.frameStart - render_instance.handleStart
end = render_instance.frameEnd + render_instance.handleEnd
comp = render_instance.workfileComp
path = comp.MapPath(
render_instance.tool["Clip"][
render_instance.workfileComp.TIME_UNDEFINED
]
)
output_dir = os.path.dirname(path)
render_instance.outputDir = output_dir
basename = os.path.basename(path)
head, padding, ext = get_frame_path(basename)
expected_files = []
for frame in range(start, end + 1):
expected_files.append(
os.path.join(
output_dir,
f"{head}{str(frame).zfill(padding)}{ext}"
)
)
return expected_files
def _update_for_frames(self, instance):
"""Updating instance for render.frames family
Adding representation data to the instance. Also setting
colorspaceData to the representation based on file rules.
"""
expected_files = instance.data["expectedFiles"]
start = instance.data["frameStart"] - instance.data["handleStart"]
path = expected_files[0]
basename = os.path.basename(path)
staging_dir = os.path.dirname(path)
_, padding, ext = get_frame_path(basename)
repre = {
"name": ext[1:],
"ext": ext[1:],
"frameStart": f"%0{padding}d" % start,
"files": [os.path.basename(f) for f in expected_files],
"stagingDir": staging_dir,
}
self.set_representation_colorspace(
representation=repre,
context=instance.context,
)
# review representation
if instance.data.get("review", False):
repre["tags"] = ["review"]
# add the repre to the instance
if "representations" not in instance.data:
instance.data["representations"] = []
instance.data["representations"].append(repre)
return instance

View file

@ -1,26 +0,0 @@
import os
import pyblish.api
class CollectFusionWorkfile(pyblish.api.InstancePlugin):
"""Collect Fusion workfile representation."""
order = pyblish.api.CollectorOrder + 0.1
label = "Collect Workfile"
hosts = ["fusion"]
families = ["workfile"]
def process(self, instance):
current_file = instance.context.data["currentFile"]
folder, file = os.path.split(current_file)
filename, ext = os.path.splitext(file)
instance.data['representations'] = [{
'name': ext.lstrip("."),
'ext': ext.lstrip("."),
'files': file,
"stagingDir": folder,
}]

View file

@ -1,207 +0,0 @@
import os
import logging
import contextlib
import collections
import pyblish.api
from ayon_core.pipeline import publish
from ayon_fusion.api import comp_lock_and_undo_chunk
from ayon_fusion.api.lib import get_frame_path, maintained_comp_range
log = logging.getLogger(__name__)
@contextlib.contextmanager
def enabled_savers(comp, savers):
"""Enable only the `savers` in Comp during the context.
Any Saver tool in the passed composition that is not in the savers list
will be set to passthrough during the context.
Args:
comp (object): Fusion composition object.
savers (list): List of Saver tool objects.
"""
passthrough_key = "TOOLB_PassThrough"
original_states = {}
enabled_saver_names = {saver.Name for saver in savers}
all_savers = comp.GetToolList(False, "Saver").values()
savers_by_name = {saver.Name: saver for saver in all_savers}
try:
for saver in all_savers:
original_state = saver.GetAttrs()[passthrough_key]
original_states[saver.Name] = original_state
# The passthrough state we want to set (passthrough != enabled)
state = saver.Name not in enabled_saver_names
if state != original_state:
saver.SetAttrs({passthrough_key: state})
yield
finally:
for saver_name, original_state in original_states.items():
saver = savers_by_name[saver_name]
saver.SetAttrs({"TOOLB_PassThrough": original_state})
class FusionRenderLocal(
pyblish.api.InstancePlugin,
publish.ColormanagedPyblishPluginMixin
):
"""Render the current Fusion composition locally."""
order = pyblish.api.ExtractorOrder - 0.2
label = "Render Local"
hosts = ["fusion"]
families = ["render.local"]
is_rendered_key = "_fusionrenderlocal_has_rendered"
def process(self, instance):
# Start render
result = self.render(instance)
if result is False:
raise RuntimeError(f"Comp render failed for {instance}")
self._add_representation(instance)
# Log render status
self.log.info(
"Rendered '{}' for folder '{}' under the task '{}'".format(
instance.data["name"],
instance.data["folderPath"],
instance.data["task"],
)
)
def render(self, instance):
"""Render instance.
We try to render the minimal amount of times by combining the instances
that have a matching frame range in one Fusion render. Then for the
batch of instances we store whether the render succeeded or failed.
"""
if self.is_rendered_key in instance.data:
# This instance was already processed in batch with another
# instance, so we just return the render result directly
self.log.debug(f"Instance {instance} was already rendered")
return instance.data[self.is_rendered_key]
instances_by_frame_range = self.get_render_instances_by_frame_range(
instance.context
)
# Render matching batch of instances that share the same frame range
frame_range = self.get_instance_render_frame_range(instance)
render_instances = instances_by_frame_range[frame_range]
# We initialize render state false to indicate it wasn't successful
# yet to keep track of whether Fusion succeeded. This is for cases
# where an error below this might cause the comp render result not
# to be stored for the instances of this batch
for render_instance in render_instances:
render_instance.data[self.is_rendered_key] = False
savers_to_render = [inst.data["tool"] for inst in render_instances]
current_comp = instance.context.data["currentComp"]
frame_start, frame_end = frame_range
self.log.info(
f"Starting Fusion render frame range {frame_start}-{frame_end}"
)
saver_names = ", ".join(saver.Name for saver in savers_to_render)
self.log.info(f"Rendering tools: {saver_names}")
with comp_lock_and_undo_chunk(current_comp):
with maintained_comp_range(current_comp):
with enabled_savers(current_comp, savers_to_render):
result = current_comp.Render(
{
"Start": frame_start,
"End": frame_end,
"Wait": True,
}
)
# Store the render state for all the rendered instances
for render_instance in render_instances:
render_instance.data[self.is_rendered_key] = bool(result)
return result
def _add_representation(self, instance):
"""Add representation to instance"""
expected_files = instance.data["expectedFiles"]
start = instance.data["frameStart"] - instance.data["handleStart"]
path = expected_files[0]
_, padding, ext = get_frame_path(path)
staging_dir = os.path.dirname(path)
files = [os.path.basename(f) for f in expected_files]
if len(expected_files) == 1:
files = files[0]
repre = {
"name": ext[1:],
"ext": ext[1:],
"frameStart": f"%0{padding}d" % start,
"files": files,
"stagingDir": staging_dir,
}
self.set_representation_colorspace(
representation=repre,
context=instance.context,
)
# review representation
if instance.data.get("review", False):
repre["tags"] = ["review"]
# add the repre to the instance
if "representations" not in instance.data:
instance.data["representations"] = []
instance.data["representations"].append(repre)
return instance
def get_render_instances_by_frame_range(self, context):
"""Return enabled render.local instances grouped by their frame range.
Arguments:
context (pyblish.Context): The pyblish context
Returns:
dict: (start, end): instances mapping
"""
instances_to_render = [
instance for instance in context if
# Only active instances
instance.data.get("publish", True) and
# Only render.local instances
"render.local" in instance.data.get("families", [])
]
# Instances by frame ranges
instances_by_frame_range = collections.defaultdict(list)
for instance in instances_to_render:
start, end = self.get_instance_render_frame_range(instance)
instances_by_frame_range[(start, end)].append(instance)
return dict(instances_by_frame_range)
def get_instance_render_frame_range(self, instance):
start = instance.data["frameStartHandle"]
end = instance.data["frameEndHandle"]
return start, end

View file

@ -1,44 +0,0 @@
import pyblish.api
from ayon_core.pipeline import OptionalPyblishPluginMixin
from ayon_core.pipeline import KnownPublishError
class FusionIncrementCurrentFile(
pyblish.api.ContextPlugin, OptionalPyblishPluginMixin
):
"""Increment the current file.
Saves the current file with an increased version number.
"""
label = "Increment workfile version"
order = pyblish.api.IntegratorOrder + 9.0
hosts = ["fusion"]
optional = True
def process(self, context):
if not self.is_active(context.data):
return
from ayon_core.lib import version_up
from ayon_core.pipeline.publish import get_errored_plugins_from_context
errored_plugins = get_errored_plugins_from_context(context)
if any(
plugin.__name__ == "FusionSubmitDeadline"
for plugin in errored_plugins
):
raise KnownPublishError(
"Skipping incrementing current file because "
"submission to render farm failed."
)
comp = context.data.get("currentComp")
assert comp, "Must have comp"
current_filepath = context.data["currentFile"]
new_filepath = version_up(current_filepath)
comp.Save(new_filepath)

View file

@ -1,21 +0,0 @@
import pyblish.api
class FusionSaveComp(pyblish.api.ContextPlugin):
"""Save current comp"""
label = "Save current file"
order = pyblish.api.ExtractorOrder - 0.49
hosts = ["fusion"]
families = ["render", "image", "workfile"]
def process(self, context):
comp = context.data.get("currentComp")
assert comp, "Must have comp"
current = comp.GetAttrs().get("COMPS_FileName", "")
assert context.data['currentFile'] == current
self.log.info("Saving current file: {}".format(current))
comp.Save()

View file

@ -1,54 +0,0 @@
import pyblish.api
from ayon_core.pipeline import (
publish,
OptionalPyblishPluginMixin,
PublishValidationError,
)
from ayon_fusion.api.action import SelectInvalidAction
class ValidateBackgroundDepth(
pyblish.api.InstancePlugin, OptionalPyblishPluginMixin
):
"""Validate if all Background tool are set to float32 bit"""
order = pyblish.api.ValidatorOrder
label = "Validate Background Depth 32 bit"
hosts = ["fusion"]
families = ["render", "image"]
optional = True
actions = [SelectInvalidAction, publish.RepairAction]
@classmethod
def get_invalid(cls, instance):
context = instance.context
comp = context.data.get("currentComp")
assert comp, "Must have Comp object"
backgrounds = comp.GetToolList(False, "Background").values()
if not backgrounds:
return []
return [i for i in backgrounds if i.GetInput("Depth") != 4.0]
def process(self, instance):
if not self.is_active(instance.data):
return
invalid = self.get_invalid(instance)
if invalid:
raise PublishValidationError(
"Found {} Backgrounds tools which"
" are not set to float32".format(len(invalid)),
title=self.label,
)
@classmethod
def repair(cls, instance):
comp = instance.context.data.get("currentComp")
invalid = cls.get_invalid(instance)
for i in invalid:
i.SetInput("Depth", 4.0, comp.TIME_UNDEFINED)

View file

@ -1,32 +0,0 @@
import os
import pyblish.api
from ayon_core.pipeline import PublishValidationError
class ValidateFusionCompSaved(pyblish.api.ContextPlugin):
"""Ensure current comp is saved"""
order = pyblish.api.ValidatorOrder
label = "Validate Comp Saved"
families = ["render", "image"]
hosts = ["fusion"]
def process(self, context):
comp = context.data.get("currentComp")
assert comp, "Must have Comp object"
attrs = comp.GetAttrs()
filename = attrs["COMPS_FileName"]
if not filename:
raise PublishValidationError("Comp is not saved.",
title=self.label)
if not os.path.exists(filename):
raise PublishValidationError(
"Comp file does not exist: %s" % filename, title=self.label)
if attrs["COMPB_Modified"]:
self.log.warning("Comp is modified. Save your comp to ensure your "
"changes propagate correctly.")

View file

@ -1,44 +0,0 @@
import pyblish.api
from ayon_core.pipeline.publish import RepairAction
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateCreateFolderChecked(pyblish.api.InstancePlugin):
"""Valid if all savers have the input attribute CreateDir checked on
This attribute ensures that the folders to which the saver will write
will be created.
"""
order = pyblish.api.ValidatorOrder
label = "Validate Create Folder Checked"
families = ["render", "image"]
hosts = ["fusion"]
actions = [RepairAction, SelectInvalidAction]
@classmethod
def get_invalid(cls, instance):
tool = instance.data["tool"]
create_dir = tool.GetInput("CreateDir")
if create_dir == 0.0:
cls.log.error(
"%s has Create Folder turned off" % instance[0].Name
)
return [tool]
def process(self, instance):
invalid = self.get_invalid(instance)
if invalid:
raise PublishValidationError(
"Found Saver with Create Folder During Render checked off",
title=self.label,
)
@classmethod
def repair(cls, instance):
invalid = cls.get_invalid(instance)
for tool in invalid:
tool.SetInput("CreateDir", 1.0)

View file

@ -1,66 +0,0 @@
import os
import pyblish.api
from ayon_core.pipeline.publish import RepairAction
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateLocalFramesExistence(pyblish.api.InstancePlugin):
"""Checks if files for savers that's set
to publish expected frames exists
"""
order = pyblish.api.ValidatorOrder
label = "Validate Expected Frames Exists"
families = ["render.frames"]
hosts = ["fusion"]
actions = [RepairAction, SelectInvalidAction]
@classmethod
def get_invalid(cls, instance, non_existing_frames=None):
if non_existing_frames is None:
non_existing_frames = []
tool = instance.data["tool"]
expected_files = instance.data["expectedFiles"]
for file in expected_files:
if not os.path.exists(file):
cls.log.error(
f"Missing file: {file}"
)
non_existing_frames.append(file)
if len(non_existing_frames) > 0:
cls.log.error(f"Some of {tool.Name}'s files does not exist")
return [tool]
def process(self, instance):
non_existing_frames = []
invalid = self.get_invalid(instance, non_existing_frames)
if invalid:
raise PublishValidationError(
"{} is set to publish existing frames but "
"some frames are missing. "
"The missing file(s) are:\n\n{}".format(
invalid[0].Name,
"\n\n".join(non_existing_frames),
),
title=self.label,
)
@classmethod
def repair(cls, instance):
invalid = cls.get_invalid(instance)
if invalid:
tool = instance.data["tool"]
# Change render target to local to render locally
tool.SetData("openpype.creator_attributes.render_target", "local")
cls.log.info(
f"Reload the publisher and {tool.Name} "
"will be set to render locally"
)

View file

@ -1,41 +0,0 @@
import os
import pyblish.api
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateFilenameHasExtension(pyblish.api.InstancePlugin):
"""Ensure the Saver has an extension in the filename path
This disallows files written as `filename` instead of `filename.frame.ext`.
Fusion does not always set an extension for your filename when
changing the file format of the saver.
"""
order = pyblish.api.ValidatorOrder
label = "Validate Filename Has Extension"
families = ["render", "image"]
hosts = ["fusion"]
actions = [SelectInvalidAction]
def process(self, instance):
invalid = self.get_invalid(instance)
if invalid:
raise PublishValidationError("Found Saver without an extension",
title=self.label)
@classmethod
def get_invalid(cls, instance):
path = instance.data["expectedFiles"][0]
fname, ext = os.path.splitext(path)
if not ext:
tool = instance.data["tool"]
cls.log.error("%s has no extension specified" % tool.Name)
return [tool]
return []

View file

@ -1,27 +0,0 @@
import pyblish.api
from ayon_core.pipeline import PublishValidationError
class ValidateImageFrame(pyblish.api.InstancePlugin):
"""Validates that `image` product type contains only single frame."""
order = pyblish.api.ValidatorOrder
label = "Validate Image Frame"
families = ["image"]
hosts = ["fusion"]
def process(self, instance):
render_start = instance.data["frameStartHandle"]
render_end = instance.data["frameEndHandle"]
too_many_frames = (isinstance(instance.data["expectedFiles"], list)
and len(instance.data["expectedFiles"]) > 1)
if render_end - render_start > 0 or too_many_frames:
desc = ("Trying to render multiple frames. 'image' product type "
"is meant for single frame. Please use 'render' creator.")
raise PublishValidationError(
title="Frame range outside of comp range",
message=desc,
description=desc
)

View file

@ -1,41 +0,0 @@
import pyblish.api
from ayon_core.pipeline import PublishValidationError
class ValidateInstanceFrameRange(pyblish.api.InstancePlugin):
"""Validate instance frame range is within comp's global render range."""
order = pyblish.api.ValidatorOrder
label = "Validate Frame Range"
families = ["render", "image"]
hosts = ["fusion"]
def process(self, instance):
context = instance.context
global_start = context.data["compFrameStart"]
global_end = context.data["compFrameEnd"]
render_start = instance.data["frameStartHandle"]
render_end = instance.data["frameEndHandle"]
if render_start < global_start or render_end > global_end:
message = (
f"Instance {instance} render frame range "
f"({render_start}-{render_end}) is outside of the comp's "
f"global render range ({global_start}-{global_end}) and thus "
f"can't be rendered. "
)
description = (
f"{message}\n\n"
f"Either update the comp's global range or the instance's "
f"frame range to ensure the comp's frame range includes the "
f"to render frame range for the instance."
)
raise PublishValidationError(
title="Frame range outside of comp range",
message=message,
description=description
)

View file

@ -1,80 +0,0 @@
# -*- coding: utf-8 -*-
"""Validate if instance context is the same as publish context."""
import pyblish.api
from ayon_fusion.api.action import SelectToolAction
from ayon_core.pipeline.publish import (
RepairAction,
ValidateContentsOrder,
PublishValidationError,
OptionalPyblishPluginMixin
)
class ValidateInstanceInContextFusion(pyblish.api.InstancePlugin,
OptionalPyblishPluginMixin):
"""Validator to check if instance context matches context of publish.
When working in per-shot style you always publish data in context of
current asset (shot). This validator checks if this is so. It is optional
so it can be disabled when needed.
"""
# Similar to maya and houdini-equivalent `ValidateInstanceInContext`
order = ValidateContentsOrder
label = "Instance in same Context"
optional = True
hosts = ["fusion"]
actions = [SelectToolAction, RepairAction]
def process(self, instance):
if not self.is_active(instance.data):
return
instance_context = self.get_context(instance.data)
context = self.get_context(instance.context.data)
if instance_context != context:
context_label = "{} > {}".format(*context)
instance_label = "{} > {}".format(*instance_context)
raise PublishValidationError(
message=(
"Instance '{}' publishes to different asset than current "
"context: {}. Current context: {}".format(
instance.name, instance_label, context_label
)
),
description=(
"## Publishing to a different asset\n"
"There are publish instances present which are publishing "
"into a different asset than your current context.\n\n"
"Usually this is not what you want but there can be cases "
"where you might want to publish into another asset or "
"shot. If that's the case you can disable the validation "
"on the instance to ignore it."
)
)
@classmethod
def repair(cls, instance):
create_context = instance.context.data["create_context"]
instance_id = instance.data.get("instance_id")
created_instance = create_context.get_instance_by_id(
instance_id
)
if created_instance is None:
raise RuntimeError(
f"No CreatedInstances found with id '{instance_id} "
f"in {create_context.instances_by_id}"
)
context_asset, context_task = cls.get_context(instance.context.data)
created_instance["folderPath"] = context_asset
created_instance["task"] = context_task
create_context.save_changes()
@staticmethod
def get_context(data):
"""Return asset, task from publishing context data"""
return data["folderPath"], data["task"]

View file

@ -1,36 +0,0 @@
import pyblish.api
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateSaverHasInput(pyblish.api.InstancePlugin):
"""Validate saver has incoming connection
This ensures a Saver has at least an input connection.
"""
order = pyblish.api.ValidatorOrder
label = "Validate Saver Has Input"
families = ["render", "image"]
hosts = ["fusion"]
actions = [SelectInvalidAction]
@classmethod
def get_invalid(cls, instance):
saver = instance.data["tool"]
if not saver.Input.GetConnectedOutput():
return [saver]
return []
def process(self, instance):
invalid = self.get_invalid(instance)
if invalid:
saver_name = invalid[0].Name
raise PublishValidationError(
"Saver has no incoming connection: {} ({})".format(instance,
saver_name),
title=self.label)

View file

@ -1,49 +0,0 @@
import pyblish.api
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateSaverPassthrough(pyblish.api.ContextPlugin):
"""Validate saver passthrough is similar to Pyblish publish state"""
order = pyblish.api.ValidatorOrder
label = "Validate Saver Passthrough"
families = ["render", "image"]
hosts = ["fusion"]
actions = [SelectInvalidAction]
def process(self, context):
# Workaround for ContextPlugin always running, even if no instance
# is present with the family
instances = pyblish.api.instances_by_plugin(instances=list(context),
plugin=self)
if not instances:
self.log.debug("Ignoring plugin.. (bugfix)")
invalid_instances = []
for instance in instances:
invalid = self.is_invalid(instance)
if invalid:
invalid_instances.append(instance)
if invalid_instances:
self.log.info("Reset pyblish to collect your current scene state, "
"that should fix error.")
raise PublishValidationError(
"Invalid instances: {0}".format(invalid_instances),
title=self.label)
def is_invalid(self, instance):
saver = instance.data["tool"]
attr = saver.GetAttrs()
active = not attr["TOOLB_PassThrough"]
if active != instance.data.get("publish", True):
self.log.info("Saver has different passthrough state than "
"Pyblish: {} ({})".format(instance, saver.Name))
return [saver]
return []

View file

@ -1,116 +0,0 @@
import pyblish.api
from ayon_core.pipeline import (
PublishValidationError,
OptionalPyblishPluginMixin,
)
from ayon_fusion.api.action import SelectInvalidAction
from ayon_fusion.api import comp_lock_and_undo_chunk
class ValidateSaverResolution(
pyblish.api.InstancePlugin, OptionalPyblishPluginMixin
):
"""Validate that the saver input resolution matches the folder resolution"""
order = pyblish.api.ValidatorOrder
label = "Validate Folder Resolution"
families = ["render", "image"]
hosts = ["fusion"]
optional = True
actions = [SelectInvalidAction]
def process(self, instance):
if not self.is_active(instance.data):
return
resolution = self.get_resolution(instance)
expected_resolution = self.get_expected_resolution(instance)
if resolution != expected_resolution:
raise PublishValidationError(
"The input's resolution does not match "
"the folder's resolution {}x{}.\n\n"
"The input's resolution is {}x{}.".format(
expected_resolution[0], expected_resolution[1],
resolution[0], resolution[1]
)
)
@classmethod
def get_invalid(cls, instance):
saver = instance.data["tool"]
try:
resolution = cls.get_resolution(instance)
except PublishValidationError:
resolution = None
expected_resolution = cls.get_expected_resolution(instance)
if resolution != expected_resolution:
return [saver]
@classmethod
def get_resolution(cls, instance):
saver = instance.data["tool"]
first_frame = instance.data["frameStartHandle"]
return cls.get_tool_resolution(saver, frame=first_frame)
@classmethod
def get_expected_resolution(cls, instance):
attributes = instance.data["folderEntity"]["attrib"]
return attributes["resolutionWidth"], attributes["resolutionHeight"]
@classmethod
def get_tool_resolution(cls, tool, frame):
"""Return the 2D input resolution to a Fusion tool
If the current tool hasn't been rendered its input resolution
hasn't been saved. To combat this, add an expression in
the comments field to read the resolution
Args
tool (Fusion Tool): The tool to query input resolution
frame (int): The frame to query the resolution on.
Returns:
tuple: width, height as 2-tuple of integers
"""
comp = tool.Composition
# False undo removes the undo-stack from the undo list
with comp_lock_and_undo_chunk(comp, "Read resolution", False):
# Save old comment
old_comment = ""
has_expression = False
if tool["Comments"][frame] not in ["", None]:
if tool["Comments"].GetExpression() is not None:
has_expression = True
old_comment = tool["Comments"].GetExpression()
tool["Comments"].SetExpression(None)
else:
old_comment = tool["Comments"][frame]
tool["Comments"][frame] = ""
# Get input width
tool["Comments"].SetExpression("self.Input.OriginalWidth")
if tool["Comments"][frame] is None:
raise PublishValidationError(
"Cannot get resolution info for frame '{}'.\n\n "
"Please check that saver has connected input.".format(
frame
)
)
width = int(tool["Comments"][frame])
# Get input height
tool["Comments"].SetExpression("self.Input.OriginalHeight")
height = int(tool["Comments"][frame])
# Reset old comment
tool["Comments"].SetExpression(None)
if has_expression:
tool["Comments"].SetExpression(old_comment)
else:
tool["Comments"][frame] = old_comment
return width, height

View file

@ -1,62 +0,0 @@
from collections import defaultdict
import pyblish.api
from ayon_core.pipeline import PublishValidationError
from ayon_fusion.api.action import SelectInvalidAction
class ValidateUniqueSubsets(pyblish.api.ContextPlugin):
"""Ensure all instances have a unique product name"""
order = pyblish.api.ValidatorOrder
label = "Validate Unique Products"
families = ["render", "image"]
hosts = ["fusion"]
actions = [SelectInvalidAction]
@classmethod
def get_invalid(cls, context):
# Collect instances per product per folder
instances_per_product_folder = defaultdict(lambda: defaultdict(list))
for instance in context:
folder_path = instance.data["folderPath"]
product_name = instance.data["productName"]
instances_per_product_folder[folder_path][product_name].append(
instance
)
# Find which folder + subset combination has more than one instance
# Those are considered invalid because they'd integrate to the same
# destination.
invalid = []
for folder_path, instances_per_product in (
instances_per_product_folder.items()
):
for product_name, instances in instances_per_product.items():
if len(instances) > 1:
cls.log.warning(
(
"{folder_path} > {product_name} used by more than "
"one instance: {instances}"
).format(
folder_path=folder_path,
product_name=product_name,
instances=instances
)
)
invalid.extend(instances)
# Return tools for the invalid instances so they can be selected
invalid = [instance.data["tool"] for instance in invalid]
return invalid
def process(self, context):
invalid = self.get_invalid(context)
if invalid:
raise PublishValidationError(
"Multiple instances are set to the same folder > product.",
title=self.label
)

View file

@ -1,45 +0,0 @@
from ayon_fusion.api import (
comp_lock_and_undo_chunk,
get_current_comp
)
def is_connected(input):
"""Return whether an input has incoming connection"""
return input.GetAttrs()["INPB_Connected"]
def duplicate_with_input_connections():
"""Duplicate selected tools with incoming connections."""
comp = get_current_comp()
original_tools = comp.GetToolList(True).values()
if not original_tools:
return # nothing selected
with comp_lock_and_undo_chunk(
comp, "Duplicate With Input Connections"):
# Generate duplicates
comp.Copy()
comp.SetActiveTool()
comp.Paste()
duplicate_tools = comp.GetToolList(True).values()
# Copy connections
for original, new in zip(original_tools, duplicate_tools):
original_inputs = original.GetInputList().values()
new_inputs = new.GetInputList().values()
assert len(original_inputs) == len(new_inputs)
for original_input, new_input in zip(original_inputs, new_inputs):
if is_connected(original_input):
if is_connected(new_input):
# Already connected if it is between the copied tools
continue
new_input.ConnectTo(original_input.GetConnectedOutput())
assert is_connected(new_input), "Must be connected now"

View file

@ -1,78 +0,0 @@
from __future__ import absolute_import, division, print_function
import sys
from functools import partial
from . import converters, exceptions, filters, setters, validators
from ._cmp import cmp_using
from ._config import get_run_validators, set_run_validators
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
from ._make import (
NOTHING,
Attribute,
Factory,
attrib,
attrs,
fields,
fields_dict,
make_class,
validate,
)
from ._version_info import VersionInfo
__version__ = "21.2.0"
__version_info__ = VersionInfo._from_version_string(__version__)
__title__ = "attrs"
__description__ = "Classes Without Boilerplate"
__url__ = "https://www.attrs.org/"
__uri__ = __url__
__doc__ = __description__ + " <" + __uri__ + ">"
__author__ = "Hynek Schlawack"
__email__ = "hs@ox.cx"
__license__ = "MIT"
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
s = attributes = attrs
ib = attr = attrib
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
__all__ = [
"Attribute",
"Factory",
"NOTHING",
"asdict",
"assoc",
"astuple",
"attr",
"attrib",
"attributes",
"attrs",
"cmp_using",
"converters",
"evolve",
"exceptions",
"fields",
"fields_dict",
"filters",
"get_run_validators",
"has",
"ib",
"make_class",
"resolve_types",
"s",
"set_run_validators",
"setters",
"validate",
"validators",
]
if sys.version_info[:2] >= (3, 6):
from ._next_gen import define, field, frozen, mutable
__all__.extend((define, field, frozen, mutable))

View file

@ -1,475 +0,0 @@
import sys
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
# `import X as X` is required to make these public
from . import converters as converters
from . import exceptions as exceptions
from . import filters as filters
from . import setters as setters
from . import validators as validators
from ._version_info import VersionInfo
__version__: str
__version_info__: VersionInfo
__title__: str
__description__: str
__url__: str
__uri__: str
__author__: str
__email__: str
__license__: str
__copyright__: str
_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)
_EqOrderType = Union[bool, Callable[[Any], Any]]
_ValidatorType = Callable[[Any, Attribute[_T], _T], Any]
_ConverterType = Callable[[Any], Any]
_FilterType = Callable[[Attribute[_T], _T], bool]
_ReprType = Callable[[Any], str]
_ReprArgType = Union[bool, _ReprType]
_OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any]
_OnSetAttrArgType = Union[
_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType
]
_FieldTransformer = Callable[[type, List[Attribute[Any]]], List[Attribute[Any]]]
# FIXME: in reality, if multiple validators are passed they must be in a list
# or tuple, but those are invariant and so would prevent subtypes of
# _ValidatorType from working when passed in a list or tuple.
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
# _make --
NOTHING: object
# NOTE: Factory lies about its return type to make this possible:
# `x: List[int] # = Factory(list)`
# Work around mypy issue #4554 in the common case by using an overload.
if sys.version_info >= (3, 8):
from typing import Literal
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Callable[[Any], _T],
takes_self: Literal[True],
) -> _T: ...
@overload
def Factory(
factory: Callable[[], _T],
takes_self: Literal[False],
) -> _T: ...
else:
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Union[Callable[[Any], _T], Callable[[], _T]],
takes_self: bool = ...,
) -> _T: ...
# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
class Attribute(Generic[_T]):
name: str
default: Optional[_T]
validator: Optional[_ValidatorType[_T]]
repr: _ReprArgType
cmp: _EqOrderType
eq: _EqOrderType
order: _EqOrderType
hash: Optional[bool]
init: bool
converter: Optional[_ConverterType]
metadata: Dict[Any, Any]
type: Optional[Type[_T]]
kw_only: bool
on_setattr: _OnSetAttrType
def evolve(self, **changes: Any) -> "Attribute[Any]": ...
# NOTE: We had several choices for the annotation to use for type arg:
# 1) Type[_T]
# - Pros: Handles simple cases correctly
# - Cons: Might produce less informative errors in the case of conflicting
# TypeVars e.g. `attr.ib(default='bad', type=int)`
# 2) Callable[..., _T]
# - Pros: Better error messages than #1 for conflicting TypeVars
# - Cons: Terrible error messages for validator checks.
# e.g. attr.ib(type=int, validator=validate_str)
# -> error: Cannot infer function type argument
# 3) type (and do all of the work in the mypy plugin)
# - Pros: Simple here, and we could customize the plugin with our own errors.
# - Cons: Would need to write mypy plugin code to handle all the cases.
# We chose option #1.
# `attr` lies about its return type to make the following possible:
# attr() -> Any
# attr(8) -> int
# attr(validator=<some callable>) -> Whatever the callable expects.
# This makes this type of assignments possible:
# x: int = attr(8)
#
# This form catches explicit None or no default but with no other arguments
# returns Any.
@overload
def attrib(
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def attrib(
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def attrib(
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: object = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
def field(
*,
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def field(
*,
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def field(
*,
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def field(
*,
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: _C,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: None = ...,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> Callable[[_C], _C]: ...
mutable = define
frozen = define # they differ only in their defaults
# TODO: add support for returning NamedTuple from the mypy plugin
class _Fields(Tuple[Attribute[Any], ...]):
def __getattr__(self, name: str) -> Attribute[Any]: ...
def fields(cls: type) -> _Fields: ...
def fields_dict(cls: type) -> Dict[str, Attribute[Any]]: ...
def validate(inst: Any) -> None: ...
def resolve_types(
cls: _C,
globalns: Optional[Dict[str, Any]] = ...,
localns: Optional[Dict[str, Any]] = ...,
attribs: Optional[List[Attribute[Any]]] = ...,
) -> _C: ...
# TODO: add support for returning a proper attrs class from the mypy plugin
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo',
# [attr.ib()])` is valid
def make_class(
name: str,
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
bases: Tuple[type, ...] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
collect_by_mro: bool = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> type: ...
# _funcs --
# TODO: add support for returning TypedDict from the mypy plugin
# FIXME: asdict/astuple do not honor their factory args. Waiting on one of
# these:
# https://github.com/python/mypy/issues/4236
# https://github.com/python/typing/issues/253
def asdict(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
value_serializer: Optional[Callable[[type, Attribute[Any], Any], Any]] = ...,
) -> Dict[str, Any]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
def astuple(
inst: Any,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
tuple_factory: Type[Sequence[Any]] = ...,
retain_collection_types: bool = ...,
) -> Tuple[Any, ...]: ...
def has(cls: type) -> bool: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
def evolve(inst: _T, **changes: Any) -> _T: ...
# _config --
def set_run_validators(run: bool) -> None: ...
def get_run_validators() -> bool: ...
# aliases --
s = attributes = attrs
ib = attr = attrib
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)

View file

@ -1,152 +0,0 @@
from __future__ import absolute_import, division, print_function
import functools
from ._compat import new_class
from ._make import _make_ne
_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
def cmp_using(
eq=None,
lt=None,
le=None,
gt=None,
ge=None,
require_same_type=True,
class_name="Comparable",
):
"""
Create a class that can be passed into `attr.ib`'s ``eq``, ``order``, and
``cmp`` arguments to customize field comparison.
The resulting class will have a full set of ordering methods if
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
:param Optional[callable] eq: `callable` used to evaluate equality
of two objects.
:param Optional[callable] lt: `callable` used to evaluate whether
one object is less than another object.
:param Optional[callable] le: `callable` used to evaluate whether
one object is less than or equal to another object.
:param Optional[callable] gt: `callable` used to evaluate whether
one object is greater than another object.
:param Optional[callable] ge: `callable` used to evaluate whether
one object is greater than or equal to another object.
:param bool require_same_type: When `True`, equality and ordering methods
will return `NotImplemented` if objects are not of the same type.
:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
See `comparison` for more details.
.. versionadded:: 21.1.0
"""
body = {
"__slots__": ["value"],
"__init__": _make_init(),
"_requirements": [],
"_is_comparable_to": _is_comparable_to,
}
# Add operations.
num_order_functions = 0
has_eq_function = False
if eq is not None:
has_eq_function = True
body["__eq__"] = _make_operator("eq", eq)
body["__ne__"] = _make_ne()
if lt is not None:
num_order_functions += 1
body["__lt__"] = _make_operator("lt", lt)
if le is not None:
num_order_functions += 1
body["__le__"] = _make_operator("le", le)
if gt is not None:
num_order_functions += 1
body["__gt__"] = _make_operator("gt", gt)
if ge is not None:
num_order_functions += 1
body["__ge__"] = _make_operator("ge", ge)
type_ = new_class(class_name, (object,), {}, lambda ns: ns.update(body))
# Add same type requirement.
if require_same_type:
type_._requirements.append(_check_same_type)
# Add total ordering if at least one operation was defined.
if 0 < num_order_functions < 4:
if not has_eq_function:
# functools.total_ordering requires __eq__ to be defined,
# so raise early error here to keep a nice stack.
raise ValueError(
"eq must be define is order to complete ordering from "
"lt, le, gt, ge."
)
type_ = functools.total_ordering(type_)
return type_
def _make_init():
"""
Create __init__ method.
"""
def __init__(self, value):
"""
Initialize object with *value*.
"""
self.value = value
return __init__
def _make_operator(name, func):
"""
Create operator method.
"""
def method(self, other):
if not self._is_comparable_to(other):
return NotImplemented
result = func(self.value, other.value)
if result is NotImplemented:
return NotImplemented
return result
method.__name__ = "__%s__" % (name,)
method.__doc__ = "Return a %s b. Computed by attrs." % (
_operation_names[name],
)
return method
def _is_comparable_to(self, other):
"""
Check whether `other` is comparable to `self`.
"""
for func in self._requirements:
if not func(self, other):
return False
return True
def _check_same_type(self, other):
"""
Return True if *self* and *other* are of the same type, False otherwise.
"""
return other.value.__class__ is self.value.__class__

View file

@ -1,14 +0,0 @@
from typing import Type
from . import _CompareWithType
def cmp_using(
eq: Optional[_CompareWithType],
lt: Optional[_CompareWithType],
le: Optional[_CompareWithType],
gt: Optional[_CompareWithType],
ge: Optional[_CompareWithType],
require_same_type: bool,
class_name: str,
) -> Type: ...

View file

@ -1,242 +0,0 @@
from __future__ import absolute_import, division, print_function
import platform
import sys
import types
import warnings
PY2 = sys.version_info[0] == 2
PYPY = platform.python_implementation() == "PyPy"
if PYPY or sys.version_info[:2] >= (3, 6):
ordered_dict = dict
else:
from collections import OrderedDict
ordered_dict = OrderedDict
if PY2:
from collections import Mapping, Sequence
from UserDict import IterableUserDict
# We 'bundle' isclass instead of using inspect as importing inspect is
# fairly expensive (order of 10-15 ms for a modern machine in 2016)
def isclass(klass):
return isinstance(klass, (type, types.ClassType))
def new_class(name, bases, kwds, exec_body):
"""
A minimal stub of types.new_class that we need for make_class.
"""
ns = {}
exec_body(ns)
return type(name, bases, ns)
# TYPE is used in exceptions, repr(int) is different on Python 2 and 3.
TYPE = "type"
def iteritems(d):
return d.iteritems()
# Python 2 is bereft of a read-only dict proxy, so we make one!
class ReadOnlyDict(IterableUserDict):
"""
Best-effort read-only dict wrapper.
"""
def __setitem__(self, key, val):
# We gently pretend we're a Python 3 mappingproxy.
raise TypeError(
"'mappingproxy' object does not support item assignment"
)
def update(self, _):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'update'"
)
def __delitem__(self, _):
# We gently pretend we're a Python 3 mappingproxy.
raise TypeError(
"'mappingproxy' object does not support item deletion"
)
def clear(self):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'clear'"
)
def pop(self, key, default=None):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'pop'"
)
def popitem(self):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'popitem'"
)
def setdefault(self, key, default=None):
# We gently pretend we're a Python 3 mappingproxy.
raise AttributeError(
"'mappingproxy' object has no attribute 'setdefault'"
)
def __repr__(self):
# Override to be identical to the Python 3 version.
return "mappingproxy(" + repr(self.data) + ")"
def metadata_proxy(d):
res = ReadOnlyDict()
res.data.update(d) # We blocked update, so we have to do it like this.
return res
def just_warn(*args, **kw): # pragma: no cover
"""
We only warn on Python 3 because we are not aware of any concrete
consequences of not setting the cell on Python 2.
"""
else: # Python 3 and later.
from collections.abc import Mapping, Sequence # noqa
def just_warn(*args, **kw):
"""
We only warn on Python 3 because we are not aware of any concrete
consequences of not setting the cell on Python 2.
"""
warnings.warn(
"Running interpreter doesn't sufficiently support code object "
"introspection. Some features like bare super() or accessing "
"__class__ will not work with slotted classes.",
RuntimeWarning,
stacklevel=2,
)
def isclass(klass):
return isinstance(klass, type)
TYPE = "class"
def iteritems(d):
return d.items()
new_class = types.new_class
def metadata_proxy(d):
return types.MappingProxyType(dict(d))
def make_set_closure_cell():
"""Return a function of two arguments (cell, value) which sets
the value stored in the closure cell `cell` to `value`.
"""
# pypy makes this easy. (It also supports the logic below, but
# why not do the easy/fast thing?)
if PYPY:
def set_closure_cell(cell, value):
cell.__setstate__((value,))
return set_closure_cell
# Otherwise gotta do it the hard way.
# Create a function that will set its first cellvar to `value`.
def set_first_cellvar_to(value):
x = value
return
# This function will be eliminated as dead code, but
# not before its reference to `x` forces `x` to be
# represented as a closure cell rather than a local.
def force_x_to_be_a_cell(): # pragma: no cover
return x
try:
# Extract the code object and make sure our assumptions about
# the closure behavior are correct.
if PY2:
co = set_first_cellvar_to.func_code
else:
co = set_first_cellvar_to.__code__
if co.co_cellvars != ("x",) or co.co_freevars != ():
raise AssertionError # pragma: no cover
# Convert this code object to a code object that sets the
# function's first _freevar_ (not cellvar) to the argument.
if sys.version_info >= (3, 8):
# CPython 3.8+ has an incompatible CodeType signature
# (added a posonlyargcount argument) but also added
# CodeType.replace() to do this without counting parameters.
set_first_freevar_code = co.replace(
co_cellvars=co.co_freevars, co_freevars=co.co_cellvars
)
else:
args = [co.co_argcount]
if not PY2:
args.append(co.co_kwonlyargcount)
args.extend(
[
co.co_nlocals,
co.co_stacksize,
co.co_flags,
co.co_code,
co.co_consts,
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
co.co_lnotab,
# These two arguments are reversed:
co.co_cellvars,
co.co_freevars,
]
)
set_first_freevar_code = types.CodeType(*args)
def set_closure_cell(cell, value):
# Create a function using the set_first_freevar_code,
# whose first closure cell is `cell`. Calling it will
# change the value of that cell.
setter = types.FunctionType(
set_first_freevar_code, {}, "setter", (), (cell,)
)
# And call it to set the cell.
setter(value)
# Make sure it works on this interpreter:
def make_func_with_cell():
x = None
def func():
return x # pragma: no cover
return func
if PY2:
cell = make_func_with_cell().func_closure[0]
else:
cell = make_func_with_cell().__closure__[0]
set_closure_cell(cell, 100)
if cell.cell_contents != 100:
raise AssertionError # pragma: no cover
except Exception:
return just_warn
else:
return set_closure_cell
set_closure_cell = make_set_closure_cell()

View file

@ -1,23 +0,0 @@
from __future__ import absolute_import, division, print_function
__all__ = ["set_run_validators", "get_run_validators"]
_run_validators = True
def set_run_validators(run):
"""
Set whether or not validators are run. By default, they are run.
"""
if not isinstance(run, bool):
raise TypeError("'run' must be bool.")
global _run_validators
_run_validators = run
def get_run_validators():
"""
Return whether or not validators are run.
"""
return _run_validators

View file

@ -1,395 +0,0 @@
from __future__ import absolute_import, division, print_function
import copy
from ._compat import iteritems
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError
def asdict(
inst,
recurse=True,
filter=None,
dict_factory=dict,
retain_collection_types=False,
value_serializer=None,
):
"""
Return the ``attrs`` attribute values of *inst* as a dict.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attr.Attribute` as the first argument and the
value as the second argument.
:param callable dict_factory: A callable to produce dictionaries from. For
example, to produce ordered dictionaries instead of normal Python
dictionaries, pass in ``collections.OrderedDict``.
:param bool retain_collection_types: Do not convert to ``list`` when
encountering an attribute whose type is ``tuple`` or ``set``. Only
meaningful if ``recurse`` is ``True``.
:param Optional[callable] value_serializer: A hook that is called for every
attribute or dict key/value. It receives the current instance, field
and value and must return the (updated) value. The hook is run *after*
the optional *filter* has been applied.
:rtype: return type of *dict_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.0.0 *dict_factory*
.. versionadded:: 16.1.0 *retain_collection_types*
.. versionadded:: 20.3.0 *value_serializer*
"""
attrs = fields(inst.__class__)
rv = dict_factory()
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if value_serializer is not None:
v = value_serializer(inst, a, v)
if recurse is True:
if has(v.__class__):
rv[a.name] = asdict(
v,
True,
filter,
dict_factory,
retain_collection_types,
value_serializer,
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain_collection_types is True else list
rv[a.name] = cf(
[
_asdict_anything(
i,
filter,
dict_factory,
retain_collection_types,
value_serializer,
)
for i in v
]
)
elif isinstance(v, dict):
df = dict_factory
rv[a.name] = df(
(
_asdict_anything(
kk,
filter,
df,
retain_collection_types,
value_serializer,
),
_asdict_anything(
vv,
filter,
df,
retain_collection_types,
value_serializer,
),
)
for kk, vv in iteritems(v)
)
else:
rv[a.name] = v
else:
rv[a.name] = v
return rv
def _asdict_anything(
val,
filter,
dict_factory,
retain_collection_types,
value_serializer,
):
"""
``asdict`` only works on attrs instances, this works on anything.
"""
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
# Attrs class.
rv = asdict(
val,
True,
filter,
dict_factory,
retain_collection_types,
value_serializer,
)
elif isinstance(val, (tuple, list, set, frozenset)):
cf = val.__class__ if retain_collection_types is True else list
rv = cf(
[
_asdict_anything(
i,
filter,
dict_factory,
retain_collection_types,
value_serializer,
)
for i in val
]
)
elif isinstance(val, dict):
df = dict_factory
rv = df(
(
_asdict_anything(
kk, filter, df, retain_collection_types, value_serializer
),
_asdict_anything(
vv, filter, df, retain_collection_types, value_serializer
),
)
for kk, vv in iteritems(val)
)
else:
rv = val
if value_serializer is not None:
rv = value_serializer(None, None, rv)
return rv
def astuple(
inst,
recurse=True,
filter=None,
tuple_factory=tuple,
retain_collection_types=False,
):
"""
Return the ``attrs`` attribute values of *inst* as a tuple.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attr.Attribute` as the first argument and the
value as the second argument.
:param callable tuple_factory: A callable to produce tuples from. For
example, to produce lists instead of tuples.
:param bool retain_collection_types: Do not convert to ``list``
or ``dict`` when encountering an attribute which type is
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
``True``.
:rtype: return type of *tuple_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.2.0
"""
attrs = fields(inst.__class__)
rv = []
retain = retain_collection_types # Very long. :/
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if recurse is True:
if has(v.__class__):
rv.append(
astuple(
v,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain is True else list
rv.append(
cf(
[
astuple(
j,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(j.__class__)
else j
for j in v
]
)
)
elif isinstance(v, dict):
df = v.__class__ if retain is True else dict
rv.append(
df(
(
astuple(
kk,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(kk.__class__)
else kk,
astuple(
vv,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(vv.__class__)
else vv,
)
for kk, vv in iteritems(v)
)
)
else:
rv.append(v)
else:
rv.append(v)
return rv if tuple_factory is list else tuple_factory(rv)
def has(cls):
"""
Check whether *cls* is a class with ``attrs`` attributes.
:param type cls: Class to introspect.
:raise TypeError: If *cls* is not a class.
:rtype: bool
"""
return getattr(cls, "__attrs_attrs__", None) is not None
def assoc(inst, **changes):
"""
Copy *inst* and apply *changes*.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
be found on *cls*.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. deprecated:: 17.1.0
Use `evolve` instead.
"""
import warnings
warnings.warn(
"assoc is deprecated and will be removed after 2018/01.",
DeprecationWarning,
stacklevel=2,
)
new = copy.copy(inst)
attrs = fields(inst.__class__)
for k, v in iteritems(changes):
a = getattr(attrs, k, NOTHING)
if a is NOTHING:
raise AttrsAttributeNotFoundError(
"{k} is not an attrs attribute on {cl}.".format(
k=k, cl=new.__class__
)
)
_obj_setattr(new, k, v)
return new
def evolve(inst, **changes):
"""
Create a new instance, based on *inst* with *changes* applied.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise TypeError: If *attr_name* couldn't be found in the class
``__init__``.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 17.1.0
"""
cls = inst.__class__
attrs = fields(cls)
for a in attrs:
if not a.init:
continue
attr_name = a.name # To deal with private attributes.
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
if init_name not in changes:
changes[init_name] = getattr(inst, attr_name)
return cls(**changes)
def resolve_types(cls, globalns=None, localns=None, attribs=None):
"""
Resolve any strings and forward annotations in type annotations.
This is only required if you need concrete types in `Attribute`'s *type*
field. In other words, you don't need to resolve your types if you only
use them for static type checking.
With no arguments, names will be looked up in the module in which the class
was created. If this is not what you want, e.g. if the name only exists
inside a method, you may pass *globalns* or *localns* to specify other
dictionaries in which to look up these names. See the docs of
`typing.get_type_hints` for more details.
:param type cls: Class to resolve.
:param Optional[dict] globalns: Dictionary containing global variables.
:param Optional[dict] localns: Dictionary containing local variables.
:param Optional[list] attribs: List of attribs for the given class.
This is necessary when calling from inside a ``field_transformer``
since *cls* is not an ``attrs`` class yet.
:raise TypeError: If *cls* is not a class.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class and you didn't pass any attribs.
:raise NameError: If types cannot be resolved because of missing variables.
:returns: *cls* so you can use this function also as a class decorator.
Please note that you have to apply it **after** `attr.s`. That means
the decorator has to come in the line **before** `attr.s`.
.. versionadded:: 20.1.0
.. versionadded:: 21.1.0 *attribs*
"""
try:
# Since calling get_type_hints is expensive we cache whether we've
# done it already.
cls.__attrs_types_resolved__
except AttributeError:
import typing
hints = typing.get_type_hints(cls, globalns=globalns, localns=localns)
for field in fields(cls) if attribs is None else attribs:
if field.name in hints:
# Since fields have been frozen we must work around it.
_obj_setattr(field, "type", hints[field.name])
cls.__attrs_types_resolved__ = True
# Return the class so you can use it as a decorator too.
return cls

File diff suppressed because it is too large Load diff

View file

@ -1,158 +0,0 @@
"""
These are Python 3.6+-only and keyword-only APIs that call `attr.s` and
`attr.ib` with different default values.
"""
from functools import partial
from attr.exceptions import UnannotatedAttributeError
from . import setters
from ._make import NOTHING, _frozen_setattrs, attrib, attrs
def define(
maybe_cls=None,
*,
these=None,
repr=None,
hash=None,
init=None,
slots=True,
frozen=False,
weakref_slot=True,
str=False,
auto_attribs=None,
kw_only=False,
cache_hash=False,
auto_exc=True,
eq=None,
order=False,
auto_detect=True,
getstate_setstate=None,
on_setattr=None,
field_transformer=None,
):
r"""
The only behavioral differences are the handling of the *auto_attribs*
option:
:param Optional[bool] auto_attribs: If set to `True` or `False`, it behaves
exactly like `attr.s`. If left `None`, `attr.s` will try to guess:
1. If any attributes are annotated and no unannotated `attr.ib`\ s
are found, it assumes *auto_attribs=True*.
2. Otherwise it assumes *auto_attribs=False* and tries to collect
`attr.ib`\ s.
and that mutable classes (``frozen=False``) validate on ``__setattr__``.
.. versionadded:: 20.1.0
"""
def do_it(cls, auto_attribs):
return attrs(
maybe_cls=cls,
these=these,
repr=repr,
hash=hash,
init=init,
slots=slots,
frozen=frozen,
weakref_slot=weakref_slot,
str=str,
auto_attribs=auto_attribs,
kw_only=kw_only,
cache_hash=cache_hash,
auto_exc=auto_exc,
eq=eq,
order=order,
auto_detect=auto_detect,
collect_by_mro=True,
getstate_setstate=getstate_setstate,
on_setattr=on_setattr,
field_transformer=field_transformer,
)
def wrap(cls):
"""
Making this a wrapper ensures this code runs during class creation.
We also ensure that frozen-ness of classes is inherited.
"""
nonlocal frozen, on_setattr
had_on_setattr = on_setattr not in (None, setters.NO_OP)
# By default, mutable classes validate on setattr.
if frozen is False and on_setattr is None:
on_setattr = setters.validate
# However, if we subclass a frozen class, we inherit the immutability
# and disable on_setattr.
for base_cls in cls.__bases__:
if base_cls.__setattr__ is _frozen_setattrs:
if had_on_setattr:
raise ValueError(
"Frozen classes can't use on_setattr "
"(frozen-ness was inherited)."
)
on_setattr = setters.NO_OP
break
if auto_attribs is not None:
return do_it(cls, auto_attribs)
try:
return do_it(cls, True)
except UnannotatedAttributeError:
return do_it(cls, False)
# maybe_cls's type depends on the usage of the decorator. It's a class
# if it's used as `@attrs` but ``None`` if used as `@attrs()`.
if maybe_cls is None:
return wrap
else:
return wrap(maybe_cls)
mutable = define
frozen = partial(define, frozen=True, on_setattr=None)
def field(
*,
default=NOTHING,
validator=None,
repr=True,
hash=None,
init=True,
metadata=None,
converter=None,
factory=None,
kw_only=False,
eq=None,
order=None,
on_setattr=None,
):
"""
Identical to `attr.ib`, except keyword-only and with some arguments
removed.
.. versionadded:: 20.1.0
"""
return attrib(
default=default,
validator=validator,
repr=repr,
hash=hash,
init=init,
metadata=metadata,
converter=converter,
factory=factory,
kw_only=kw_only,
eq=eq,
order=order,
on_setattr=on_setattr,
)

View file

@ -1,85 +0,0 @@
from __future__ import absolute_import, division, print_function
from functools import total_ordering
from ._funcs import astuple
from ._make import attrib, attrs
@total_ordering
@attrs(eq=False, order=False, slots=True, frozen=True)
class VersionInfo(object):
"""
A version object that can be compared to tuple of length 1--4:
>>> attr.VersionInfo(19, 1, 0, "final") <= (19, 2)
True
>>> attr.VersionInfo(19, 1, 0, "final") < (19, 1, 1)
True
>>> vi = attr.VersionInfo(19, 2, 0, "final")
>>> vi < (19, 1, 1)
False
>>> vi < (19,)
False
>>> vi == (19, 2,)
True
>>> vi == (19, 2, 1)
False
.. versionadded:: 19.2
"""
year = attrib(type=int)
minor = attrib(type=int)
micro = attrib(type=int)
releaselevel = attrib(type=str)
@classmethod
def _from_version_string(cls, s):
"""
Parse *s* and return a _VersionInfo.
"""
v = s.split(".")
if len(v) == 3:
v.append("final")
return cls(
year=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]
)
def _ensure_tuple(self, other):
"""
Ensure *other* is a tuple of a valid length.
Returns a possibly transformed *other* and ourselves as a tuple of
the same length as *other*.
"""
if self.__class__ is other.__class__:
other = astuple(other)
if not isinstance(other, tuple):
raise NotImplementedError
if not (1 <= len(other) <= 4):
raise NotImplementedError
return astuple(self)[: len(other)], other
def __eq__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
return us == them
def __lt__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
# have to do anything special with releaselevel for now.
return us < them

View file

@ -1,9 +0,0 @@
class VersionInfo:
@property
def year(self) -> int: ...
@property
def minor(self) -> int: ...
@property
def micro(self) -> int: ...
@property
def releaselevel(self) -> str: ...

View file

@ -1,111 +0,0 @@
"""
Commonly useful converters.
"""
from __future__ import absolute_import, division, print_function
from ._compat import PY2
from ._make import NOTHING, Factory, pipe
if not PY2:
import inspect
import typing
__all__ = [
"pipe",
"optional",
"default_if_none",
]
def optional(converter):
"""
A converter that allows an attribute to be optional. An optional attribute
is one which can be set to ``None``.
Type annotations will be inferred from the wrapped converter's, if it
has any.
:param callable converter: the converter that is used for non-``None``
values.
.. versionadded:: 17.1.0
"""
def optional_converter(val):
if val is None:
return None
return converter(val)
if not PY2:
sig = None
try:
sig = inspect.signature(converter)
except (ValueError, TypeError): # inspect failed
pass
if sig:
params = list(sig.parameters.values())
if params and params[0].annotation is not inspect.Parameter.empty:
optional_converter.__annotations__["val"] = typing.Optional[
params[0].annotation
]
if sig.return_annotation is not inspect.Signature.empty:
optional_converter.__annotations__["return"] = typing.Optional[
sig.return_annotation
]
return optional_converter
def default_if_none(default=NOTHING, factory=None):
"""
A converter that allows to replace ``None`` values by *default* or the
result of *factory*.
:param default: Value to be used if ``None`` is passed. Passing an instance
of `attr.Factory` is supported, however the ``takes_self`` option
is *not*.
:param callable factory: A callable that takes no parameters whose result
is used if ``None`` is passed.
:raises TypeError: If **neither** *default* or *factory* is passed.
:raises TypeError: If **both** *default* and *factory* are passed.
:raises ValueError: If an instance of `attr.Factory` is passed with
``takes_self=True``.
.. versionadded:: 18.2.0
"""
if default is NOTHING and factory is None:
raise TypeError("Must pass either `default` or `factory`.")
if default is not NOTHING and factory is not None:
raise TypeError(
"Must pass either `default` or `factory` but not both."
)
if factory is not None:
default = Factory(factory)
if isinstance(default, Factory):
if default.takes_self:
raise ValueError(
"`takes_self` is not supported by default_if_none."
)
def default_if_none_converter(val):
if val is not None:
return val
return default.factory()
else:
def default_if_none_converter(val):
if val is not None:
return val
return default
return default_if_none_converter

View file

@ -1,13 +0,0 @@
from typing import Callable, Optional, TypeVar, overload
from . import _ConverterType
_T = TypeVar("_T")
def pipe(*validators: _ConverterType) -> _ConverterType: ...
def optional(converter: _ConverterType) -> _ConverterType: ...
@overload
def default_if_none(default: _T) -> _ConverterType: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType: ...

View file

@ -1,92 +0,0 @@
from __future__ import absolute_import, division, print_function
class FrozenError(AttributeError):
"""
A frozen/immutable instance or attribute have been attempted to be
modified.
It mirrors the behavior of ``namedtuples`` by using the same error message
and subclassing `AttributeError`.
.. versionadded:: 20.1.0
"""
msg = "can't set attribute"
args = [msg]
class FrozenInstanceError(FrozenError):
"""
A frozen instance has been attempted to be modified.
.. versionadded:: 16.1.0
"""
class FrozenAttributeError(FrozenError):
"""
A frozen attribute has been attempted to be modified.
.. versionadded:: 20.1.0
"""
class AttrsAttributeNotFoundError(ValueError):
"""
An ``attrs`` function couldn't find an attribute that the user asked for.
.. versionadded:: 16.2.0
"""
class NotAnAttrsClassError(ValueError):
"""
A non-``attrs`` class has been passed into an ``attrs`` function.
.. versionadded:: 16.2.0
"""
class DefaultAlreadySetError(RuntimeError):
"""
A default has been set using ``attr.ib()`` and is attempted to be reset
using the decorator.
.. versionadded:: 17.1.0
"""
class UnannotatedAttributeError(RuntimeError):
"""
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
annotation.
.. versionadded:: 17.3.0
"""
class PythonTooOldError(RuntimeError):
"""
It was attempted to use an ``attrs`` feature that requires a newer Python
version.
.. versionadded:: 18.2.0
"""
class NotCallableError(TypeError):
"""
A ``attr.ib()`` requiring a callable has been set with a value
that is not callable.
.. versionadded:: 19.2.0
"""
def __init__(self, msg, value):
super(TypeError, self).__init__(msg, value)
self.msg = msg
self.value = value
def __str__(self):
return str(self.msg)

View file

@ -1,18 +0,0 @@
from typing import Any
class FrozenError(AttributeError):
msg: str = ...
class FrozenInstanceError(FrozenError): ...
class FrozenAttributeError(FrozenError): ...
class AttrsAttributeNotFoundError(ValueError): ...
class NotAnAttrsClassError(ValueError): ...
class DefaultAlreadySetError(RuntimeError): ...
class UnannotatedAttributeError(RuntimeError): ...
class PythonTooOldError(RuntimeError): ...
class NotCallableError(TypeError):
msg: str = ...
value: Any = ...
def __init__(self, msg: str, value: Any) -> None: ...

View file

@ -1,52 +0,0 @@
"""
Commonly useful filters for `attr.asdict`.
"""
from __future__ import absolute_import, division, print_function
from ._compat import isclass
from ._make import Attribute
def _split_what(what):
"""
Returns a tuple of `frozenset`s of classes and attributes.
"""
return (
frozenset(cls for cls in what if isclass(cls)),
frozenset(cls for cls in what if isinstance(cls, Attribute)),
)
def include(*what):
"""
Whitelist *what*.
:param what: What to whitelist.
:type what: `list` of `type` or `attr.Attribute`\\ s
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def include_(attribute, value):
return value.__class__ in cls or attribute in attrs
return include_
def exclude(*what):
"""
Blacklist *what*.
:param what: What to blacklist.
:type what: `list` of classes or `attr.Attribute`\\ s.
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def exclude_(attribute, value):
return value.__class__ not in cls and attribute not in attrs
return exclude_

View file

@ -1,7 +0,0 @@
from typing import Any, Union
from . import Attribute, _FilterType
def include(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...
def exclude(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...

View file

@ -1,77 +0,0 @@
"""
Commonly used hooks for on_setattr.
"""
from __future__ import absolute_import, division, print_function
from . import _config
from .exceptions import FrozenAttributeError
def pipe(*setters):
"""
Run all *setters* and return the return value of the last one.
.. versionadded:: 20.1.0
"""
def wrapped_pipe(instance, attrib, new_value):
rv = new_value
for setter in setters:
rv = setter(instance, attrib, rv)
return rv
return wrapped_pipe
def frozen(_, __, ___):
"""
Prevent an attribute to be modified.
.. versionadded:: 20.1.0
"""
raise FrozenAttributeError()
def validate(instance, attrib, new_value):
"""
Run *attrib*'s validator on *new_value* if it has one.
.. versionadded:: 20.1.0
"""
if _config._run_validators is False:
return new_value
v = attrib.validator
if not v:
return new_value
v(instance, attrib, new_value)
return new_value
def convert(instance, attrib, new_value):
"""
Run *attrib*'s converter -- if it has one -- on *new_value* and return the
result.
.. versionadded:: 20.1.0
"""
c = attrib.converter
if c:
return c(new_value)
return new_value
NO_OP = object()
"""
Sentinel for disabling class-wide *on_setattr* hooks for certain attributes.
Does not work in `pipe` or within lists.
.. versionadded:: 20.1.0
"""

View file

@ -1,20 +0,0 @@
from typing import Any, NewType, NoReturn, TypeVar, cast
from . import Attribute, _OnSetAttrType
_T = TypeVar("_T")
def frozen(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> NoReturn: ...
def pipe(*setters: _OnSetAttrType) -> _OnSetAttrType: ...
def validate(instance: Any, attribute: Attribute[_T], new_value: _T) -> _T: ...
# convert is allowed to return Any, because they can be chained using pipe.
def convert(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> Any: ...
_NoOpType = NewType("_NoOpType", object)
NO_OP: _NoOpType

View file

@ -1,379 +0,0 @@
"""
Commonly useful validators.
"""
from __future__ import absolute_import, division, print_function
import re
from ._make import _AndValidator, and_, attrib, attrs
from .exceptions import NotCallableError
__all__ = [
"and_",
"deep_iterable",
"deep_mapping",
"in_",
"instance_of",
"is_callable",
"matches_re",
"optional",
"provides",
]
@attrs(repr=False, slots=True, hash=True)
class _InstanceOfValidator(object):
type = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not isinstance(value, self.type):
raise TypeError(
"'{name}' must be {type!r} (got {value!r} that is a "
"{actual!r}).".format(
name=attr.name,
type=self.type,
actual=value.__class__,
value=value,
),
attr,
self.type,
value,
)
def __repr__(self):
return "<instance_of validator for type {type!r}>".format(
type=self.type
)
def instance_of(type):
"""
A validator that raises a `TypeError` if the initializer is called
with a wrong type for this particular attribute (checks are performed using
`isinstance` therefore it's also valid to pass a tuple of types).
:param type: The type to check for.
:type type: type or tuple of types
:raises TypeError: With a human readable error message, the attribute
(of type `attr.Attribute`), the expected type, and the value it
got.
"""
return _InstanceOfValidator(type)
@attrs(repr=False, frozen=True, slots=True)
class _MatchesReValidator(object):
regex = attrib()
flags = attrib()
match_func = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.match_func(value):
raise ValueError(
"'{name}' must match regex {regex!r}"
" ({value!r} doesn't)".format(
name=attr.name, regex=self.regex.pattern, value=value
),
attr,
self.regex,
value,
)
def __repr__(self):
return "<matches_re validator for pattern {regex!r}>".format(
regex=self.regex
)
def matches_re(regex, flags=0, func=None):
r"""
A validator that raises `ValueError` if the initializer is called
with a string that doesn't match *regex*.
:param str regex: a regex string to match against
:param int flags: flags that will be passed to the underlying re function
(default 0)
:param callable func: which underlying `re` function to call (options
are `re.fullmatch`, `re.search`, `re.match`, default
is ``None`` which means either `re.fullmatch` or an emulation of
it on Python 2). For performance reasons, they won't be used directly
but on a pre-`re.compile`\ ed pattern.
.. versionadded:: 19.2.0
"""
fullmatch = getattr(re, "fullmatch", None)
valid_funcs = (fullmatch, None, re.search, re.match)
if func not in valid_funcs:
raise ValueError(
"'func' must be one of %s."
% (
", ".join(
sorted(
e and e.__name__ or "None" for e in set(valid_funcs)
)
),
)
)
pattern = re.compile(regex, flags)
if func is re.match:
match_func = pattern.match
elif func is re.search:
match_func = pattern.search
else:
if fullmatch:
match_func = pattern.fullmatch
else:
pattern = re.compile(r"(?:{})\Z".format(regex), flags)
match_func = pattern.match
return _MatchesReValidator(pattern, flags, match_func)
@attrs(repr=False, slots=True, hash=True)
class _ProvidesValidator(object):
interface = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.interface.providedBy(value):
raise TypeError(
"'{name}' must provide {interface!r} which {value!r} "
"doesn't.".format(
name=attr.name, interface=self.interface, value=value
),
attr,
self.interface,
value,
)
def __repr__(self):
return "<provides validator for interface {interface!r}>".format(
interface=self.interface
)
def provides(interface):
"""
A validator that raises a `TypeError` if the initializer is called
with an object that does not provide the requested *interface* (checks are
performed using ``interface.providedBy(value)`` (see `zope.interface
<https://zopeinterface.readthedocs.io/en/latest/>`_).
:param interface: The interface to check for.
:type interface: ``zope.interface.Interface``
:raises TypeError: With a human readable error message, the attribute
(of type `attr.Attribute`), the expected interface, and the
value it got.
"""
return _ProvidesValidator(interface)
@attrs(repr=False, slots=True, hash=True)
class _OptionalValidator(object):
validator = attrib()
def __call__(self, inst, attr, value):
if value is None:
return
self.validator(inst, attr, value)
def __repr__(self):
return "<optional validator for {what} or None>".format(
what=repr(self.validator)
)
def optional(validator):
"""
A validator that makes an attribute optional. An optional attribute is one
which can be set to ``None`` in addition to satisfying the requirements of
the sub-validator.
:param validator: A validator (or a list of validators) that is used for
non-``None`` values.
:type validator: callable or `list` of callables.
.. versionadded:: 15.1.0
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
"""
if isinstance(validator, list):
return _OptionalValidator(_AndValidator(validator))
return _OptionalValidator(validator)
@attrs(repr=False, slots=True, hash=True)
class _InValidator(object):
options = attrib()
def __call__(self, inst, attr, value):
try:
in_options = value in self.options
except TypeError: # e.g. `1 in "abc"`
in_options = False
if not in_options:
raise ValueError(
"'{name}' must be in {options!r} (got {value!r})".format(
name=attr.name, options=self.options, value=value
)
)
def __repr__(self):
return "<in_ validator with options {options!r}>".format(
options=self.options
)
def in_(options):
"""
A validator that raises a `ValueError` if the initializer is called
with a value that does not belong in the options provided. The check is
performed using ``value in options``.
:param options: Allowed options.
:type options: list, tuple, `enum.Enum`, ...
:raises ValueError: With a human readable error message, the attribute (of
type `attr.Attribute`), the expected options, and the value it
got.
.. versionadded:: 17.1.0
"""
return _InValidator(options)
@attrs(repr=False, slots=False, hash=True)
class _IsCallableValidator(object):
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not callable(value):
message = (
"'{name}' must be callable "
"(got {value!r} that is a {actual!r})."
)
raise NotCallableError(
msg=message.format(
name=attr.name, value=value, actual=value.__class__
),
value=value,
)
def __repr__(self):
return "<is_callable validator>"
def is_callable():
"""
A validator that raises a `attr.exceptions.NotCallableError` if the
initializer is called with a value for this particular attribute
that is not callable.
.. versionadded:: 19.1.0
:raises `attr.exceptions.NotCallableError`: With a human readable error
message containing the attribute (`attr.Attribute`) name,
and the value it got.
"""
return _IsCallableValidator()
@attrs(repr=False, slots=True, hash=True)
class _DeepIterable(object):
member_validator = attrib(validator=is_callable())
iterable_validator = attrib(
default=None, validator=optional(is_callable())
)
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.iterable_validator is not None:
self.iterable_validator(inst, attr, value)
for member in value:
self.member_validator(inst, attr, member)
def __repr__(self):
iterable_identifier = (
""
if self.iterable_validator is None
else " {iterable!r}".format(iterable=self.iterable_validator)
)
return (
"<deep_iterable validator for{iterable_identifier}"
" iterables of {member!r}>"
).format(
iterable_identifier=iterable_identifier,
member=self.member_validator,
)
def deep_iterable(member_validator, iterable_validator=None):
"""
A validator that performs deep validation of an iterable.
:param member_validator: Validator to apply to iterable members
:param iterable_validator: Validator to apply to iterable itself
(optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
return _DeepIterable(member_validator, iterable_validator)
@attrs(repr=False, slots=True, hash=True)
class _DeepMapping(object):
key_validator = attrib(validator=is_callable())
value_validator = attrib(validator=is_callable())
mapping_validator = attrib(default=None, validator=optional(is_callable()))
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.mapping_validator is not None:
self.mapping_validator(inst, attr, value)
for key in value:
self.key_validator(inst, attr, key)
self.value_validator(inst, attr, value[key])
def __repr__(self):
return (
"<deep_mapping validator for objects mapping {key!r} to {value!r}>"
).format(key=self.key_validator, value=self.value_validator)
def deep_mapping(key_validator, value_validator, mapping_validator=None):
"""
A validator that performs deep validation of a dictionary.
:param key_validator: Validator to apply to dictionary keys
:param value_validator: Validator to apply to dictionary values
:param mapping_validator: Validator to apply to top-level mapping
attribute (optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
return _DeepMapping(key_validator, value_validator, mapping_validator)

View file

@ -1,68 +0,0 @@
from typing import (
Any,
AnyStr,
Callable,
Container,
Iterable,
List,
Mapping,
Match,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from . import _ValidatorType
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_I = TypeVar("_I", bound=Iterable)
_K = TypeVar("_K")
_V = TypeVar("_V")
_M = TypeVar("_M", bound=Mapping)
# To be more precise on instance_of use some overloads.
# If there are more than 3 items in the tuple then we fall back to Any
@overload
def instance_of(type: Type[_T]) -> _ValidatorType[_T]: ...
@overload
def instance_of(type: Tuple[Type[_T]]) -> _ValidatorType[_T]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2]]
) -> _ValidatorType[Union[_T1, _T2]]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2], Type[_T3]]
) -> _ValidatorType[Union[_T1, _T2, _T3]]: ...
@overload
def instance_of(type: Tuple[type, ...]) -> _ValidatorType[Any]: ...
def provides(interface: Any) -> _ValidatorType[Any]: ...
def optional(
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
) -> _ValidatorType[Optional[_T]]: ...
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
def matches_re(
regex: AnyStr,
flags: int = ...,
func: Optional[
Callable[[AnyStr, AnyStr, int], Optional[Match[AnyStr]]]
] = ...,
) -> _ValidatorType[AnyStr]: ...
def deep_iterable(
member_validator: _ValidatorType[_T],
iterable_validator: Optional[_ValidatorType[_I]] = ...,
) -> _ValidatorType[_I]: ...
def deep_mapping(
key_validator: _ValidatorType[_K],
value_validator: _ValidatorType[_V],
mapping_validator: Optional[_ValidatorType[_M]] = ...,
) -> _ValidatorType[_M]: ...
def is_callable() -> _ValidatorType[_T]: ...

View file

@ -1,85 +0,0 @@
"""
Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more
"""
from __future__ import absolute_import
# Set default logging handler to avoid "No handler found" warnings.
import logging
import warnings
from logging import NullHandler
from . import exceptions
from ._version import __version__
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from .filepost import encode_multipart_formdata
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
from .response import HTTPResponse
from .util.request import make_headers
from .util.retry import Retry
from .util.timeout import Timeout
from .util.url import get_host
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT"
__version__ = __version__
__all__ = (
"HTTPConnectionPool",
"HTTPSConnectionPool",
"PoolManager",
"ProxyManager",
"HTTPResponse",
"Retry",
"Timeout",
"add_stderr_logger",
"connection_from_url",
"disable_warnings",
"encode_multipart_formdata",
"get_host",
"make_headers",
"proxy_from_url",
)
logging.getLogger(__name__).addHandler(NullHandler())
def add_stderr_logger(level=logging.DEBUG):
"""
Helper for quickly adding a StreamHandler to the logger. Useful for
debugging.
Returns the handler after adding it.
"""
# This method needs to be in this __init__.py to get the __name__ correct
# even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(level)
logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler
# ... Clean up.
del NullHandler
# All warning filters *must* be appended unless you're really certain that they
# shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them.
# SecurityWarning's always go off by default.
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning):
"""
Helper for quickly disabling all urllib3 warnings.
"""
warnings.simplefilter("ignore", category)

View file

@ -1,337 +0,0 @@
from __future__ import absolute_import
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
from collections import Mapping, MutableMapping
try:
from threading import RLock
except ImportError: # Platform-specific: No threads available
class RLock:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
from collections import OrderedDict
from .exceptions import InvalidHeader
from .packages import six
from .packages.six import iterkeys, itervalues
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object()
class RecentlyUsedContainer(MutableMapping):
"""
Provides a thread-safe dict-like container which maintains up to
``maxsize`` keys while throwing away the least-recently-used keys beyond
``maxsize``.
:param maxsize:
Maximum number of recent elements to retain.
:param dispose_func:
Every time an item is evicted from the container,
``dispose_func(value)`` is called. Callback which will get called
"""
ContainerCls = OrderedDict
def __init__(self, maxsize=10, dispose_func=None):
self._maxsize = maxsize
self.dispose_func = dispose_func
self._container = self.ContainerCls()
self.lock = RLock()
def __getitem__(self, key):
# Re-insert the item, moving it to the end of the eviction line.
with self.lock:
item = self._container.pop(key)
self._container[key] = item
return item
def __setitem__(self, key, value):
evicted_value = _Null
with self.lock:
# Possibly evict the existing value of 'key'
evicted_value = self._container.get(key, _Null)
self._container[key] = value
# If we didn't evict an existing value, we might have to evict the
# least recently used item from the beginning of the container.
if len(self._container) > self._maxsize:
_key, evicted_value = self._container.popitem(last=False)
if self.dispose_func and evicted_value is not _Null:
self.dispose_func(evicted_value)
def __delitem__(self, key):
with self.lock:
value = self._container.pop(key)
if self.dispose_func:
self.dispose_func(value)
def __len__(self):
with self.lock:
return len(self._container)
def __iter__(self):
raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self):
with self.lock:
# Copy pointers to all values, then wipe the mapping
values = list(itervalues(self._container))
self._container.clear()
if self.dispose_func:
for value in values:
self.dispose_func(value)
def keys(self):
with self.lock:
return list(iterkeys(self._container))
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
when compared case-insensitively.
:param kwargs:
Additional field-value pairs to pass in to ``dict.update``.
A ``dict`` like container for storing HTTP Headers.
Field names are stored and compared case-insensitively in compliance with
RFC 7230. Iteration provides the first case-sensitive key seen for each
case-insensitive pair.
Using ``__setitem__`` syntax overwrites fields that compare equal
case-insensitively in order to maintain ``dict``'s api. For fields that
compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
in a loop.
If multiple fields that are equal case-insensitively are passed to the
constructor or ``.update``, the behavior is undefined and some will be
lost.
>>> headers = HTTPHeaderDict()
>>> headers.add('Set-Cookie', 'foo=bar')
>>> headers.add('set-cookie', 'baz=quxx')
>>> headers['content-length'] = '7'
>>> headers['SET-cookie']
'foo=bar, baz=quxx'
>>> headers['Content-Length']
'7'
"""
def __init__(self, headers=None, **kwargs):
super(HTTPHeaderDict, self).__init__()
self._container = OrderedDict()
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
else:
self.extend(headers)
if kwargs:
self.extend(kwargs)
def __setitem__(self, key, val):
self._container[key.lower()] = [key, val]
return self._container[key.lower()]
def __getitem__(self, key):
val = self._container[key.lower()]
return ", ".join(val[1:])
def __delitem__(self, key):
del self._container[key.lower()]
def __contains__(self, key):
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
(k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other):
return not self.__eq__(other)
if six.PY2: # Python 2
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues
__marker = object()
def __len__(self):
return len(self._container)
def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]
def pop(self, key, default=__marker):
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
"""
# Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel.
try:
value = self[key]
except KeyError:
if default is self.__marker:
raise
return default
else:
del self[key]
return value
def discard(self, key):
try:
del self[key]
except KeyError:
pass
def add(self, key, val):
"""Adds a (name, value) pair, doesn't overwrite the value if it already
exists.
>>> headers = HTTPHeaderDict(foo='bar')
>>> headers.add('Foo', 'baz')
>>> headers['foo']
'bar, baz'
"""
key_lower = key.lower()
new_vals = [key, val]
# Keep the common case aka no item present as fast as possible
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
vals.append(val)
def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__
"""
if len(args) > 1:
raise TypeError(
"extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args))
)
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
elif isinstance(other, Mapping):
for key in other:
self.add(key, other[key])
elif hasattr(other, "keys"):
for key in other.keys():
self.add(key, other[key])
else:
for key, value in other:
self.add(key, value)
for key, value in kwargs.items():
self.add(key, value)
def getlist(self, key, default=__marker):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = self._container[key.lower()]
except KeyError:
if default is self.__marker:
return []
return default
else:
return vals[1:]
# Backwards compatibility for httplib
getheaders = getlist
getallmatchingheaders = getlist
iget = getlist
# Backwards compatibility for http.cookiejar
get_all = getlist
def __repr__(self):
return "%s(%s)" % (type(self).__name__, dict(self.itermerged()))
def _copy_from(self, other):
for key in other:
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
self._container[key.lower()] = [key] + val
def copy(self):
clone = type(self)()
clone._copy_from(self)
return clone
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val
def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[key.lower()]
yield val[0], ", ".join(val[1:])
def items(self):
return list(self.iteritems())
@classmethod
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
obs_fold_continued_leaders = (" ", "\t")
headers = []
for line in message.headers:
if line.startswith(obs_fold_continued_leaders):
if not headers:
# We received a header line that starts with OWS as described
# in RFC-7230 S3.2.4. This indicates a multiline header, but
# there exists no previous header to which we can attach it.
raise InvalidHeader(
"Header continuation with no previous header: %s" % line
)
else:
key, value = headers[-1]
headers[-1] = (key, value + " " + line.strip())
continue
key, value = line.split(":", 1)
headers.append((key, value.strip()))
return cls(headers)

View file

@ -1,2 +0,0 @@
# This file is protected via CODEOWNERS
__version__ = "1.26.6"

View file

@ -1,539 +0,0 @@
from __future__ import absolute_import
import datetime
import logging
import os
import re
import socket
import warnings
from socket import error as SocketError
from socket import timeout as SocketTimeout
from .packages import six
from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection
from .packages.six.moves.http_client import HTTPException # noqa: F401
from .util.proxy import create_proxy_ssl_context
try: # Compiled with SSL?
import ssl
BaseSSLError = ssl.SSLError
except (ImportError, AttributeError): # Platform-specific: No SSL.
ssl = None
class BaseSSLError(BaseException):
pass
try:
# Python 3: not a no-op, we're adding this to the namespace so it can be imported.
ConnectionError = ConnectionError
except NameError:
# Python 2
class ConnectionError(Exception):
pass
try: # Python 3:
# Not a no-op, we're adding this to the namespace so it can be imported.
BrokenPipeError = BrokenPipeError
except NameError: # Python 2:
class BrokenPipeError(Exception):
pass
from ._collections import HTTPHeaderDict # noqa (historical, removed in v2)
from ._version import __version__
from .exceptions import (
ConnectTimeoutError,
NewConnectionError,
SubjectAltNameWarning,
SystemTimeWarning,
)
from .packages.ssl_match_hostname import CertificateError, match_hostname
from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection
from .util.ssl_ import (
assert_fingerprint,
create_urllib3_context,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
)
log = logging.getLogger(__name__)
port_by_scheme = {"http": 80, "https": 443}
# When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2020, 7, 1)
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
class HTTPConnection(_HTTPConnection, object):
"""
Based on :class:`http.client.HTTPConnection` but provides an extra constructor
backwards-compatibility layer between older and newer Pythons.
Additional keyword parameters are used to configure attributes of the connection.
Accepted parameters include:
- ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool`
- ``source_address``: Set the source address for the current connection.
- ``socket_options``: Set specific options on the underlying socket. If not specified, then
defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling
Nagle's algorithm (sets TCP_NODELAY to 1) unless the connection is behind a proxy.
For example, if you wish to enable TCP Keep Alive in addition to the defaults,
you might pass:
.. code-block:: python
HTTPConnection.default_socket_options + [
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
]
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
"""
default_port = port_by_scheme["http"]
#: Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
#: Whether this connection verifies the host's certificate.
is_verified = False
def __init__(self, *args, **kw):
if not six.PY2:
kw.pop("strict", None)
# Pre-set source_address.
self.source_address = kw.get("source_address")
#: The socket options provided by the user. If no options are
#: provided, we use the default options.
self.socket_options = kw.pop("socket_options", self.default_socket_options)
# Proxy options provided by the user.
self.proxy = kw.pop("proxy", None)
self.proxy_config = kw.pop("proxy_config", None)
_HTTPConnection.__init__(self, *args, **kw)
@property
def host(self):
"""
Getter method to remove any trailing dots that indicate the hostname is an FQDN.
In general, SSL certificates don't include the trailing dot indicating a
fully-qualified domain name, and thus, they don't validate properly when
checked against a domain name that includes the dot. In addition, some
servers may not expect to receive the trailing dot when provided.
However, the hostname with trailing dot is critical to DNS resolution; doing a
lookup with the trailing dot will properly only resolve the appropriate FQDN,
whereas a lookup without a trailing dot will search the system's search domain
list. Thus, it's important to keep the original host around for use only in
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
actual TCP connection across which we're going to send HTTP requests).
"""
return self._dns_host.rstrip(".")
@host.setter
def host(self, value):
"""
Setter for the `host` property.
We assume that only urllib3 uses the _dns_host attribute; httplib itself
only uses `host`, and it seems reasonable that other libraries follow suit.
"""
self._dns_host = value
def _new_conn(self):
"""Establish a socket connection and set nodelay settings on it.
:return: New socket connection.
"""
extra_kw = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw["socket_options"] = self.socket_options
try:
conn = connection.create_connection(
(self._dns_host, self.port), self.timeout, **extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except SocketError as e:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
return conn
def _is_using_tunnel(self):
# Google App Engine's httplib does not define _tunnel_host
return getattr(self, "_tunnel_host", None)
def _prepare_conn(self, conn):
self.sock = conn
if self._is_using_tunnel():
# TODO: Fix tunnel so it doesn't depend on self.sock state.
self._tunnel()
# Mark this connection as not reusable
self.auto_open = 0
def connect(self):
conn = self._new_conn()
self._prepare_conn(conn)
def putrequest(self, method, url, *args, **kwargs):
""" """
# Empty docstring because the indentation of CPython's implementation
# is broken but we don't want this method in our documentation.
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
raise ValueError(
"Method cannot contain non-token characters %r (found at least %r)"
% (method, match.group())
)
return _HTTPConnection.putrequest(self, method, url, *args, **kwargs)
def putheader(self, header, *values):
""" """
if not any(isinstance(v, str) and v == SKIP_HEADER for v in values):
_HTTPConnection.putheader(self, header, *values)
elif six.ensure_str(header.lower()) not in SKIPPABLE_HEADERS:
raise ValueError(
"urllib3.util.SKIP_HEADER only supports '%s'"
% ("', '".join(map(str.title, sorted(SKIPPABLE_HEADERS))),)
)
def request(self, method, url, body=None, headers=None):
if headers is None:
headers = {}
else:
# Avoid modifying the headers passed into .request()
headers = headers.copy()
if "user-agent" not in (six.ensure_str(k.lower()) for k in headers):
headers["User-Agent"] = _get_default_user_agent()
super(HTTPConnection, self).request(method, url, body=body, headers=headers)
def request_chunked(self, method, url, body=None, headers=None):
"""
Alternative to the common request method, which sends the
body with chunked encoding and not as one block
"""
headers = headers or {}
header_keys = set([six.ensure_str(k.lower()) for k in headers])
skip_accept_encoding = "accept-encoding" in header_keys
skip_host = "host" in header_keys
self.putrequest(
method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
)
if "user-agent" not in header_keys:
self.putheader("User-Agent", _get_default_user_agent())
for header, value in headers.items():
self.putheader(header, value)
if "transfer-encoding" not in header_keys:
self.putheader("Transfer-Encoding", "chunked")
self.endheaders()
if body is not None:
stringish_types = six.string_types + (bytes,)
if isinstance(body, stringish_types):
body = (body,)
for chunk in body:
if not chunk:
continue
if not isinstance(chunk, bytes):
chunk = chunk.encode("utf8")
len_str = hex(len(chunk))[2:]
to_send = bytearray(len_str.encode())
to_send += b"\r\n"
to_send += chunk
to_send += b"\r\n"
self.send(to_send)
# After the if clause, to always have a closed body
self.send(b"0\r\n\r\n")
class HTTPSConnection(HTTPConnection):
"""
Many of the parameters to this constructor are passed to the underlying SSL
socket by means of :py:func:`urllib3.util.ssl_wrap_socket`.
"""
default_port = port_by_scheme["https"]
cert_reqs = None
ca_certs = None
ca_cert_dir = None
ca_cert_data = None
ssl_version = None
assert_fingerprint = None
tls_in_tls_required = False
def __init__(
self,
host,
port=None,
key_file=None,
cert_file=None,
key_password=None,
strict=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None,
server_hostname=None,
**kw
):
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
self.key_password = key_password
self.ssl_context = ssl_context
self.server_hostname = server_hostname
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = "https"
def set_cert(
self,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
):
"""
This method should only be called once, before the connection is used.
"""
# If cert_reqs is not provided we'll assume CERT_REQUIRED unless we also
# have an SSLContext object in which case we'll use its verify_mode.
if cert_reqs is None:
if self.ssl_context is not None:
cert_reqs = self.ssl_context.verify_mode
else:
cert_reqs = resolve_cert_reqs(None)
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
self.key_password = key_password
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
def connect(self):
# Add certificate verification
conn = self._new_conn()
hostname = self.host
tls_in_tls = False
if self._is_using_tunnel():
if self.tls_in_tls_required:
conn = self._connect_tls_proxy(hostname, conn)
tls_in_tls = True
self.sock = conn
# Calls self._set_hostport(), so self.host is
# self._tunnel_host below.
self._tunnel()
# Mark this connection as not reusable
self.auto_open = 0
# Override the host with the one we're requesting data from.
hostname = self._tunnel_host
server_hostname = hostname
if self.server_hostname is not None:
server_hostname = self.server_hostname
is_time_off = datetime.date.today() < RECENT_DATE
if is_time_off:
warnings.warn(
(
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# Wrap socket using verification with the root certs in
# trusted_root_certs
default_ssl_context = False
if self.ssl_context is None:
default_ssl_context = True
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(self.ssl_version),
cert_reqs=resolve_cert_reqs(self.cert_reqs),
)
context = self.ssl_context
context.verify_mode = resolve_cert_reqs(self.cert_reqs)
# Try to load OS default certs if none are given.
# Works well on Windows (requires Python3.4+)
if (
not self.ca_certs
and not self.ca_cert_dir
and not self.ca_cert_data
and default_ssl_context
and hasattr(context, "load_default_certs")
):
context.load_default_certs()
self.sock = ssl_wrap_socket(
sock=conn,
keyfile=self.key_file,
certfile=self.cert_file,
key_password=self.key_password,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
ca_cert_data=self.ca_cert_data,
server_hostname=server_hostname,
ssl_context=context,
tls_in_tls=tls_in_tls,
)
# If we're using all defaults and the connection
# is TLSv1 or TLSv1.1 we throw a DeprecationWarning
# for the host.
if (
default_ssl_context
and self.ssl_version is None
and hasattr(self.sock, "version")
and self.sock.version() in {"TLSv1", "TLSv1.1"}
):
warnings.warn(
"Negotiating TLSv1/TLSv1.1 by default is deprecated "
"and will be disabled in urllib3 v2.0.0. Connecting to "
"'%s' with '%s' can be enabled by explicitly opting-in "
"with 'ssl_version'" % (self.host, self.sock.version()),
DeprecationWarning,
)
if self.assert_fingerprint:
assert_fingerprint(
self.sock.getpeercert(binary_form=True), self.assert_fingerprint
)
elif (
context.verify_mode != ssl.CERT_NONE
and not getattr(context, "check_hostname", False)
and self.assert_hostname is not False
):
# While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames.
cert = self.sock.getpeercert()
if not cert.get("subjectAltName", ()):
warnings.warn(
(
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
"`commonName` for now. This feature is being removed by major browsers and "
"deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
"for details.)".format(hostname)
),
SubjectAltNameWarning,
)
_match_hostname(cert, self.assert_hostname or server_hostname)
self.is_verified = (
context.verify_mode == ssl.CERT_REQUIRED
or self.assert_fingerprint is not None
)
def _connect_tls_proxy(self, hostname, conn):
"""
Establish a TLS connection to the proxy using the provided SSL context.
"""
proxy_config = self.proxy_config
ssl_context = proxy_config.ssl_context
if ssl_context:
# If the user provided a proxy context, we assume CA and client
# certificates have already been set
return ssl_wrap_socket(
sock=conn,
server_hostname=hostname,
ssl_context=ssl_context,
)
ssl_context = create_proxy_ssl_context(
self.ssl_version,
self.cert_reqs,
self.ca_certs,
self.ca_cert_dir,
self.ca_cert_data,
)
# By default urllib3's SSLContext disables `check_hostname` and uses
# a custom check. For proxies we're good with relying on the default
# verification.
ssl_context.check_hostname = True
# If no cert was provided, use only the default options for server
# certificate validation
return ssl_wrap_socket(
sock=conn,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
ca_cert_data=self.ca_cert_data,
server_hostname=hostname,
ssl_context=ssl_context,
)
def _match_hostname(cert, asserted_hostname):
try:
match_hostname(cert, asserted_hostname)
except CertificateError as e:
log.warning(
"Certificate did not match expected hostname: %s. Certificate: %s",
asserted_hostname,
cert,
)
# Add cert to exception and reraise so client code can inspect
# the cert when catching the exception, if they want to
e._peer_cert = cert
raise
def _get_default_user_agent():
return "python-urllib3/%s" % __version__
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
if not ssl:
HTTPSConnection = DummyConnection # noqa: F811
VerifiedHTTPSConnection = HTTPSConnection

File diff suppressed because it is too large Load diff

View file

@ -1,36 +0,0 @@
"""
This module provides means to detect the App Engine environment.
"""
import os
def is_appengine():
return is_local_appengine() or is_prod_appengine()
def is_appengine_sandbox():
"""Reports if the app is running in the first generation sandbox.
The second generation runtimes are technically still in a sandbox, but it
is much less restrictive, so generally you shouldn't need to check for it.
see https://cloud.google.com/appengine/docs/standard/runtimes
"""
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
def is_local_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Development/")
def is_prod_appengine():
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Google App Engine/")
def is_prod_appengine_mvms():
"""Deprecated."""
return False

View file

@ -1,519 +0,0 @@
"""
This module uses ctypes to bind a whole bunch of functions and constants from
SecureTransport. The goal here is to provide the low-level API to
SecureTransport. These are essentially the C-level functions and constants, and
they're pretty gross to work with.
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
import platform
from ctypes import (
CDLL,
CFUNCTYPE,
POINTER,
c_bool,
c_byte,
c_char_p,
c_int32,
c_long,
c_size_t,
c_uint32,
c_ulong,
c_void_p,
)
from ctypes.util import find_library
from urllib3.packages.six import raise_from
if platform.system() != "Darwin":
raise ImportError("Only macOS is supported")
version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8):
raise OSError(
"Only OS X 10.8 and newer are supported, not %s.%s"
% (version_info[0], version_info[1])
)
def load_cdll(name, macos10_16_path):
"""Loads a CDLL by name, falling back to known path on 10.16+"""
try:
# Big Sur is technically 11 but we use 10.16 due to the Big Sur
# beta being labeled as 10.16.
if version_info >= (10, 16):
path = macos10_16_path
else:
path = find_library(name)
if not path:
raise OSError # Caught and reraised as 'ImportError'
return CDLL(path, use_errno=True)
except OSError:
raise_from(ImportError("The library %s failed to load" % name), None)
Security = load_cdll(
"Security", "/System/Library/Frameworks/Security.framework/Security"
)
CoreFoundation = load_cdll(
"CoreFoundation",
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
)
Boolean = c_bool
CFIndex = c_long
CFStringEncoding = c_uint32
CFData = c_void_p
CFString = c_void_p
CFArray = c_void_p
CFMutableArray = c_void_p
CFDictionary = c_void_p
CFError = c_void_p
CFType = c_void_p
CFTypeID = c_ulong
CFTypeRef = POINTER(CFType)
CFAllocatorRef = c_void_p
OSStatus = c_int32
CFDataRef = POINTER(CFData)
CFStringRef = POINTER(CFString)
CFArrayRef = POINTER(CFArray)
CFMutableArrayRef = POINTER(CFMutableArray)
CFDictionaryRef = POINTER(CFDictionary)
CFArrayCallBacks = c_void_p
CFDictionaryKeyCallBacks = c_void_p
CFDictionaryValueCallBacks = c_void_p
SecCertificateRef = POINTER(c_void_p)
SecExternalFormat = c_uint32
SecExternalItemType = c_uint32
SecIdentityRef = POINTER(c_void_p)
SecItemImportExportFlags = c_uint32
SecItemImportExportKeyParameters = c_void_p
SecKeychainRef = POINTER(c_void_p)
SSLProtocol = c_uint32
SSLCipherSuite = c_uint32
SSLContextRef = POINTER(c_void_p)
SecTrustRef = POINTER(c_void_p)
SSLConnectionRef = c_uint32
SecTrustResultType = c_uint32
SecTrustOptionFlags = c_uint32
SSLProtocolSide = c_uint32
SSLConnectionType = c_uint32
SSLSessionOption = c_uint32
try:
Security.SecItemImport.argtypes = [
CFDataRef,
CFStringRef,
POINTER(SecExternalFormat),
POINTER(SecExternalItemType),
SecItemImportExportFlags,
POINTER(SecItemImportExportKeyParameters),
SecKeychainRef,
POINTER(CFArrayRef),
]
Security.SecItemImport.restype = OSStatus
Security.SecCertificateGetTypeID.argtypes = []
Security.SecCertificateGetTypeID.restype = CFTypeID
Security.SecIdentityGetTypeID.argtypes = []
Security.SecIdentityGetTypeID.restype = CFTypeID
Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef,
SecCertificateRef,
POINTER(SecIdentityRef),
]
Security.SecIdentityCreateWithCertificate.restype = OSStatus
Security.SecKeychainCreate.argtypes = [
c_char_p,
c_uint32,
c_void_p,
Boolean,
c_void_p,
POINTER(SecKeychainRef),
]
Security.SecKeychainCreate.restype = OSStatus
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [
CFDataRef,
CFDictionaryRef,
POINTER(CFArrayRef),
]
Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [SSLContextRef]
Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [SSLContextRef]
Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
c_size_t,
]
Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t),
]
Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef,
POINTER(SSLProtocol),
]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [
CFAllocatorRef,
SSLProtocolSide,
SSLConnectionType,
]
Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMin.restype = OSStatus
Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMax.restype = OSStatus
try:
Security.SSLSetALPNProtocols.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetALPNProtocols.restype = OSStatus
except AttributeError:
# Supported only in 10.12+
pass
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SSLReadFunc = SSLReadFunc
Security.SSLWriteFunc = SSLWriteFunc
Security.SSLContextRef = SSLContextRef
Security.SSLProtocol = SSLProtocol
Security.SSLCipherSuite = SSLCipherSuite
Security.SecIdentityRef = SecIdentityRef
Security.SecKeychainRef = SecKeychainRef
Security.SecTrustRef = SecTrustRef
Security.SecTrustResultType = SecTrustResultType
Security.SecExternalFormat = SecExternalFormat
Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, "kSecImportExportPassphrase"
)
Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, "kSecImportItemIdentity"
)
# CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
CoreFoundation.CFRetain.restype = CFTypeRef
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef,
c_char_p,
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFDictionaryCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
POINTER(CFTypeRef),
CFIndex,
CFDictionaryKeyCallBacks,
CFDictionaryValueCallBacks,
]
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [
CFAllocatorRef,
POINTER(CFTypeRef),
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreate.restype = CFArrayRef
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef,
CFIndex,
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
)
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef
CoreFoundation.CFArrayRef = CFArrayRef
CoreFoundation.CFStringRef = CFStringRef
CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError):
raise ImportError("Error initializing ctypes")
class CFConst(object):
"""
A class object that acts as essentially a namespace for CoreFoundation
constants.
"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
class SecurityConst(object):
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
kSSLProtocol3 = 2
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999
kSSLClientSide = 1
kSSLStreamType = 0
kSecFormatPEMSequence = 10
kSecTrustResultInvalid = 0
kSecTrustResultProceed = 1
# This gap is present on purpose: this was kSecTrustResultConfirm, which
# is deprecated.
kSecTrustResultDeny = 3
kSecTrustResultUnspecified = 4
kSecTrustResultRecoverableTrustFailure = 5
kSecTrustResultFatalTrustFailure = 6
kSecTrustResultOtherError = 7
errSSLProtocol = -9800
errSSLWouldBlock = -9803
errSSLClosedGraceful = -9805
errSSLClosedNoNotify = -9816
errSSLClosedAbort = -9806
errSSLXCertChainInvalid = -9807
errSSLCrypto = -9809
errSSLInternal = -9810
errSSLCertExpired = -9814
errSSLCertNotYetValid = -9815
errSSLUnknownRootCert = -9812
errSSLNoRootCert = -9813
errSSLHostNameMismatch = -9843
errSSLPeerHandshakeFail = -9824
errSSLPeerUserCancelled = -9839
errSSLWeakPeerEphemeralDHKey = -9850
errSSLServerAuthCompleted = -9841
errSSLRecordOverflow = -9847
errSecVerifyFailed = -67808
errSecNoTrustSettings = -25263
errSecItemNotFound = -25300
errSecInvalidTrustSettings = -25262
# Cipher suites. We only pick the ones our default cipher string allows.
# Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8
TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F
TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014
TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D
TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D
TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C
TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F
TLS_AES_128_GCM_SHA256 = 0x1301
TLS_AES_256_GCM_SHA384 = 0x1302
TLS_AES_128_CCM_8_SHA256 = 0x1305
TLS_AES_128_CCM_SHA256 = 0x1304

View file

@ -1,396 +0,0 @@
"""
Low-level helpers for the SecureTransport bindings.
These are Python functions that are not directly related to the high-level APIs
but are necessary to get them to work. They include a whole bunch of low-level
CoreFoundation messing about and memory management. The concerns in this module
are almost entirely about trying to avoid memory leaks and providing
appropriate and useful assistance to the higher-level code.
"""
import base64
import ctypes
import itertools
import os
import re
import ssl
import struct
import tempfile
from .bindings import CFConst, CoreFoundation, Security
# This regular expression is used to grab PEM data out of a PEM bundle.
_PEM_CERTS_RE = re.compile(
b"-----BEGIN CERTIFICATE-----\n(.*?)\n-----END CERTIFICATE-----", re.DOTALL
)
def _cf_data_from_bytes(bytestring):
"""
Given a bytestring, create a CFData object from it. This CFData object must
be CFReleased by the caller.
"""
return CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, bytestring, len(bytestring)
)
def _cf_dictionary_from_tuples(tuples):
"""
Given a list of Python tuples, create an associated CFDictionary.
"""
dictionary_size = len(tuples)
# We need to get the dictionary keys and values out in the same order.
keys = (t[0] for t in tuples)
values = (t[1] for t in tuples)
cf_keys = (CoreFoundation.CFTypeRef * dictionary_size)(*keys)
cf_values = (CoreFoundation.CFTypeRef * dictionary_size)(*values)
return CoreFoundation.CFDictionaryCreate(
CoreFoundation.kCFAllocatorDefault,
cf_keys,
cf_values,
dictionary_size,
CoreFoundation.kCFTypeDictionaryKeyCallBacks,
CoreFoundation.kCFTypeDictionaryValueCallBacks,
)
def _cfstr(py_bstr):
"""
Given a Python binary data, create a CFString.
The string must be CFReleased by the caller.
"""
c_str = ctypes.c_char_p(py_bstr)
cf_str = CoreFoundation.CFStringCreateWithCString(
CoreFoundation.kCFAllocatorDefault,
c_str,
CFConst.kCFStringEncodingUTF8,
)
return cf_str
def _create_cfstring_array(lst):
"""
Given a list of Python binary data, create an associated CFMutableArray.
The array must be CFReleased by the caller.
Raises an ssl.SSLError on failure.
"""
cf_arr = None
try:
cf_arr = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cf_arr:
raise MemoryError("Unable to allocate memory!")
for item in lst:
cf_str = _cfstr(item)
if not cf_str:
raise MemoryError("Unable to allocate memory!")
try:
CoreFoundation.CFArrayAppendValue(cf_arr, cf_str)
finally:
CoreFoundation.CFRelease(cf_str)
except BaseException as e:
if cf_arr:
CoreFoundation.CFRelease(cf_arr)
raise ssl.SSLError("Unable to allocate array: %s" % (e,))
return cf_arr
def _cf_string_to_unicode(value):
"""
Creates a Unicode string from a CFString object. Used entirely for error
reporting.
Yes, it annoys me quite a lot that this function is this complex.
"""
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
string = CoreFoundation.CFStringGetCStringPtr(
value_as_void_p, CFConst.kCFStringEncodingUTF8
)
if string is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode("utf-8")
return string
def _assert_no_error(error, exception_class=None):
"""
Checks the return code and throws an exception if there is an error to
report
"""
if error == 0:
return
cf_error_string = Security.SecCopyErrorMessageString(error, None)
output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u"":
output = u"OSStatus %s" % error
if exception_class is None:
exception_class = ssl.SSLError
raise exception_class(output)
def _cert_array_from_pem(pem_bundle):
"""
Given a bundle of certs in PEM format, turns them into a CFArray of certs
that can be used to validate a cert chain.
"""
# Normalize the PEM bundle's line endings.
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
der_certs = [
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
]
if not der_certs:
raise ssl.SSLError("No root certificates specified")
cert_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cert_array:
raise ssl.SSLError("Unable to allocate memory!")
try:
for der_bytes in der_certs:
certdata = _cf_data_from_bytes(der_bytes)
if not certdata:
raise ssl.SSLError("Unable to allocate memory!")
cert = Security.SecCertificateCreateWithData(
CoreFoundation.kCFAllocatorDefault, certdata
)
CoreFoundation.CFRelease(certdata)
if not cert:
raise ssl.SSLError("Unable to build cert object!")
CoreFoundation.CFArrayAppendValue(cert_array, cert)
CoreFoundation.CFRelease(cert)
except Exception:
# We need to free the array before the exception bubbles further.
# We only want to do that if an error occurs: otherwise, the caller
# should free.
CoreFoundation.CFRelease(cert_array)
return cert_array
def _is_cert(item):
"""
Returns True if a given CFTypeRef is a certificate.
"""
expected = Security.SecCertificateGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected
def _is_identity(item):
"""
Returns True if a given CFTypeRef is an identity.
"""
expected = Security.SecIdentityGetTypeID()
return CoreFoundation.CFGetTypeID(item) == expected
def _temporary_keychain():
"""
This function creates a temporary Mac keychain that we can use to work with
credentials. This keychain uses a one-time password and a temporary file to
store the data. We expect to have one keychain per socket. The returned
SecKeychainRef must be freed by the caller, including calling
SecKeychainDelete.
Returns a tuple of the SecKeychainRef and the path to the temporary
directory that contains it.
"""
# Unfortunately, SecKeychainCreate requires a path to a keychain. This
# means we cannot use mkstemp to use a generic temporary file. Instead,
# we're going to create a temporary directory and a filename to use there.
# This filename will be 8 random bytes expanded into base64. We also need
# some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes.
random_bytes = os.urandom(40)
filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself.
keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate(
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
)
_assert_no_error(status)
# Having created the keychain, we want to pass it off to the caller.
return keychain, tempdirectory
def _load_items_from_file(keychain, path):
"""
Given a single file, loads all the trust objects from it into arrays and
the keychain.
Returns a tuple of lists: the first list is a list of identities, the
second a list of certs.
"""
certificates = []
identities = []
result_array = None
with open(path, "rb") as f:
raw_filedata = f.read()
try:
filedata = CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
)
result_array = CoreFoundation.CFArrayRef()
result = Security.SecItemImport(
filedata, # cert data
None, # Filename, leaving it out for now
None, # What the type of the file is, we don't care
None, # what's in the file, we don't care
0, # import flags
None, # key params, can include passphrase in the future
keychain, # The keychain to insert into
ctypes.byref(result_array), # Results
)
_assert_no_error(result)
# A CFArray is not very useful to us as an intermediary
# representation, so we are going to extract the objects we want
# and then free the array. We don't need to keep hold of keys: the
# keychain already has them!
result_count = CoreFoundation.CFArrayGetCount(result_array)
for index in range(result_count):
item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
item = ctypes.cast(item, CoreFoundation.CFTypeRef)
if _is_cert(item):
CoreFoundation.CFRetain(item)
certificates.append(item)
elif _is_identity(item):
CoreFoundation.CFRetain(item)
identities.append(item)
finally:
if result_array:
CoreFoundation.CFRelease(result_array)
CoreFoundation.CFRelease(filedata)
return (identities, certificates)
def _load_client_cert_chain(keychain, *paths):
"""
Load certificates and maybe keys from a number of files. Has the end goal
of returning a CFArray containing one SecIdentityRef, and then zero or more
SecCertificateRef objects, suitable for use as a client certificate trust
chain.
"""
# Ok, the strategy.
#
# This relies on knowing that macOS will not give you a SecIdentityRef
# unless you have imported a key into a keychain. This is a somewhat
# artificial limitation of macOS (for example, it doesn't necessarily
# affect iOS), but there is nothing inside Security.framework that lets you
# get a SecIdentityRef without having a key in a keychain.
#
# So the policy here is we take all the files and iterate them in order.
# Each one will use SecItemImport to have one or more objects loaded from
# it. We will also point at a keychain that macOS can use to work with the
# private key.
#
# Once we have all the objects, we'll check what we actually have. If we
# already have a SecIdentityRef in hand, fab: we'll use that. Otherwise,
# we'll take the first certificate (which we assume to be our leaf) and
# ask the keychain to give us a SecIdentityRef with that cert's associated
# key.
#
# We'll then return a CFArray containing the trust chain: one
# SecIdentityRef and then zero-or-more SecCertificateRef objects. The
# responsibility for freeing this CFArray will be with the caller. This
# CFArray must remain alive for the entire connection, so in practice it
# will be stored with a single SSLSocket, along with the reference to the
# keychain.
certificates = []
identities = []
# Filter out bad paths.
paths = (path for path in paths if path)
try:
for file_path in paths:
new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities)
certificates.extend(new_certs)
# Ok, we have everything. The question is: do we have an identity? If
# not, we want to grab one from the first cert we have.
if not identities:
new_identity = Security.SecIdentityRef()
status = Security.SecIdentityCreateWithCertificate(
keychain, certificates[0], ctypes.byref(new_identity)
)
_assert_no_error(status)
identities.append(new_identity)
# We now want to release the original certificate, as we no longer
# need it.
CoreFoundation.CFRelease(certificates.pop(0))
# We now need to build a new CFArray that holds the trust chain.
trust_chain = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
for item in itertools.chain(identities, certificates):
# ArrayAppendValue does a CFRetain on the item. That's fine,
# because the finally block will release our other refs to them.
CoreFoundation.CFArrayAppendValue(trust_chain, item)
return trust_chain
finally:
for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj)
TLS_PROTOCOL_VERSIONS = {
"SSLv2": (0, 2),
"SSLv3": (3, 0),
"TLSv1": (3, 1),
"TLSv1.1": (3, 2),
"TLSv1.2": (3, 3),
}
def _build_tls_unknown_ca_alert(version):
"""
Builds a TLS alert record for an unknown CA.
"""
ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
severity_fatal = 0x02
description_unknown_ca = 0x30
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
msg_len = len(msg)
record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record

View file

@ -1,314 +0,0 @@
"""
This module provides a pool manager that uses Google App Engine's
`URLFetch Service <https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Example usage::
from urllib3 import PoolManager
from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox
if is_appengine_sandbox():
# AppEngineManager uses AppEngine's URLFetch API behind the scenes
http = AppEngineManager()
else:
# PoolManager uses a socket-level API behind the scenes
http = PoolManager()
r = http.request('GET', 'https://google.com/')
There are `limitations <https://cloud.google.com/appengine/docs/python/\
urlfetch/#Python_Quotas_and_limits>`_ to the URLFetch service and it may not be
the best choice for your application. There are three options for using
urllib3 on Google App Engine:
1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is
cost-effective in many circumstances as long as your usage is within the
limitations.
2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets.
Sockets also have `limitations and restrictions
<https://cloud.google.com/appengine/docs/python/sockets/\
#limitations-and-restrictions>`_ and have a lower free quota than URLFetch.
To use sockets, be sure to specify the following in your ``app.yaml``::
env_variables:
GAE_USE_SOCKETS_HTTPLIB : 'true'
3. If you are using `App Engine Flexible
<https://cloud.google.com/appengine/docs/flexible/>`_, you can use the standard
:class:`PoolManager` without any configuration or special environment variables.
"""
from __future__ import absolute_import
import io
import logging
import warnings
from ..exceptions import (
HTTPError,
HTTPWarning,
MaxRetryError,
ProtocolError,
SSLError,
TimeoutError,
)
from ..packages.six.moves.urllib.parse import urljoin
from ..request import RequestMethods
from ..response import HTTPResponse
from ..util.retry import Retry
from ..util.timeout import Timeout
from . import _appengine_environ
try:
from google.appengine.api import urlfetch
except ImportError:
urlfetch = None
log = logging.getLogger(__name__)
class AppEnginePlatformWarning(HTTPWarning):
pass
class AppEnginePlatformError(HTTPError):
pass
class AppEngineManager(RequestMethods):
"""
Connection manager for Google App Engine sandbox applications.
This manager uses the URLFetch service directly instead of using the
emulated httplib, and is subject to URLFetch limitations as described in
the App Engine documentation `here
<https://cloud.google.com/appengine/docs/python/urlfetch>`_.
Notably it will raise an :class:`AppEnginePlatformError` if:
* URLFetch is not available.
* If you attempt to use this on App Engine Flexible, as full socket
support is available.
* If a request size is more than 10 megabytes.
* If a response size is more than 32 megabytes.
* If you use an unsupported request method such as OPTIONS.
Beyond those cases, it will raise normal urllib3 errors.
"""
def __init__(
self,
headers=None,
retries=None,
validate_certificate=True,
urlfetch_retries=True,
):
if not urlfetch:
raise AppEnginePlatformError(
"URLFetch is not available in this environment."
)
warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.io/en/1.26.x/reference/urllib3.contrib.html.",
AppEnginePlatformWarning,
)
RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate
self.urlfetch_retries = urlfetch_retries
self.retries = retries or Retry.DEFAULT
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Return False to re-raise any potential exceptions
return False
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw
):
retries = self._get_retries(retries, redirect)
try:
follow_redirects = redirect and retries.redirect != 0 and retries.total
response = urlfetch.fetch(
url,
payload=body,
method=method,
headers=headers or {},
allow_truncated=False,
follow_redirects=self.urlfetch_retries and follow_redirects,
deadline=self._get_absolute_timeout(timeout),
validate_certificate=self.validate_certificate,
)
except urlfetch.DeadlineExceededError as e:
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if "too large" in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.",
e,
)
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.",
e,
)
except urlfetch.SSLCertificateError as e:
raise SSLError(e)
except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e
)
http_response = self._urlfetch_response_to_http_response(
response, retries=retries, **response_kw
)
# Handle redirect?
redirect_location = redirect and http_response.get_redirect_location()
if redirect_location:
# Check for redirect response
if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
else:
if http_response.status == 303:
method = "GET"
try:
retries = retries.increment(
method, url, response=http_response, _pool=self
)
except MaxRetryError:
if retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
return http_response
retries.sleep_for_retry(http_response)
log.debug("Redirecting %s -> %s", url, redirect_location)
redirect_url = urljoin(url, redirect_location)
return self.urlopen(
method,
redirect_url,
body,
headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
# Check if we should retry the HTTP response.
has_retry_after = bool(http_response.getheader("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment(method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url)
retries.sleep(http_response)
return self.urlopen(
method,
url,
body=body,
headers=headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
return http_response
def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == "deflate":
del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content,
# so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",")
encodings.remove("chunked")
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
original_response = HTTPResponse(
# In order for decoding to work, we must present the content as
# a file-like object.
body=io.BytesIO(urlfetch_resp.content),
msg=urlfetch_resp.header_msg,
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
**response_kw
)
return HTTPResponse(
body=io.BytesIO(urlfetch_resp.content),
headers=urlfetch_resp.headers,
status=urlfetch_resp.status_code,
original_response=original_response,
**response_kw
)
def _get_absolute_timeout(self, timeout):
if timeout is Timeout.DEFAULT_TIMEOUT:
return None # Defer to URLFetch's default.
if isinstance(timeout, Timeout):
if timeout._read is not None or timeout._connect is not None:
warnings.warn(
"URLFetch does not support granular timeout settings, "
"reverting to total or default URLFetch timeout.",
AppEnginePlatformWarning,
)
return timeout.total
return timeout
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning,
)
return retries
# Alias methods from _appengine_environ to maintain public API interface.
is_appengine = _appengine_environ.is_appengine
is_appengine_sandbox = _appengine_environ.is_appengine_sandbox
is_local_appengine = _appengine_environ.is_local_appengine
is_prod_appengine = _appengine_environ.is_prod_appengine
is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms

View file

@ -1,130 +0,0 @@
"""
NTLM authenticating pool, contributed by erikcederstran
Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10
"""
from __future__ import absolute_import
import warnings
from logging import getLogger
from ntlm import ntlm
from .. import HTTPSConnectionPool
from ..packages.six.moves.http_client import HTTPSConnection
warnings.warn(
"The 'urllib3.contrib.ntlmpool' module is deprecated and will be removed "
"in urllib3 v2.0 release, urllib3 is not able to support it properly due "
"to reasons listed in issue: https://github.com/urllib3/urllib3/issues/2282. "
"If you are a user of this module please comment in the mentioned issue.",
DeprecationWarning,
)
log = getLogger(__name__)
class NTLMConnectionPool(HTTPSConnectionPool):
"""
Implements an NTLM authentication version of an urllib3 connection pool
"""
scheme = "https"
def __init__(self, user, pw, authurl, *args, **kwargs):
"""
authurl is a random URL on the server that is protected by NTLM.
user is the Windows user, probably in the DOMAIN\\username format.
pw is the password for the user.
"""
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
self.authurl = authurl
self.rawuser = user
user_parts = user.split("\\", 1)
self.domain = user_parts[0].upper()
self.user = user_parts[1]
self.pw = pw
def _new_conn(self):
# Performs the NTLM handshake that secures the connection. The socket
# must be kept open while requests are performed.
self.num_connections += 1
log.debug(
"Starting NTLM HTTPS connection no. %d: https://%s%s",
self.num_connections,
self.host,
self.authurl,
)
headers = {"Connection": "Keep-Alive"}
req_header = "Authorization"
resp_header = "www-authenticate"
conn = HTTPSConnection(host=self.host, port=self.port)
# Send negotiation message
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
self.rawuser
)
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
reshdr = dict(res.getheaders())
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", reshdr)
log.debug("Response data: %s [...]", res.read(100))
# Remove the reference to the socket, so that it can not be closed by
# the response object (we want to keep the socket open)
res.fp = None
# Server should respond with a challenge message
auth_header_values = reshdr[resp_header].split(", ")
auth_header_value = None
for s in auth_header_values:
if s[:5] == "NTLM ":
auth_header_value = s[5:]
if auth_header_value is None:
raise Exception(
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
)
# Send authentication message
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
auth_header_value
)
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
)
headers[req_header] = "NTLM %s" % auth_msg
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", dict(res.getheaders()))
log.debug("Response data: %s [...]", res.read()[:100])
if res.status != 200:
if res.status == 401:
raise Exception("Server rejected request: wrong username or password")
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
res.fp = None
log.debug("Connection established")
return conn
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=3,
redirect=True,
assert_same_host=True,
):
if headers is None:
headers = {}
headers["Connection"] = "Keep-Alive"
return super(NTLMConnectionPool, self).urlopen(
method, url, body, headers, retries, redirect, assert_same_host
)

View file

@ -1,511 +0,0 @@
"""
TLS with SNI_-support for Python 2. Follow these instructions if you would
like to verify TLS certificates in Python 2. Note, the default libraries do
*not* do certificate checking; you need to do additional work to validate
certificates yourself.
This needs the following packages installed:
* `pyOpenSSL`_ (tested with 16.0.0)
* `cryptography`_ (minimum 1.3.4, from pyopenssl)
* `idna`_ (minimum 2.0, from cryptography)
However, pyopenssl depends on cryptography, which depends on idna, so while we
use all three directly here we end up having relatively few packages required.
You can install them with the following command:
.. code-block:: bash
$ python -m pip install pyopenssl cryptography idna
To activate certificate checking, call
:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
before you begin making HTTP requests. This can be done in a ``sitecustomize``
module, or at any other time before your application begins using ``urllib3``,
like this:
.. code-block:: python
try:
import urllib3.contrib.pyopenssl
urllib3.contrib.pyopenssl.inject_into_urllib3()
except ImportError:
pass
Now you can use :mod:`urllib3` as you normally would, and it will support SNI
when the required modules are installed.
Activating this module also has the positive side effect of disabling SSL/TLS
compression in Python 2 (see `CRIME attack`_).
.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
.. _pyopenssl: https://www.pyopenssl.org
.. _cryptography: https://cryptography.io
.. _idna: https://github.com/kjd/idna
"""
from __future__ import absolute_import
import OpenSSL.SSL
from cryptography import x509
from cryptography.hazmat.backends.openssl import backend as openssl_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate
try:
from cryptography.x509 import UnsupportedExtension
except ImportError:
# UnsupportedExtension is gone in cryptography >= 2.1.0
class UnsupportedExtension(Exception):
pass
from io import BytesIO
from socket import error as SocketError
from socket import timeout
try: # Platform-specific: Python 2
from socket import _fileobject
except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
import logging
import ssl
import sys
from .. import util
from ..packages import six
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works.
HAS_SNI = True
# Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = {
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD,
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
}
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
_stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext
log = logging.getLogger(__name__)
def inject_into_urllib3():
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met()
util.SSLContext = PyOpenSSLContext
util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_PYOPENSSL = True
util.ssl_.IS_PYOPENSSL = True
def extract_from_urllib3():
"Undo monkey-patching by :func:`inject_into_urllib3`."
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_PYOPENSSL = False
util.ssl_.IS_PYOPENSSL = False
def _validate_dependencies_met():
"""
Verifies that PyOpenSSL's package-level dependencies have been met.
Throws `ImportError` if they are not met.
"""
# Method added in `cryptography==1.1`; not available in older versions
from cryptography.x509.extensions import Extensions
if getattr(Extensions, "get_extension_for_class", None) is None:
raise ImportError(
"'cryptography' module missing required functionality. "
"Try upgrading to v1.3.4 or newer."
)
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
# attribute is only present on those versions.
from OpenSSL.crypto import X509
x509 = X509()
if getattr(x509, "_x509", None) is None:
raise ImportError(
"'pyOpenSSL' module missing required functionality. "
"Try upgrading to v0.14 or newer."
)
def _dnsname_to_stdlib(name):
"""
Converts a dNSName SubjectAlternativeName field to the form used by the
standard library on the given Python version.
Cryptography produces a dNSName as a unicode string that was idna-decoded
from ASCII bytes. We need to idna-encode that string to get it back, and
then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
If the name cannot be idna-encoded then we return None signalling that
the name given should be skipped.
"""
def idna_encode(name):
"""
Borrowed wholesale from the Python Cryptography Project. It turns out
that we can't just safely call `idna.encode`: it can explode for
wildcard names. This avoids that problem.
"""
import idna
try:
for prefix in [u"*.", u"."]:
if name.startswith(prefix):
name = name[len(prefix) :]
return prefix.encode("ascii") + idna.encode(name)
return idna.encode(name)
except idna.core.IDNAError:
return None
# Don't send IPv6 addresses through the IDNA encoder.
if ":" in name:
return name
name = idna_encode(name)
if name is None:
return None
elif sys.version_info >= (3, 0):
name = name.decode("utf-8")
return name
def get_subj_alt_name(peer_cert):
"""
Given an PyOpenSSL certificate, provides all the subject alternative names.
"""
# Pass the cert to cryptography, which has much better APIs for this.
if hasattr(peer_cert, "to_cryptography"):
cert = peer_cert.to_cryptography()
else:
# This is technically using private APIs, but should work across all
# relevant versions before PyOpenSSL got a proper API for this.
cert = _Certificate(openssl_backend, peer_cert._x509)
# We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python)
try:
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
except x509.ExtensionNotFound:
# No such extension, return the empty list.
return []
except (
x509.DuplicateExtension,
UnsupportedExtension,
x509.UnsupportedGeneralNameType,
UnicodeError,
) as e:
# A problem has been found with the quality of the certificate. Assume
# no SAN field is present.
log.warning(
"A problem was encountered with the certificate that prevented "
"urllib3 from finding the SubjectAlternativeName field. This can "
"affect certificate validation. The error was %s",
e,
)
return []
# We want to return dNSName and iPAddress fields. We need to cast the IPs
# back to strings because the match_hostname function wants them as
# strings.
# Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
# decoded. This is pretty frustrating, but that's what the standard library
# does with certificates, and so we need to attempt to do the same.
# We also want to skip over names which cannot be idna encoded.
names = [
("DNS", name)
for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
if name is not None
]
names.extend(
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
)
return names
class WrappedSocket(object):
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
collector of pypy.
"""
def __init__(self, connection, socket, suppress_ragged_eofs=True):
self.connection = connection
self.socket = socket
self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0
self._closed = False
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
self._makefile_refs -= 1
if self._closed:
self.close()
def recv(self, *args, **kwargs):
try:
data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b""
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b""
else:
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out")
else:
return self.recv(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
else:
return data
def recv_into(self, *args, **kwargs):
try:
return self.connection.recv_into(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return 0
else:
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout("The read operation timed out")
else:
return self.recv_into(*args, **kwargs)
# TLS 1.3 post-handshake authentication
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("read error: %r" % e)
def settimeout(self, timeout):
return self.socket.settimeout(timeout)
def _send_until_done(self, data):
while True:
try:
return self.connection.send(data)
except OpenSSL.SSL.WantWriteError:
if not util.wait_for_write(self.socket, self.socket.gettimeout()):
raise timeout()
continue
except OpenSSL.SSL.SysCallError as e:
raise SocketError(str(e))
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self._send_until_done(
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
)
total_sent += sent
def shutdown(self):
# FIXME rethrow compatible exceptions should we ever use this
self.connection.shutdown()
def close(self):
if self._makefile_refs < 1:
try:
self._closed = True
return self.connection.close()
except OpenSSL.SSL.Error:
return
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False):
x509 = self.connection.get_peer_certificate()
if not x509:
return x509
if binary_form:
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
return {
"subject": ((("commonName", x509.get_subject().CN),),),
"subjectAltName": get_subj_alt_name(x509),
}
def version(self):
return self.connection.get_protocol_version_name()
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
makefile = backport_makefile
WrappedSocket.makefile = makefile
class PyOpenSSLContext(object):
"""
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
for translating the interface of the standard library ``SSLContext`` object
to calls into PyOpenSSL.
"""
def __init__(self, protocol):
self.protocol = _openssl_versions[protocol]
self._ctx = OpenSSL.SSL.Context(self.protocol)
self._options = 0
self.check_hostname = False
@property
def options(self):
return self._options
@options.setter
def options(self, value):
self._options = value
self._ctx.set_options(value)
@property
def verify_mode(self):
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
@verify_mode.setter
def verify_mode(self, value):
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
def set_default_verify_paths(self):
self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers):
if isinstance(ciphers, six.text_type):
ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
if cafile is not None:
cafile = cafile.encode("utf-8")
if capath is not None:
capath = capath.encode("utf-8")
try:
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_chain_file(certfile)
if password is not None:
if not isinstance(password, six.binary_type):
password = password.encode("utf-8")
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile)
def set_alpn_protocols(self, protocols):
protocols = [six.ensure_binary(p) for p in protocols]
return self._ctx.set_alpn_protos(protocols)
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
server_hostname = server_hostname.encode("utf-8")
if server_hostname is not None:
cnx.set_tlsext_host_name(server_hostname)
cnx.set_connect_state()
while True:
try:
cnx.do_handshake()
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(sock, sock.gettimeout()):
raise timeout("select timed out")
continue
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("bad handshake: %r" % e)
break
return WrappedSocket(cnx, sock)
def _verify_callback(cnx, x509, err_no, err_depth, return_code):
return err_no == 0

View file

@ -1,922 +0,0 @@
"""
SecureTranport support for urllib3 via ctypes.
This makes platform-native TLS available to urllib3 users on macOS without the
use of a compiler. This is an important feature because the Python Package
Index is moving to become a TLSv1.2-or-higher server, and the default OpenSSL
that ships with macOS is not capable of doing TLSv1.2. The only way to resolve
this is to give macOS users an alternative solution to the problem, and that
solution is to use SecureTransport.
We use ctypes here because this solution must not require a compiler. That's
because pip is not allowed to require a compiler either.
This is not intended to be a seriously long-term solution to this problem.
The hope is that PEP 543 will eventually solve this issue for us, at which
point we can retire this contrib module. But in the short term, we need to
solve the impending tire fire that is Python on Mac without this kind of
contrib module. So...here we are.
To use this module, simply import and inject it::
import urllib3.contrib.securetransport
urllib3.contrib.securetransport.inject_into_urllib3()
Happy TLSing!
This code is a bastardised version of the code found in Will Bond's oscrypto
library. An enormous debt is owed to him for blazing this trail for us. For
that reason, this code should be considered to be covered both by urllib3's
license and by oscrypto's:
.. code-block::
Copyright (c) 2015-2016 Will Bond <will@wbond.net>
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import absolute_import
import contextlib
import ctypes
import errno
import os.path
import shutil
import socket
import ssl
import struct
import threading
import weakref
import six
from .. import util
from ..util.ssl_ import PROTOCOL_TLS_CLIENT
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
from ._securetransport.low_level import (
_assert_no_error,
_build_tls_unknown_ca_alert,
_cert_array_from_pem,
_create_cfstring_array,
_load_client_cert_chain,
_temporary_keychain,
)
try: # Platform-specific: Python 2
from socket import _fileobject
except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works
HAS_SNI = True
orig_util_HAS_SNI = util.HAS_SNI
orig_util_SSLContext = util.ssl_.SSLContext
# This dictionary is used by the read callback to obtain a handle to the
# calling wrapped socket. This is a pretty silly approach, but for now it'll
# do. I feel like I should be able to smuggle a handle to the wrapped socket
# directly in the SSLConnectionRef, but for now this approach will work I
# guess.
#
# We need to lock around this structure for inserts, but we don't do it for
# reads/writes in the callbacks. The reasoning here goes as follows:
#
# 1. It is not possible to call into the callbacks before the dictionary is
# populated, so once in the callback the id must be in the dictionary.
# 2. The callbacks don't mutate the dictionary, they only read from it, and
# so cannot conflict with any of the insertions.
#
# This is good: if we had to lock in the callbacks we'd drastically slow down
# the performance of this code.
_connection_refs = weakref.WeakValueDictionary()
_connection_ref_lock = threading.Lock()
# Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over
# for no better reason than we need *a* limit, and this one is right there.
SSL_WRITE_BLOCKSIZE = 16384
# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to
# individual cipher suites. We need to do this because this is how
# SecureTransport wants them.
CIPHER_SUITES = [
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
SecurityConst.TLS_AES_256_GCM_SHA384,
SecurityConst.TLS_AES_128_GCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384,
SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256,
SecurityConst.TLS_AES_128_CCM_8_SHA256,
SecurityConst.TLS_AES_128_CCM_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256,
SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA,
SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA,
]
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = {
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
PROTOCOL_TLS_CLIENT: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12),
}
if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
)
if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
)
if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
)
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
)
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
)
def inject_into_urllib3():
"""
Monkey-patch urllib3 with SecureTransport-backed SSL-support.
"""
util.SSLContext = SecureTransportContext
util.ssl_.SSLContext = SecureTransportContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True
def extract_from_urllib3():
"""
Undo monkey-patching by :func:`inject_into_urllib3`.
"""
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False
def _read_callback(connection_id, data_buffer, data_length_pointer):
"""
SecureTransport read callback. This is called by ST to request that data
be returned from the socket.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
requested_length = data_length_pointer[0]
timeout = wrapped_socket.gettimeout()
error = None
read_count = 0
try:
while read_count < requested_length:
if timeout is None or timeout >= 0:
if not util.wait_for_read(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out")
remaining = requested_length - read_count
buffer = (ctypes.c_char * remaining).from_address(
data_buffer + read_count
)
chunk_size = base_socket.recv_into(buffer, remaining)
read_count += chunk_size
if not chunk_size:
if not read_count:
return SecurityConst.errSSLClosedGraceful
break
except (socket.error) as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = read_count
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = read_count
if read_count != requested_length:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
def _write_callback(connection_id, data_buffer, data_length_pointer):
"""
SecureTransport write callback. This is called by ST to request that data
actually be sent on the network.
"""
wrapped_socket = None
try:
wrapped_socket = _connection_refs.get(connection_id)
if wrapped_socket is None:
return SecurityConst.errSSLInternal
base_socket = wrapped_socket.socket
bytes_to_write = data_length_pointer[0]
data = ctypes.string_at(data_buffer, bytes_to_write)
timeout = wrapped_socket.gettimeout()
error = None
sent = 0
try:
while sent < bytes_to_write:
if timeout is None or timeout >= 0:
if not util.wait_for_write(base_socket, timeout):
raise socket.error(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data)
sent += chunk_sent
# This has some needless copying here, but I'm not sure there's
# much value in optimising this data path.
data = data[chunk_sent:]
except (socket.error) as e:
error = e.errno
if error is not None and error != errno.EAGAIN:
data_length_pointer[0] = sent
if error == errno.ECONNRESET or error == errno.EPIPE:
return SecurityConst.errSSLClosedAbort
raise
data_length_pointer[0] = sent
if sent != bytes_to_write:
return SecurityConst.errSSLWouldBlock
return 0
except Exception as e:
if wrapped_socket is not None:
wrapped_socket._exception = e
return SecurityConst.errSSLInternal
# We need to keep these two objects references alive: if they get GC'd while
# in use then SecureTransport could attempt to call a function that is in freed
# memory. That would be...uh...bad. Yeah, that's the word. Bad.
_read_callback_pointer = Security.SSLReadFunc(_read_callback)
_write_callback_pointer = Security.SSLWriteFunc(_write_callback)
class WrappedSocket(object):
"""
API-compatibility wrapper for Python's OpenSSL wrapped socket object.
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
collector of PyPy.
"""
def __init__(self, socket):
self.socket = socket
self.context = None
self._makefile_refs = 0
self._closed = False
self._exception = None
self._keychain = None
self._keychain_dir = None
self._client_cert_chain = None
# We save off the previously-configured timeout and then set it to
# zero. This is done because we use select and friends to handle the
# timeouts, but if we leave the timeout set on the lower socket then
# Python will "kindly" call select on that socket again for us. Avoid
# that by forcing the timeout to zero.
self._timeout = self.socket.gettimeout()
self.socket.settimeout(0)
@contextlib.contextmanager
def _raise_on_error(self):
"""
A context manager that can be used to wrap calls that do I/O from
SecureTransport. If any of the I/O callbacks hit an exception, this
context manager will correctly propagate the exception after the fact.
This avoids silently swallowing those exceptions.
It also correctly forces the socket closed.
"""
self._exception = None
# We explicitly don't catch around this yield because in the unlikely
# event that an exception was hit in the block we don't want to swallow
# it.
yield
if self._exception is not None:
exception, self._exception = self._exception, None
self.close()
raise exception
def _set_ciphers(self):
"""
Sets up the allowed ciphers. By default this matches the set in
util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done
custom and doesn't allow changing at this time, mostly because parsing
OpenSSL cipher strings is going to be a freaking nightmare.
"""
ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES)
result = Security.SSLSetEnabledCiphers(
self.context, ciphers, len(CIPHER_SUITES)
)
_assert_no_error(result)
def _set_alpn_protocols(self, protocols):
"""
Sets up the ALPN protocols on the context.
"""
if not protocols:
return
protocols_arr = _create_cfstring_array(protocols)
try:
result = Security.SSLSetALPNProtocols(self.context, protocols_arr)
_assert_no_error(result)
finally:
CoreFoundation.CFRelease(protocols_arr)
def _custom_validate(self, verify, trust_bundle):
"""
Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when
using a custom trust DB.
Raises an SSLError if the connection is not trusted.
"""
# If we disabled cert validation, just say: cool.
if not verify:
return
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
try:
trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes:
return
reason = "error code: %d" % (trust_result,)
except Exception as e:
# Do not trust on error
reason = "exception: %r" % (e,)
# SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version())
self.socket.sendall(rec)
# close the connection immediately
# l_onoff = 1, activate linger
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
self.close()
raise ssl.SSLError("certificate verify failed, %s" % reason)
def _evaluate_trust(self, trust_bundle):
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f:
trust_bundle = f.read()
cert_array = None
trust = Security.SecTrustRef()
try:
# Get a CFArray that contains the certs we want.
cert_array = _cert_array_from_pem(trust_bundle)
# Ok, now the hard part. We want to get the SecTrustRef that ST has
# created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
raise ssl.SSLError("Failed to copy trust reference")
result = Security.SecTrustSetAnchorCertificates(trust, cert_array)
_assert_no_error(result)
result = Security.SecTrustSetAnchorCertificatesOnly(trust, True)
_assert_no_error(result)
trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
_assert_no_error(result)
finally:
if trust:
CoreFoundation.CFRelease(trust)
if cert_array is not None:
CoreFoundation.CFRelease(cert_array)
return trust_result.value
def handshake(
self,
server_hostname,
verify,
trust_bundle,
min_version,
max_version,
client_cert,
client_key,
client_key_passphrase,
alpn_protocols,
):
"""
Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code.
"""
# First, we do the initial bits of connection setup. We need to create
# a context, set its I/O funcs, and set the connection reference.
self.context = Security.SSLCreateContext(
None, SecurityConst.kSSLClientSide, SecurityConst.kSSLStreamType
)
result = Security.SSLSetIOFuncs(
self.context, _read_callback_pointer, _write_callback_pointer
)
_assert_no_error(result)
# Here we need to compute the handle to use. We do this by taking the
# id of self modulo 2**31 - 1. If this is already in the dictionary, we
# just keep incrementing by one until we find a free space.
with _connection_ref_lock:
handle = id(self) % 2147483647
while handle in _connection_refs:
handle = (handle + 1) % 2147483647
_connection_refs[handle] = self
result = Security.SSLSetConnection(self.context, handle)
_assert_no_error(result)
# If we have a server hostname, we should set that too.
if server_hostname:
if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname)
)
_assert_no_error(result)
# Setup the ciphers.
self._set_ciphers()
# Setup the ALPN protocols.
self._set_alpn_protocols(alpn_protocols)
# Set the minimum and maximum TLS versions.
result = Security.SSLSetProtocolVersionMin(self.context, min_version)
_assert_no_error(result)
result = Security.SSLSetProtocolVersionMax(self.context, max_version)
_assert_no_error(result)
# If there's a trust DB, we need to use it. We do that by telling
# SecureTransport to break on server auth. We also do that if we don't
# want to validate the certs at all: we just won't actually do any
# authing in that case.
if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption(
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
)
_assert_no_error(result)
# If there's a client cert, we need to use it.
if client_cert:
self._keychain, self._keychain_dir = _temporary_keychain()
self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key
)
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
_assert_no_error(result)
while True:
with self._raise_on_error():
result = Security.SSLHandshake(self.context)
if result == SecurityConst.errSSLWouldBlock:
raise socket.timeout("handshake timed out")
elif result == SecurityConst.errSSLServerAuthCompleted:
self._custom_validate(verify, trust_bundle)
continue
else:
_assert_no_error(result)
break
def fileno(self):
return self.socket.fileno()
# Copy-pasted from Python 3.5 source code
def _decref_socketios(self):
if self._makefile_refs > 0:
self._makefile_refs -= 1
if self._closed:
self.close()
def recv(self, bufsiz):
buffer = ctypes.create_string_buffer(bufsiz)
bytes_read = self.recv_into(buffer, bufsiz)
data = buffer[:bytes_read]
return data
def recv_into(self, buffer, nbytes=None):
# Read short on EOF.
if self._closed:
return 0
if nbytes is None:
nbytes = len(buffer)
buffer = (ctypes.c_char * nbytes).from_buffer(buffer)
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLRead(
self.context, buffer, nbytes, ctypes.byref(processed_bytes)
)
# There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify.
if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short"
# and return.
if processed_bytes.value == 0:
# Timed out, no data read.
raise socket.timeout("recv timed out")
elif result in (
SecurityConst.errSSLClosedGraceful,
SecurityConst.errSSLClosedNoNotify,
):
# The remote peer has closed this connection. We should do so as
# well. Note that we don't actually return here because in
# principle this could actually be fired along with return data.
# It's unlikely though.
self.close()
else:
_assert_no_error(result)
# Ok, we read and probably succeeded. We should return whatever data
# was actually read.
return processed_bytes.value
def settimeout(self, timeout):
self._timeout = timeout
def gettimeout(self):
return self._timeout
def send(self, data):
processed_bytes = ctypes.c_size_t(0)
with self._raise_on_error():
result = Security.SSLWrite(
self.context, data, len(data), ctypes.byref(processed_bytes)
)
if result == SecurityConst.errSSLWouldBlock and processed_bytes.value == 0:
# Timed out
raise socket.timeout("send timed out")
else:
_assert_no_error(result)
# We sent, and probably succeeded. Tell them how much we sent.
return processed_bytes.value
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self):
with self._raise_on_error():
Security.SSLClose(self.context)
def close(self):
# TODO: should I do clean shutdown here? Do I have to?
if self._makefile_refs < 1:
self._closed = True
if self.context:
CoreFoundation.CFRelease(self.context)
self.context = None
if self._client_cert_chain:
CoreFoundation.CFRelease(self._client_cert_chain)
self._client_cert_chain = None
if self._keychain:
Security.SecKeychainDelete(self._keychain)
CoreFoundation.CFRelease(self._keychain)
shutil.rmtree(self._keychain_dir)
self._keychain = self._keychain_dir = None
return self.socket.close()
else:
self._makefile_refs -= 1
def getpeercert(self, binary_form=False):
# Urgh, annoying.
#
# Here's how we do this:
#
# 1. Call SSLCopyPeerTrust to get hold of the trust object for this
# connection.
# 2. Call SecTrustGetCertificateAtIndex for index 0 to get the leaf.
# 3. To get the CN, call SecCertificateCopyCommonName and process that
# string so that it's of the appropriate type.
# 4. To get the SAN, we need to do something a bit more complex:
# a. Call SecCertificateCopyValues to get the data, requesting
# kSecOIDSubjectAltName.
# b. Mess about with this dictionary to try to get the SANs out.
#
# This is gross. Really gross. It's going to be a few hundred LoC extra
# just to repeat something that SecureTransport can *already do*. So my
# operating assumption at this time is that what we want to do is
# instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport.
if not binary_form:
raise ValueError("SecureTransport only supports dumping binary certs")
trust = Security.SecTrustRef()
certdata = None
der_bytes = None
try:
# Grab the trust store.
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
# Probably we haven't done the handshake yet. No biggie.
return None
cert_count = Security.SecTrustGetCertificateCount(trust)
if not cert_count:
# Also a case that might happen if we haven't handshaked.
# Handshook? Handshaken?
return None
leaf = Security.SecTrustGetCertificateAtIndex(trust, 0)
assert leaf
# Ok, now we want the DER bytes.
certdata = Security.SecCertificateCopyData(leaf)
assert certdata
data_length = CoreFoundation.CFDataGetLength(certdata)
data_buffer = CoreFoundation.CFDataGetBytePtr(certdata)
der_bytes = ctypes.string_at(data_buffer, data_length)
finally:
if certdata:
CoreFoundation.CFRelease(certdata)
if trust:
CoreFoundation.CFRelease(trust)
return der_bytes
def version(self):
protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol)
)
_assert_no_error(result)
if protocol.value == SecurityConst.kTLSProtocol13:
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
elif protocol.value == SecurityConst.kTLSProtocol12:
return "TLSv1.2"
elif protocol.value == SecurityConst.kTLSProtocol11:
return "TLSv1.1"
elif protocol.value == SecurityConst.kTLSProtocol1:
return "TLSv1"
elif protocol.value == SecurityConst.kSSLProtocol3:
return "SSLv3"
elif protocol.value == SecurityConst.kSSLProtocol2:
return "SSLv2"
else:
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
def _reuse(self):
self._makefile_refs += 1
def _drop(self):
if self._makefile_refs < 1:
self.close()
else:
self._makefile_refs -= 1
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
def makefile(self, mode="r", buffering=None, *args, **kwargs):
# We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more).
buffering = 0
return backport_makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile
class SecureTransportContext(object):
"""
I am a wrapper class for the SecureTransport library, to translate the
interface of the standard library ``SSLContext`` object to calls into
SecureTransport.
"""
def __init__(self, protocol):
self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0
self._verify = False
self._trust_bundle = None
self._client_cert = None
self._client_key = None
self._client_key_passphrase = None
self._alpn_protocols = None
@property
def check_hostname(self):
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
return True
@check_hostname.setter
def check_hostname(self, value):
"""
SecureTransport cannot have its hostname checking disabled. For more,
see the comment on getpeercert() in this file.
"""
pass
@property
def options(self):
# TODO: Well, crap.
#
# So this is the bit of the code that is the most likely to cause us
# trouble. Essentially we need to enumerate all of the SSL options that
# users might want to use and try to see if we can sensibly translate
# them, or whether we should just ignore them.
return self._options
@options.setter
def options(self, value):
# TODO: Update in line with above.
self._options = value
@property
def verify_mode(self):
return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE
@verify_mode.setter
def verify_mode(self, value):
self._verify = True if value == ssl.CERT_REQUIRED else False
def set_default_verify_paths(self):
# So, this has to do something a bit weird. Specifically, what it does
# is nothing.
#
# This means that, if we had previously had load_verify_locations
# called, this does not undo that. We need to do that because it turns
# out that the rest of the urllib3 code will attempt to load the
# default verify paths if it hasn't been told about any paths, even if
# the context itself was sometime earlier. We resolve that by just
# ignoring it.
pass
def load_default_certs(self):
return self.set_default_verify_paths()
def set_ciphers(self, ciphers):
# For now, we just require the default cipher string.
if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# OK, we only really support cadata and cafile.
if capath is not None:
raise ValueError("SecureTransport does not support cert directories")
# Raise if cafile does not exist.
if cafile is not None:
with open(cafile):
pass
self._trust_bundle = cafile or cadata
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._client_cert = certfile
self._client_key = keyfile
self._client_cert_passphrase = password
def set_alpn_protocols(self, protocols):
"""
Sets the ALPN protocols that will later be set on the context.
Raises a NotImplementedError if ALPN is not supported.
"""
if not hasattr(Security, "SSLSetALPNProtocols"):
raise NotImplementedError(
"SecureTransport supports ALPN only in macOS 10.12+"
)
self._alpn_protocols = [six.ensure_binary(p) for p in protocols]
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
# So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal.
assert not server_side
assert do_handshake_on_connect
assert suppress_ragged_eofs
# Ok, we're good to go. Now we want to create the wrapped socket object
# and store it in the appropriate place.
wrapped_socket = WrappedSocket(sock)
# Now we can handshake
wrapped_socket.handshake(
server_hostname,
self._verify,
self._trust_bundle,
self._min_version,
self._max_version,
self._client_cert,
self._client_key,
self._client_key_passphrase,
self._alpn_protocols,
)
return wrapped_socket

View file

@ -1,216 +0,0 @@
# -*- coding: utf-8 -*-
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install PySocks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import absolute_import
try:
import socks
except ImportError:
import warnings
from ..exceptions import DependencyWarning
warnings.warn(
(
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/1.26.x/contrib.html#socks-proxies"
),
DependencyWarning,
)
raise
from socket import error as SocketError
from socket import timeout as SocketTimeout
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None
class SOCKSConnection(HTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(self, *args, **kwargs):
self._socks_options = kwargs.pop("_socks_options")
super(SOCKSConnection, self).__init__(*args, **kwargs)
def _new_conn(self):
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw["socket_options"] = self.socket_options
try:
conn = socks.create_connection(
(self.host, self.port),
proxy_type=self._socks_options["socks_version"],
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"],
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout,
**extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise
# useful errors here.
if e.socket_err:
error = e.socket_err
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
else:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % error
)
else:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
except SocketError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
return conn
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url,
username=None,
password=None,
num_pools=10,
headers=None,
**connection_pool_kw
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw
)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme

View file

@ -1,323 +0,0 @@
from __future__ import absolute_import
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
# Base Exceptions
class HTTPError(Exception):
"""Base exception used by this module."""
pass
class HTTPWarning(Warning):
"""Base warning used by this module."""
pass
class PoolError(HTTPError):
"""Base exception for errors caused within a pool."""
def __init__(self, pool, message):
self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message))
def __reduce__(self):
# For pickling purposes.
return self.__class__, (None, None)
class RequestError(PoolError):
"""Base exception for PoolErrors that have associated URLs."""
def __init__(self, pool, url, message):
self.url = url
PoolError.__init__(self, pool, message)
def __reduce__(self):
# For pickling purposes.
return self.__class__, (None, self.url, None)
class SSLError(HTTPError):
"""Raised when SSL certificate fails in an HTTPS connection."""
pass
class ProxyError(HTTPError):
"""Raised when the connection to a proxy fails."""
def __init__(self, message, error, *args):
super(ProxyError, self).__init__(message, error, *args)
self.original_error = error
class DecodeError(HTTPError):
"""Raised when automatic decoding based on Content-Type fails."""
pass
class ProtocolError(HTTPError):
"""Raised when something unexpected happens mid-request/response."""
pass
#: Renamed to ProtocolError but aliased for backwards compatibility.
ConnectionError = ProtocolError
# Leaf Exceptions
class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded.
:param pool: The connection pool
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
:param string url: The requested Url
:param exceptions.Exception reason: The underlying error
"""
def __init__(self, pool, url, reason=None):
self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
RequestError.__init__(self, pool, url, message)
class HostChangedError(RequestError):
"""Raised when an existing pool gets a request for a foreign host."""
def __init__(self, pool, url, retries=3):
message = "Tried to open a foreign host with url: %s" % url
RequestError.__init__(self, pool, url, message)
self.retries = retries
class TimeoutStateError(HTTPError):
"""Raised when passing an invalid state to a timeout"""
pass
class TimeoutError(HTTPError):
"""Raised when a socket timeout error occurs.
Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
"""
pass
class ReadTimeoutError(TimeoutError, RequestError):
"""Raised when a socket timeout occurs while receiving data from a server"""
pass
# This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError
class ConnectTimeoutError(TimeoutError):
"""Raised when a socket timeout occurs while connecting to a server"""
pass
class NewConnectionError(ConnectTimeoutError, PoolError):
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
pass
class EmptyPoolError(PoolError):
"""Raised when a pool runs out of connections and no more are allowed."""
pass
class ClosedPoolError(PoolError):
"""Raised when a request enters a pool after the pool has been closed."""
pass
class LocationValueError(ValueError, HTTPError):
"""Raised when there is something wrong with a given URL input."""
pass
class LocationParseError(LocationValueError):
"""Raised when get_host or similar fails to parse the URL input."""
def __init__(self, location):
message = "Failed to parse: %s" % location
HTTPError.__init__(self, message)
self.location = location
class URLSchemeUnknown(LocationValueError):
"""Raised when a URL input has an unsupported scheme."""
def __init__(self, scheme):
message = "Not supported URL scheme %s" % scheme
super(URLSchemeUnknown, self).__init__(message)
self.scheme = scheme
class ResponseError(HTTPError):
"""Used as a container for an error reason supplied in a MaxRetryError."""
GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning):
"""Warned when performing security reducing actions"""
pass
class SubjectAltNameWarning(SecurityWarning):
"""Warned when connecting to a host with a certificate missing a SAN."""
pass
class InsecureRequestWarning(SecurityWarning):
"""Warned when making an unverified HTTPS request."""
pass
class SystemTimeWarning(SecurityWarning):
"""Warned when system time is suspected to be wrong"""
pass
class InsecurePlatformWarning(SecurityWarning):
"""Warned when certain TLS/SSL configuration is not available on a platform."""
pass
class SNIMissingWarning(HTTPWarning):
"""Warned when making a HTTPS request without SNI available."""
pass
class DependencyWarning(HTTPWarning):
"""
Warned when an attempt is made to import a module with missing optional
dependencies.
"""
pass
class ResponseNotChunked(ProtocolError, ValueError):
"""Response needs to be chunked in order to read it as chunks."""
pass
class BodyNotHttplibCompatible(HTTPError):
"""
Body should be :class:`http.client.HTTPResponse` like
(have an fp attribute which returns raw chunks) for read_chunked().
"""
pass
class IncompleteRead(HTTPError, httplib_IncompleteRead):
"""
Response length doesn't match expected Content-Length
Subclass of :class:`http.client.IncompleteRead` to allow int value
for ``partial`` to avoid creating large objects on streamed reads.
"""
def __init__(self, partial, expected):
super(IncompleteRead, self).__init__(partial, expected)
def __repr__(self):
return "IncompleteRead(%i bytes read, %i more expected)" % (
self.partial,
self.expected,
)
class InvalidChunkLength(HTTPError, httplib_IncompleteRead):
"""Invalid chunk length in a chunked response."""
def __init__(self, response, length):
super(InvalidChunkLength, self).__init__(
response.tell(), response.length_remaining
)
self.response = response
self.length = length
def __repr__(self):
return "InvalidChunkLength(got length %r, %i bytes read)" % (
self.length,
self.partial,
)
class InvalidHeader(HTTPError):
"""The header provided was somehow invalid."""
pass
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
"""ProxyManager does not support the supplied scheme"""
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme):
# 'localhost' is here because our URL parser parses
# localhost:8080 -> scheme=localhost, remove if we fix this.
if scheme == "localhost":
scheme = None
if scheme is None:
message = "Proxy URL had no scheme, should start with http:// or https://"
else:
message = (
"Proxy URL had unsupported scheme %s, should use http:// or https://"
% scheme
)
super(ProxySchemeUnknown, self).__init__(message)
class ProxySchemeUnsupported(ValueError):
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
pass
class HeaderParsingError(HTTPError):
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
def __init__(self, defects, unparsed_data):
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
super(HeaderParsingError, self).__init__(message)
class UnrewindableBodyError(HTTPError):
"""urllib3 encountered an error when trying to rewind a body"""
pass

View file

@ -1,274 +0,0 @@
from __future__ import absolute_import
import email.utils
import mimetypes
import re
from .packages import six
def guess_content_type(filename, default="application/octet-stream"):
"""
Guess the "Content-Type" of a file.
:param filename:
The filename to guess the "Content-Type" of using :mod:`mimetypes`.
:param default:
If no "Content-Type" can be guessed, default to `default`.
"""
if filename:
return mimetypes.guess_type(filename)[0] or default
return default
def format_header_param_rfc2231(name, value):
"""
Helper function to format and quote a single header parameter using the
strategy defined in RFC 2231.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows
`RFC 2388 Section 4.4 <https://tools.ietf.org/html/rfc2388#section-4.4>`_.
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:ret:
An RFC-2231-formatted unicode string.
"""
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
if not any(ch in value for ch in '"\\\r\n'):
result = u'%s="%s"' % (name, value)
try:
result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError):
pass
else:
return result
if six.PY2: # Python 2:
value = value.encode("utf-8")
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
# string in Python 2 but accepts and returns unicode strings in Python 3
value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value)
if six.PY2: # Python 2:
value = value.decode("utf-8")
return value
_HTML5_REPLACEMENTS = {
u"\u0022": u"%22",
# Replace "\" with "\\".
u"\u005C": u"\u005C\u005C",
}
# All control characters from 0x00 to 0x1F *except* 0x1B.
_HTML5_REPLACEMENTS.update(
{
six.unichr(cc): u"%{:02X}".format(cc)
for cc in range(0x00, 0x1F + 1)
if cc not in (0x1B,)
}
)
def _replace_multiple(value, needles_and_replacements):
def replacer(match):
return needles_and_replacements[match.group(0)]
pattern = re.compile(
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
)
result = pattern.sub(replacer, value)
return result
def format_header_param_html5(name, value):
"""
Helper function to format and quote a single header parameter using the
HTML5 strategy.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows the `HTML5 Working Draft
Section 4.10.22.7`_ and matches the behavior of curl and modern browsers.
.. _HTML5 Working Draft Section 4.10.22.7:
https://w3c.github.io/html/sec-forms.html#multipart-form-data
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:ret:
A unicode string, stripped of troublesome characters.
"""
if isinstance(value, six.binary_type):
value = value.decode("utf-8")
value = _replace_multiple(value, _HTML5_REPLACEMENTS)
return u'%s="%s"' % (name, value)
# For backwards-compatibility.
format_header_param = format_header_param_html5
class RequestField(object):
"""
A data container for request body parameters.
:param name:
The name of this request field. Must be unicode.
:param data:
The data/value body.
:param filename:
An optional filename of the request field. Must be unicode.
:param headers:
An optional dict-like object of headers to initially use for the field.
:param header_formatter:
An optional callable that is used to encode and format the headers. By
default, this is :func:`format_header_param_html5`.
"""
def __init__(
self,
name,
data,
filename=None,
headers=None,
header_formatter=format_header_param_html5,
):
self._name = name
self._filename = filename
self.data = data
self.headers = {}
if headers:
self.headers = dict(headers)
self.header_formatter = header_formatter
@classmethod
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
"""
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
Supports constructing :class:`~urllib3.fields.RequestField` from
parameter of key/value strings AND key/filetuple. A filetuple is a
(filename, data, MIME type) tuple where the MIME type is optional.
For example::
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
Field names and filenames must be unicode.
"""
if isinstance(value, tuple):
if len(value) == 3:
filename, data, content_type = value
else:
filename, data = value
content_type = guess_content_type(filename)
else:
filename = None
content_type = None
data = value
request_param = cls(
fieldname, data, filename=filename, header_formatter=header_formatter
)
request_param.make_multipart(content_type=content_type)
return request_param
def _render_part(self, name, value):
"""
Overridable helper function to format a single header parameter. By
default, this calls ``self.header_formatter``.
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as a unicode string.
"""
return self.header_formatter(name, value)
def _render_parts(self, header_parts):
"""
Helper function to format and quote a single header.
Useful for single headers that are composed of multiple items. E.g.,
'Content-Disposition' fields.
:param header_parts:
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
as `k1="v1"; k2="v2"; ...`.
"""
parts = []
iterable = header_parts
if isinstance(header_parts, dict):
iterable = header_parts.items()
for name, value in iterable:
if value is not None:
parts.append(self._render_part(name, value))
return u"; ".join(parts)
def render_headers(self):
"""
Renders the headers for this request field.
"""
lines = []
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys:
if self.headers.get(sort_key, False):
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
for header_name, header_value in self.headers.items():
if header_name not in sort_keys:
if header_value:
lines.append(u"%s: %s" % (header_name, header_value))
lines.append(u"\r\n")
return u"\r\n".join(lines)
def make_multipart(
self, content_disposition=None, content_type=None, content_location=None
):
"""
Makes this request field into a multipart request field.
This method overrides "Content-Disposition", "Content-Type" and
"Content-Location" headers to the request parameter.
:param content_type:
The 'Content-Type' of the request body.
:param content_location:
The 'Content-Location' of the request body.
"""
self.headers["Content-Disposition"] = content_disposition or u"form-data"
self.headers["Content-Disposition"] += u"; ".join(
[
u"",
self._render_parts(
((u"name", self._name), (u"filename", self._filename))
),
]
)
self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location

View file

@ -1,98 +0,0 @@
from __future__ import absolute_import
import binascii
import codecs
import os
from io import BytesIO
from .fields import RequestField
from .packages import six
from .packages.six import b
writer = codecs.lookup("utf-8")[3]
def choose_boundary():
"""
Our embarrassingly-simple replacement for mimetools.choose_boundary.
"""
boundary = binascii.hexlify(os.urandom(16))
if not six.PY2:
boundary = boundary.decode("ascii")
return boundary
def iter_field_objects(fields):
"""
Iterate over fields.
Supports list of (k, v) tuples and dicts, and lists of
:class:`~urllib3.fields.RequestField`.
"""
if isinstance(fields, dict):
i = six.iteritems(fields)
else:
i = iter(fields)
for field in i:
if isinstance(field, RequestField):
yield field
else:
yield RequestField.from_tuples(*field)
def iter_fields(fields):
"""
.. deprecated:: 1.6
Iterate over fields.
The addition of :class:`~urllib3.fields.RequestField` makes this function
obsolete. Instead, use :func:`iter_field_objects`, which returns
:class:`~urllib3.fields.RequestField` objects.
Supports list of (k, v) tuples and dicts.
"""
if isinstance(fields, dict):
return ((k, v) for k, v in six.iteritems(fields))
return ((k, v) for k, v in fields)
def encode_multipart_formdata(fields, boundary=None):
"""
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
:param fields:
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
:param boundary:
If not specified, then a random boundary will be generated using
:func:`urllib3.filepost.choose_boundary`.
"""
body = BytesIO()
if boundary is None:
boundary = choose_boundary()
for field in iter_field_objects(fields):
body.write(b("--%s\r\n" % (boundary)))
writer(body).write(field.render_headers())
data = field.data
if isinstance(data, int):
data = str(data) # Backwards compatibility
if isinstance(data, six.text_type):
writer(body).write(data)
else:
body.write(data)
body.write(b"\r\n")
body.write(b("--%s--\r\n" % (boundary)))
content_type = str("multipart/form-data; boundary=%s" % boundary)
return body.getvalue(), content_type

View file

@ -1,5 +0,0 @@
from __future__ import absolute_import
from . import ssl_match_hostname
__all__ = ("ssl_match_hostname",)

View file

@ -1,51 +0,0 @@
# -*- coding: utf-8 -*-
"""
backports.makefile
~~~~~~~~~~~~~~~~~~
Backports the Python 3 ``socket.makefile`` method for use with anything that
wants to create a "fake" socket object.
"""
import io
from socket import SocketIO
def backport_makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Backport of ``socket.makefile`` from Python 3.5.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = SocketIO(self, rawmode)
self._makefile_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text

File diff suppressed because it is too large Load diff

View file

@ -1,24 +0,0 @@
import sys
try:
# Our match_hostname function is the same as 3.10's, so we only want to
# import the match_hostname function if it's at least that good.
# We also fallback on Python 3.10+ because our code doesn't emit
# deprecation warnings and is the same as Python 3.10 otherwise.
if sys.version_info < (3, 5) or sys.version_info >= (3, 10):
raise ImportError("Fallback to vendored code")
from ssl import CertificateError, match_hostname
except ImportError:
try:
# Backport of the function from a pypi module
from backports.ssl_match_hostname import ( # type: ignore
CertificateError,
match_hostname,
)
except ImportError:
# Our vendored copy
from ._implementation import CertificateError, match_hostname # type: ignore
# Not needed, but documenting what we provide.
__all__ = ("CertificateError", "match_hostname")

View file

@ -1,160 +0,0 @@
"""The match_hostname() function from Python 3.3.3, essential when using SSL."""
# Note: This file is under the PSF license as the code comes from the python
# stdlib. http://docs.python.org/3/license.html
import re
import sys
# ipaddress has been backported to 2.6+ in pypi. If it is installed on the
# system, use it to handle IPAddress ServerAltnames (this was added in
# python-3.5) otherwise only do DNS matching. This allows
# backports.ssl_match_hostname to continue to be used in Python 2.7.
try:
import ipaddress
except ImportError:
ipaddress = None
__version__ = "3.5.0.1"
class CertificateError(ValueError):
pass
def _dnsname_match(dn, hostname, max_wildcards=1):
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
"""
pats = []
if not dn:
return False
# Ported from python3-syntax:
# leftmost, *remainder = dn.split(r'.')
parts = dn.split(r".")
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count("*")
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn)
)
# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append("[^.]+")
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
# U-label of an internationalized domain name.
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname)
def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,):
obj = unicode(obj, encoding="ascii", errors="strict")
return obj
def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses.
RFC 6125 explicitly doesn't define an algorithm for this
(section 1.7.2 - "Out of Scope").
"""
# OpenSSL may add a trailing newline to a subjectAltName's IP address
# Divergence from upstream: ipaddress can't handle byte str
ip = ipaddress.ip_address(_to_unicode(ipname).rstrip())
return ip == host_ip
def match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*.
CertificateError is raised on failure. On success, the function
returns nothing.
"""
if not cert:
raise ValueError(
"empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED"
)
try:
# Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname))
except ValueError:
# Not an IP address (common case)
host_ip = None
except UnicodeError:
# Divergence from upstream: Have to deal with ipaddress not taking
# byte strings. addresses should be all ascii, so we consider it not
# an ipaddress in this case
host_ip = None
except AttributeError:
# Divergence from upstream: Make ipaddress library optional
if ipaddress is None:
host_ip = None
else:
raise
dnsnames = []
san = cert.get("subjectAltName", ())
for key, value in san:
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
return
dnsnames.append(value)
elif key == "IP Address":
if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get("subject", ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == "commonName":
if _dnsname_match(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError(
"hostname %r "
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1:
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
else:
raise CertificateError(
"no appropriate commonName or subjectAltName fields were found"
)

View file

@ -1,536 +0,0 @@
from __future__ import absolute_import
import collections
import functools
import logging
from ._collections import RecentlyUsedContainer
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme
from .exceptions import (
LocationValueError,
MaxRetryError,
ProxySchemeUnknown,
ProxySchemeUnsupported,
URLSchemeUnknown,
)
from .packages import six
from .packages.six.moves.urllib.parse import urljoin
from .request import RequestMethods
from .util.proxy import connection_requires_http_tunnel
from .util.retry import Retry
from .util.url import parse_url
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
log = logging.getLogger(__name__)
SSL_KEYWORDS = (
"key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
"key_password",
)
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
"key_scheme", # str
"key_host", # str
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_strict", # bool
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_key_password", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key__proxy_config", # class
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
)
#: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple("PoolKey", _key_fields)
_proxy_config_fields = ("ssl_context", "use_forwarding_for_https")
ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields)
def _default_key_normalizer(key_class, request_context):
"""
Create a pool key out of a request context dictionary.
According to RFC 3986, both the scheme and host are case-insensitive.
Therefore, this function normalizes both before constructing the pool
key for an HTTPS request. If you wish to change this behaviour, provide
alternate callables to ``key_fn_by_scheme``.
:param key_class:
The class to use when constructing the key. This should be a namedtuple
with the ``scheme`` and ``host`` keys at a minimum.
:type key_class: namedtuple
:param request_context:
A dictionary-like object that contain the context for a request.
:type request_context: dict
:return: A namedtuple that can be used as a connection pool key.
:rtype: PoolKey
"""
# Since we mutate the dictionary, make a copy first
context = request_context.copy()
context["scheme"] = context["scheme"].lower()
context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets
for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None:
context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a
# tuple.
socket_opts = context.get("socket_options")
if socket_opts is not None:
context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'.
for key in list(context.keys()):
context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context
for field in key_class._fields:
if field not in context:
context[field] = None
return key_class(**context)
#: A dictionary that maps a scheme to a callable that creates a pool key.
#: This can be used to alter the way pool keys are constructed, if desired.
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.
key_fn_by_scheme = {
"http": functools.partial(_default_key_normalizer, PoolKey),
"https": functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
class PoolManager(RequestMethods):
"""
Allows for arbitrary requests while transparently keeping track of
necessary connection pools for you.
:param num_pools:
Number of connection pools to cache before discarding the least
recently used pool.
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
:param \\**connection_pool_kw:
Additional parameters are used to create fresh
:class:`urllib3.connectionpool.ConnectionPool` instances.
Example::
>>> manager = PoolManager(num_pools=2)
>>> r = manager.request('GET', 'http://google.com/')
>>> r = manager.request('GET', 'http://google.com/mail')
>>> r = manager.request('GET', 'http://yahoo.com/')
>>> len(manager.pools)
2
"""
proxy = None
proxy_config = None
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
self.key_fn_by_scheme = key_fn_by_scheme.copy()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.clear()
# Return False to re-raise any potential exceptions
return False
def _new_pool(self, scheme, host, port, request_context=None):
"""
Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and
any additional pool keyword arguments.
If ``request_context`` is provided, it is provided as keyword arguments
to the pool class used. This method is used to actually create the
connection pools handed out by :meth:`connection_from_url` and
companion methods. It is intended to be overridden for customization.
"""
pool_cls = self.pool_classes_by_scheme[scheme]
if request_context is None:
request_context = self.connection_pool_kw.copy()
# Although the context has everything necessary to create the pool,
# this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can
# be removed.
for key in ("scheme", "host", "port"):
request_context.pop(key, None)
if scheme == "http":
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)
return pool_cls(host, port, **request_context)
def clear(self):
"""
Empty our store of pools and direct them all to close.
This will not affect in-flight connections, but they will not be
re-used after completion.
"""
self.pools.clear()
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
"""
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme.
If ``port`` isn't given, it will be derived from the ``scheme`` using
``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is
provided, it is merged with the instance's ``connection_pool_kw``
variable and used to create the new connection pool, if one is
needed.
"""
if not host:
raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs)
request_context["scheme"] = scheme or "http"
if not port:
port = port_by_scheme.get(request_context["scheme"].lower(), 80)
request_context["port"] = port
request_context["host"] = host
return self.connection_from_context(request_context)
def connection_from_context(self, request_context):
"""
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context.
``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable.
"""
scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme.get(scheme)
if not pool_key_constructor:
raise URLSchemeUnknown(scheme)
pool_key = pool_key_constructor(request_context)
return self.connection_from_pool_key(pool_key, request_context=request_context)
def connection_from_pool_key(self, pool_key, request_context=None):
"""
Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key.
``pool_key`` should be a namedtuple that only contains immutable
objects. At a minimum it must have the ``scheme``, ``host``, and
``port`` fields.
"""
with self.pools.lock:
# If the scheme, host, or port doesn't match existing open
# connections, open a new ConnectionPool.
pool = self.pools.get(pool_key)
if pool:
return pool
# Make a fresh ConnectionPool of the desired type
scheme = request_context["scheme"]
host = request_context["host"]
port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool
return pool
def connection_from_url(self, url, pool_kwargs=None):
"""
Similar to :func:`urllib3.connectionpool.connection_from_url`.
If ``pool_kwargs`` is not provided and a new pool needs to be
constructed, ``self.connection_pool_kw`` is used to initialize
the :class:`urllib3.connectionpool.ConnectionPool`. If ``pool_kwargs``
is provided, it is used instead. Note that if a new pool does not
need to be created for the request, the provided ``pool_kwargs`` are
not used.
"""
u = parse_url(url)
return self.connection_from_host(
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
)
def _merge_pool_kwargs(self, override):
"""
Merge a dictionary of override values for self.connection_pool_kw.
This does not modify self.connection_pool_kw and returns a new dict.
Any keys in the override dictionary with a value of ``None`` are
removed from the merged dictionary.
"""
base_pool_kwargs = self.connection_pool_kw.copy()
if override:
for key, value in override.items():
if value is None:
try:
del base_pool_kwargs[key]
except KeyError:
pass
else:
base_pool_kwargs[key] = value
return base_pool_kwargs
def _proxy_requires_url_absolute_form(self, parsed_url):
"""
Indicates if the proxy requires the complete destination URL in the
request. Normally this is only needed when not using an HTTP CONNECT
tunnel.
"""
if self.proxy is None:
return False
return not connection_requires_http_tunnel(
self.proxy, self.proxy_config, parsed_url.scheme
)
def _validate_proxy_scheme_url_selection(self, url_scheme):
"""
Validates that were not attempting to do TLS in TLS connections on
Python2 or with unsupported SSL implementations.
"""
if self.proxy is None or url_scheme != "https":
return
if self.proxy.scheme != "https":
return
if six.PY2 and not self.proxy_config.use_forwarding_for_https:
raise ProxySchemeUnsupported(
"Contacting HTTPS destinations through HTTPS proxies "
"'via CONNECT tunnels' is not supported in Python 2"
)
def urlopen(self, method, url, redirect=True, **kw):
"""
Same as :meth:`urllib3.HTTPConnectionPool.urlopen`
with custom cross-host redirect logic and only sends the request-uri
portion of the ``url``.
The given ``url`` parameter must be absolute, such that an appropriate
:class:`urllib3.connectionpool.ConnectionPool` can be chosen for it.
"""
u = parse_url(url)
self._validate_proxy_scheme_url_selection(u.scheme)
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
kw["assert_same_host"] = False
kw["redirect"] = False
if "headers" not in kw:
kw["headers"] = self.headers.copy()
if self._proxy_requires_url_absolute_form(u):
response = conn.urlopen(method, url, **kw)
else:
response = conn.urlopen(method, u.request_uri, **kw)
redirect_location = redirect and response.get_redirect_location()
if not redirect_location:
return response
# Support relative URLs for redirecting.
redirect_location = urljoin(url, redirect_location)
# RFC 7231, Section 6.4.4
if response.status == 303:
method = "GET"
retries = kw.get("retries")
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect)
# Strip headers marked as unsafe to forward to the redirected location.
# Check remove_headers_on_redirect to avoid a potential network call within
# conn.is_same_host() which may use socket.gethostbyname() in the future.
if retries.remove_headers_on_redirect and not conn.is_same_host(
redirect_location
):
headers = list(six.iterkeys(kw["headers"]))
for header in headers:
if header.lower() in retries.remove_headers_on_redirect:
kw["headers"].pop(header, None)
try:
retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError:
if retries.raise_on_redirect:
response.drain_conn()
raise
return response
kw["retries"] = retries
kw["redirect"] = redirect
log.info("Redirecting %s -> %s", url, redirect_location)
response.drain_conn()
return self.urlopen(method, redirect_location, **kw)
class ProxyManager(PoolManager):
"""
Behaves just like :class:`PoolManager`, but sends all requests through
the defined proxy, using the CONNECT method for HTTPS URLs.
:param proxy_url:
The URL of the proxy to be used.
:param proxy_headers:
A dictionary containing headers that will be sent to the proxy. In case
of HTTP they are being sent with each request, while in the
HTTPS/CONNECT case they are sent only once. Could be used for proxy
authentication.
:param proxy_ssl_context:
The proxy SSL context is used to establish the TLS connection to the
proxy when using HTTPS proxies.
:param use_forwarding_for_https:
(Defaults to False) If set to True will forward requests to the HTTPS
proxy to be made on behalf of the client instead of creating a TLS
tunnel via the CONNECT method. **Enabling this flag means that request
and response headers and content will be visible from the HTTPS proxy**
whereas tunneling keeps request and response headers and content
private. IP address, target hostname, SNI, and port are always visible
to an HTTPS proxy even when this flag is disabled.
Example:
>>> proxy = urllib3.ProxyManager('http://localhost:3128/')
>>> r1 = proxy.request('GET', 'http://google.com/')
>>> r2 = proxy.request('GET', 'http://httpbin.org/')
>>> len(proxy.pools)
1
>>> r3 = proxy.request('GET', 'https://httpbin.org/')
>>> r4 = proxy.request('GET', 'https://twitter.com/')
>>> len(proxy.pools)
3
"""
def __init__(
self,
proxy_url,
num_pools=10,
headers=None,
proxy_headers=None,
proxy_ssl_context=None,
use_forwarding_for_https=False,
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = "%s://%s:%i" % (
proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url)
if proxy.scheme not in ("http", "https"):
raise ProxySchemeUnknown(proxy.scheme)
if not proxy.port:
port = port_by_scheme.get(proxy.scheme, 80)
proxy = proxy._replace(port=port)
self.proxy = proxy
self.proxy_headers = proxy_headers or {}
self.proxy_ssl_context = proxy_ssl_context
self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https)
connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers
connection_pool_kw["_proxy_config"] = self.proxy_config
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https":
return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host(
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None):
"""
Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user.
"""
headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc
if netloc:
headers_["Host"] = netloc
if headers:
headers_.update(headers)
return headers_
def urlopen(self, method, url, redirect=True, **kw):
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
u = parse_url(url)
if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme):
# For connections using HTTP CONNECT, httplib sets the necessary
# headers on the CONNECT to the proxy. If we're not using CONNECT,
# we'll definitely need to set 'Host' at the very least.
headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)
def proxy_from_url(url, **kw):
return ProxyManager(proxy_url=url, **kw)

View file

@ -1,170 +0,0 @@
from __future__ import absolute_import
from .filepost import encode_multipart_formdata
from .packages.six.moves.urllib.parse import urlencode
__all__ = ["RequestMethods"]
class RequestMethods(object):
"""
Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`urllib3.HTTPConnectionPool` and
:class:`urllib3.PoolManager`.
Provides behavior for making common types of HTTP request methods and
decides which type of request field encoding to use.
Specifically,
:meth:`.request_encode_url` is for sending requests whose fields are
encoded in the URL (such as GET, HEAD, DELETE).
:meth:`.request_encode_body` is for sending requests whose fields are
encoded in the *body* of the request using multipart or www-form-urlencoded
(such as for POST, PUT, PATCH).
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
Initializer parameters:
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers=None):
self.headers = headers or {}
def urlopen(
self,
method,
url,
body=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**kw
): # Abstract
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
def request(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the appropriate encoding of
``fields`` based on the ``method`` used.
This is a convenience method that requires the least amount of manual
effort. It can be used in most situations, while still having the
option to drop down to more specific methods when necessary, such as
:meth:`request_encode_url`, :meth:`request_encode_body`,
or even the lowest level :meth:`urlopen`.
"""
method = method.upper()
urlopen_kw["request_url"] = url
if method in self._encode_url_methods:
return self.request_encode_url(
method, url, fields=fields, headers=headers, **urlopen_kw
)
else:
return self.request_encode_body(
method, url, fields=fields, headers=headers, **urlopen_kw
)
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw)
def request_encode_body(
self,
method,
url,
fields=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**urlopen_kw
):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc.
When ``encode_multipart=True`` (default), then
:func:`urllib3.encode_multipart_formdata` is used to encode
the payload with the appropriate content type. Otherwise
:func:`urllib.parse.urlencode` is used with the
'application/x-www-form-urlencoded' content type.
Multipart encoding must be used when posting files, and it's reasonably
safe to use it in other times too. However, it may break request
signing, such as with OAuth.
Supports an optional ``fields`` parameter of key/value strings AND
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
the MIME type is optional. For example::
fields = {
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(),
'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
}
When uploading a file, providing a filename (the first parameter of the
tuple) is optional but recommended to best mimic behavior of browsers.
Note that if ``headers`` are supplied, the 'Content-Type' header will
be overwritten because it depends on the dynamic random boundary string
which is used to compose the body of the request. The random boundary
string can be explicitly set with the ``multipart_boundary`` parameter.
"""
if headers is None:
headers = self.headers
extra_kw = {"headers": {}}
if fields:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart:
body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else:
body, content_type = (
urlencode(fields),
"application/x-www-form-urlencoded",
)
extra_kw["body"] = body
extra_kw["headers"] = {"Content-Type": content_type}
extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw)

Some files were not shown because too many files have changed in this diff Show more