added websocket, socketio and engineio into python 2 vendor

This commit is contained in:
Jakub Trllo 2022-06-08 16:17:03 +02:00
parent 5f3395683a
commit 1494c09d2a
72 changed files with 12949 additions and 0 deletions

View file

@ -0,0 +1,25 @@
import sys
from .client import Client
from .middleware import WSGIApp, Middleware
from .server import Server
if sys.version_info >= (3, 5): # pragma: no cover
from .asyncio_server import AsyncServer
from .asyncio_client import AsyncClient
from .async_drivers.asgi import ASGIApp
try:
from .async_drivers.tornado import get_tornado_handler
except ImportError:
get_tornado_handler = None
else: # pragma: no cover
AsyncServer = None
AsyncClient = None
get_tornado_handler = None
ASGIApp = None
__version__ = '3.8.2.post1'
__all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client']
if AsyncServer is not None: # pragma: no cover
__all__ += ['AsyncServer', 'ASGIApp', 'get_tornado_handler',
'AsyncClient'],

View file

@ -0,0 +1,129 @@
import asyncio
import sys
from urllib.parse import urlsplit
from aiohttp.web import Response, WebSocketResponse
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.router.add_get(engineio_endpoint, engineio_server.handle_request)
app.router.add_post(engineio_endpoint, engineio_server.handle_request)
app.router.add_route('OPTIONS', engineio_endpoint,
engineio_server.handle_request)
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
message = request._message
payload = request._payload
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'aiohttp.request': request
}
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
return Response(body=payload, status=int(status.split()[0]),
headers=headers)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a aiohttp WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['aiohttp.request']
self._sock = WebSocketResponse()
await self._sock.prepare(request)
self.environ = environ
await self.handler(self)
return self._sock
async def close(self):
await self._sock.close()
async def send(self, message):
if isinstance(message, bytes):
f = self._sock.send_bytes
else:
f = self._sock.send_str
if asyncio.iscoroutinefunction(f):
await f(message)
else:
f(message)
async def wait(self):
msg = await self._sock.receive()
if not isinstance(msg.data, six.binary_type) and \
not isinstance(msg.data, six.text_type):
raise IOError()
return msg.data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': sys.modules[__name__],
'websocket_class': 'WebSocket'
}

View file

@ -0,0 +1,223 @@
import sys
class ASGIApp:
"""ASGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param engineio_server: The Engine.IO server. Must be an instance of the
``engineio.AsyncServer`` class.
:param static_files: A dictionary where the keys are URLs that should be
served as static files. For each URL, the value is
a dictionary with ``content_type`` and ``filename``
keys. This option is intended to be used for serving
client files during development.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import uvicorn
eio = engineio.AsyncServer()
app = engineio.ASGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
uvicorn.run(app, '127.0.0.1', 5000)
"""
def __init__(self, engineio_server, other_asgi_app=None,
static_files=None, engineio_path='engine.io'):
self.engineio_server = engineio_server
self.other_asgi_app = other_asgi_app
self.engineio_path = engineio_path.strip('/')
self.static_files = static_files or {}
def __call__(self, scope):
if scope['type'] in ['http', 'websocket'] and \
scope['path'].startswith('/{0}/'.format(self.engineio_path)):
return self.engineio_asgi_app(scope)
elif scope['type'] == 'http' and scope['path'] in self.static_files:
return self.serve_static_file(scope)
elif self.other_asgi_app is not None:
return self.other_asgi_app(scope)
elif scope['type'] == 'lifespan':
return self.lifespan
else:
return self.not_found
def engineio_asgi_app(self, scope):
async def _app(receive, send):
await self.engineio_server.handle_request(scope, receive, send)
return _app
def serve_static_file(self, scope):
async def _send_static_file(receive, send): # pragma: no cover
event = await receive()
if event['type'] == 'http.request':
if scope['path'] in self.static_files:
content_type = self.static_files[scope['path']][
'content_type'].encode('utf-8')
filename = self.static_files[scope['path']]['filename']
status_code = 200
with open(filename, 'rb') as f:
payload = f.read()
else:
content_type = b'text/plain'
status_code = 404
payload = b'not found'
await send({'type': 'http.response.start',
'status': status_code,
'headers': [(b'Content-Type', content_type)]})
await send({'type': 'http.response.body',
'body': payload})
return _send_static_file
async def lifespan(self, receive, send):
event = await receive()
if event['type'] == 'lifespan.startup':
await send({'type': 'lifespan.startup.complete'})
elif event['type'] == 'lifespan.shutdown':
await send({'type': 'lifespan.shutdown.complete'})
async def not_found(self, receive, send):
"""Return a 404 Not Found error to the client."""
await send({'type': 'http.response.start',
'status': 404,
'headers': [(b'Content-Type', b'text/plain')]})
await send({'type': 'http.response.body',
'body': b'not found'})
async def translate_request(scope, receive, send):
class AwaitablePayload(object): # pragma: no cover
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
event = await receive()
payload = b''
if event['type'] == 'http.request':
payload += event.get('body') or b''
while event.get('more_body'):
event = await receive()
if event['type'] == 'http.request':
payload += event.get('body') or b''
elif event['type'] == 'websocket.connect':
await send({'type': 'websocket.accept'})
else:
return {}
raw_uri = scope['path'].encode('utf-8')
if 'query_string' in scope and scope['query_string']:
raw_uri += b'?' + scope['query_string']
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'asgi',
'REQUEST_METHOD': scope.get('method', 'GET'),
'PATH_INFO': scope['path'],
'QUERY_STRING': scope.get('query_string', b'').decode('utf-8'),
'RAW_URI': raw_uri.decode('utf-8'),
'SCRIPT_NAME': '',
'SERVER_PROTOCOL': 'HTTP/1.1',
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'asgi',
'SERVER_PORT': '0',
'asgi.receive': receive,
'asgi.send': send,
}
for hdr_name, hdr_value in scope['headers']:
hdr_name = hdr_name.upper().decode('utf-8')
hdr_value = hdr_value.decode('utf-8')
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
return environ
async def make_response(status, headers, payload, environ):
headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers]
await environ['asgi.send']({'type': 'http.response.start',
'status': int(status.split(' ')[0]),
'headers': headers})
await environ['asgi.send']({'type': 'http.response.body',
'body': payload})
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides an asgi WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.asgi_receive = None
self.asgi_send = None
async def __call__(self, environ):
self.asgi_receive = environ['asgi.receive']
self.asgi_send = environ['asgi.send']
await self.handler(self)
async def close(self):
await self.asgi_send({'type': 'websocket.close'})
async def send(self, message):
msg_bytes = None
msg_text = None
if isinstance(message, bytes):
msg_bytes = message
else:
msg_text = message
await self.asgi_send({'type': 'websocket.send',
'bytes': msg_bytes,
'text': msg_text})
async def wait(self):
event = await self.asgi_receive()
if event['type'] != 'websocket.receive':
raise IOError()
return event.get('bytes') or event.get('text')
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': sys.modules[__name__],
'websocket_class': 'WebSocket'
}

View file

@ -0,0 +1,128 @@
import asyncio
import sys
from urllib.parse import urlsplit
from aiohttp.web import Response, WebSocketResponse
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.router.add_get(engineio_endpoint, engineio_server.handle_request)
app.router.add_post(engineio_endpoint, engineio_server.handle_request)
app.router.add_route('OPTIONS', engineio_endpoint,
engineio_server.handle_request)
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
message = request._message
payload = request._payload
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'aiohttp.request': request
}
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
return Response(body=payload, status=int(status.split()[0]),
headers=headers)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a aiohttp WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['aiohttp.request']
self._sock = WebSocketResponse()
await self._sock.prepare(request)
self.environ = environ
await self.handler(self)
return self._sock
async def close(self):
await self._sock.close()
async def send(self, message):
if isinstance(message, bytes):
f = self._sock.send_bytes
else:
f = self._sock.send_str
if asyncio.iscoroutinefunction(f):
await f(message)
else:
f(message)
async def wait(self):
msg = await self._sock.receive()
if not isinstance(msg.data, six.binary_type) and \
not isinstance(msg.data, six.text_type):
raise IOError()
return msg.data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,214 @@
import os
import sys
from engineio.static_files import get_static_file
class ASGIApp:
"""ASGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param engineio_server: The Engine.IO server. Must be an instance of the
``engineio.AsyncServer`` class.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import uvicorn
eio = engineio.AsyncServer()
app = engineio.ASGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
uvicorn.run(app, '127.0.0.1', 5000)
"""
def __init__(self, engineio_server, other_asgi_app=None,
static_files=None, engineio_path='engine.io'):
self.engineio_server = engineio_server
self.other_asgi_app = other_asgi_app
self.engineio_path = engineio_path.strip('/')
self.static_files = static_files or {}
async def __call__(self, scope, receive, send):
if scope['type'] in ['http', 'websocket'] and \
scope['path'].startswith('/{0}/'.format(self.engineio_path)):
await self.engineio_server.handle_request(scope, receive, send)
else:
static_file = get_static_file(scope['path'], self.static_files) \
if scope['type'] == 'http' and self.static_files else None
if static_file:
await self.serve_static_file(static_file, receive, send)
elif self.other_asgi_app is not None:
await self.other_asgi_app(scope, receive, send)
elif scope['type'] == 'lifespan':
await self.lifespan(receive, send)
else:
await self.not_found(receive, send)
async def serve_static_file(self, static_file, receive,
send): # pragma: no cover
event = await receive()
if event['type'] == 'http.request':
if os.path.exists(static_file['filename']):
with open(static_file['filename'], 'rb') as f:
payload = f.read()
await send({'type': 'http.response.start',
'status': 200,
'headers': [(b'Content-Type', static_file[
'content_type'].encode('utf-8'))]})
await send({'type': 'http.response.body',
'body': payload})
else:
await self.not_found(receive, send)
async def lifespan(self, receive, send):
event = await receive()
if event['type'] == 'lifespan.startup':
await send({'type': 'lifespan.startup.complete'})
elif event['type'] == 'lifespan.shutdown':
await send({'type': 'lifespan.shutdown.complete'})
async def not_found(self, receive, send):
"""Return a 404 Not Found error to the client."""
await send({'type': 'http.response.start',
'status': 404,
'headers': [(b'Content-Type', b'text/plain')]})
await send({'type': 'http.response.body',
'body': b'Not Found'})
async def translate_request(scope, receive, send):
class AwaitablePayload(object): # pragma: no cover
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
event = await receive()
payload = b''
if event['type'] == 'http.request':
payload += event.get('body') or b''
while event.get('more_body'):
event = await receive()
if event['type'] == 'http.request':
payload += event.get('body') or b''
elif event['type'] == 'websocket.connect':
await send({'type': 'websocket.accept'})
else:
return {}
raw_uri = scope['path'].encode('utf-8')
if 'query_string' in scope and scope['query_string']:
raw_uri += b'?' + scope['query_string']
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'asgi',
'REQUEST_METHOD': scope.get('method', 'GET'),
'PATH_INFO': scope['path'],
'QUERY_STRING': scope.get('query_string', b'').decode('utf-8'),
'RAW_URI': raw_uri.decode('utf-8'),
'SCRIPT_NAME': '',
'SERVER_PROTOCOL': 'HTTP/1.1',
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'asgi',
'SERVER_PORT': '0',
'asgi.receive': receive,
'asgi.send': send,
}
for hdr_name, hdr_value in scope['headers']:
hdr_name = hdr_name.upper().decode('utf-8')
hdr_value = hdr_value.decode('utf-8')
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
return environ
async def make_response(status, headers, payload, environ):
headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers]
await environ['asgi.send']({'type': 'http.response.start',
'status': int(status.split(' ')[0]),
'headers': headers})
await environ['asgi.send']({'type': 'http.response.body',
'body': payload})
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides an asgi WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.asgi_receive = None
self.asgi_send = None
async def __call__(self, environ):
self.asgi_receive = environ['asgi.receive']
self.asgi_send = environ['asgi.send']
await self.handler(self)
async def close(self):
await self.asgi_send({'type': 'websocket.close'})
async def send(self, message):
msg_bytes = None
msg_text = None
if isinstance(message, bytes):
msg_bytes = message
else:
msg_text = message
await self.asgi_send({'type': 'websocket.send',
'bytes': msg_bytes,
'text': msg_text})
async def wait(self):
event = await self.asgi_receive()
if event['type'] != 'websocket.receive':
raise IOError()
return event.get('bytes') or event.get('text')
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,30 @@
from __future__ import absolute_import
from eventlet.green.threading import Thread, Event
from eventlet import queue
from eventlet import sleep
from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI
class WebSocketWSGI(_WebSocketWSGI):
def __init__(self, *args, **kwargs):
super(WebSocketWSGI, self).__init__(*args, **kwargs)
self._sock = None
def __call__(self, environ, start_response):
if 'eventlet.input' not in environ:
raise RuntimeError('You need to use the eventlet server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['eventlet.input'].get_socket()
return super(WebSocketWSGI, self).__call__(environ, start_response)
_async = {
'thread': Thread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI,
'sleep': sleep,
}

View file

@ -0,0 +1,63 @@
from __future__ import absolute_import
import gevent
from gevent import queue
from gevent.event import Event
try:
import geventwebsocket # noqa
_websocket_available = True
except ImportError:
_websocket_available = False
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class WebSocketWSGI(object): # pragma: no cover
"""
This wrapper class provides a gevent WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
if 'wsgi.websocket' not in environ:
raise RuntimeError('You need to use the gevent-websocket server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['wsgi.websocket']
self.environ = environ
self.version = self._sock.version
self.path = self._sock.path
self.origin = self._sock.origin
self.protocol = self._sock.protocol
return self.app(self)
def close(self):
return self._sock.close()
def send(self, message):
return self._sock.send(message)
def wait(self):
return self._sock.receive()
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': WebSocketWSGI if _websocket_available else None,
'sleep': gevent.sleep,
}

View file

@ -0,0 +1,156 @@
from __future__ import absolute_import
import six
import gevent
from gevent import queue
from gevent.event import Event
import uwsgi
_websocket_available = hasattr(uwsgi, 'websocket_handshake')
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class uWSGIWebSocket(object): # pragma: no cover
"""
This wrapper class provides a uWSGI WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
self._sock = None
def __call__(self, environ, start_response):
self._sock = uwsgi.connection_fd()
self.environ = environ
uwsgi.websocket_handshake()
self._req_ctx = None
if hasattr(uwsgi, 'request_context'):
# uWSGI >= 2.1.x with support for api access across-greenlets
self._req_ctx = uwsgi.request_context()
else:
# use event and queue for sending messages
from gevent.event import Event
from gevent.queue import Queue
from gevent.select import select
self._event = Event()
self._send_queue = Queue()
# spawn a select greenlet
def select_greenlet_runner(fd, event):
"""Sets event when data becomes available to read on fd."""
while True:
event.set()
try:
select([fd], [], [])[0]
except ValueError:
break
self._select_greenlet = gevent.spawn(
select_greenlet_runner,
self._sock,
self._event)
self.app(self)
def close(self):
"""Disconnects uWSGI from the client."""
uwsgi.disconnect()
if self._req_ctx is None:
# better kill it here in case wait() is not called again
self._select_greenlet.kill()
self._event.set()
def _send(self, msg):
"""Transmits message either in binary or UTF-8 text mode,
depending on its type."""
if isinstance(msg, six.binary_type):
method = uwsgi.websocket_send_binary
else:
method = uwsgi.websocket_send
if self._req_ctx is not None:
method(msg, request_context=self._req_ctx)
else:
method(msg)
def _decode_received(self, msg):
"""Returns either bytes or str, depending on message type."""
if not isinstance(msg, six.binary_type):
# already decoded - do nothing
return msg
# only decode from utf-8 if message is not binary data
type = six.byte2int(msg[0:1])
if type >= 48: # no binary
return msg.decode('utf-8')
# binary message, don't try to decode
return msg
def send(self, msg):
"""Queues a message for sending. Real transmission is done in
wait method.
Sends directly if uWSGI version is new enough."""
if self._req_ctx is not None:
self._send(msg)
else:
self._send_queue.put(msg)
self._event.set()
def wait(self):
"""Waits and returns received messages.
If running in compatibility mode for older uWSGI versions,
it also sends messages that have been queued by send().
A return value of None means that connection was closed.
This must be called repeatedly. For uWSGI < 2.1.x it must
be called from the main greenlet."""
while True:
if self._req_ctx is not None:
try:
msg = uwsgi.websocket_recv(request_context=self._req_ctx)
except IOError: # connection closed
return None
return self._decode_received(msg)
else:
# we wake up at least every 3 seconds to let uWSGI
# do its ping/ponging
event_set = self._event.wait(timeout=3)
if event_set:
self._event.clear()
# maybe there is something to send
msgs = []
while True:
try:
msgs.append(self._send_queue.get(block=False))
except gevent.queue.Empty:
break
for msg in msgs:
self._send(msg)
# maybe there is something to receive, if not, at least
# ensure uWSGI does its ping/ponging
try:
msg = uwsgi.websocket_recv_nb()
except IOError: # connection closed
self._select_greenlet.kill()
return None
if msg: # message available
return self._decode_received(msg)
_async = {
'thread': Thread,
'queue': queue.JoinableQueue,
'queue_empty': queue.Empty,
'event': Event,
'websocket': uWSGIWebSocket if _websocket_available else None,
'sleep': gevent.sleep,
}

View file

@ -0,0 +1,144 @@
import sys
from urllib.parse import urlsplit
from sanic.response import HTTPResponse
try:
from sanic.websocket import WebSocketProtocol
except ImportError:
# the installed version of sanic does not have websocket support
WebSocketProtocol = None
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.add_route(engineio_server.handle_request, engineio_endpoint,
methods=['GET', 'POST', 'OPTIONS'])
try:
app.enable_websocket()
except AttributeError:
# ignore, this version does not support websocket
pass
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
uri_parts = urlsplit(request.url)
environ = {
'wsgi.input': AwaitablePayload(request.body),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'sanic',
'REQUEST_METHOD': request.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': request.url,
'SERVER_PROTOCOL': 'HTTP/' + request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'sanic',
'SERVER_PORT': '0',
'sanic.request': request
}
for hdr_name, hdr_value in request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
headers_dict = {}
content_type = None
for h in headers:
if h[0].lower() == 'content-type':
content_type = h[1]
else:
headers_dict[h[0]] = h[1]
return HTTPResponse(body_bytes=payload, content_type=content_type,
status=int(status.split()[0]), headers=headers_dict)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a sanic WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['sanic.request']
protocol = request.transport.get_protocol()
self._sock = await protocol.websocket_handshake(request)
self.environ = environ
await self.handler(self)
async def close(self):
await self._sock.close()
async def send(self, message):
await self._sock.send(message)
async def wait(self):
data = await self._sock.recv()
if not isinstance(data, six.binary_type) and \
not isinstance(data, six.text_type):
raise IOError()
return data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket if WebSocketProtocol else None,
}

View file

@ -0,0 +1,17 @@
from __future__ import absolute_import
import threading
import time
try:
import queue
except ImportError: # pragma: no cover
import Queue as queue
_async = {
'thread': threading.Thread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': threading.Event,
'websocket': None,
'sleep': time.sleep,
}

View file

@ -0,0 +1,184 @@
import asyncio
import sys
from urllib.parse import urlsplit
from .. import exceptions
import tornado.web
import tornado.websocket
import six
def get_tornado_handler(engineio_server):
class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(engineio_server.cors_allowed_origins,
six.string_types):
if engineio_server.cors_allowed_origins == '*':
self.allowed_origins = None
else:
self.allowed_origins = [
engineio_server.cors_allowed_origins]
else:
self.allowed_origins = engineio_server.cors_allowed_origins
self.receive_queue = asyncio.Queue()
async def get(self, *args, **kwargs):
if self.request.headers.get('Upgrade', '').lower() == 'websocket':
ret = super().get(*args, **kwargs)
if asyncio.iscoroutine(ret):
await ret
else:
await engineio_server.handle_request(self)
async def open(self, *args, **kwargs):
# this is the handler for the websocket request
asyncio.ensure_future(engineio_server.handle_request(self))
async def post(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def options(self, *args, **kwargs):
await engineio_server.handle_request(self)
async def on_message(self, message):
await self.receive_queue.put(message)
async def get_next_message(self):
return await self.receive_queue.get()
def on_close(self):
self.receive_queue.put_nowait(None)
def check_origin(self, origin):
if self.allowed_origins is None or origin in self.allowed_origins:
return True
return super().check_origin(origin)
def get_compression_options(self):
# enable compression
return {}
return Handler
def translate_request(handler):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
payload = handler.request.body
uri_parts = urlsplit(handler.request.path)
full_uri = handler.request.path
if handler.request.query: # pragma: no cover
full_uri += '?' + handler.request.query
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': handler.request.method,
'QUERY_STRING': handler.request.query or '',
'RAW_URI': full_uri,
'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'tornado.handler': handler
}
for hdr_name, hdr_value in handler.request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
tornado_handler = environ['tornado.handler']
try:
tornado_handler.set_status(int(status.split()[0]))
except RuntimeError: # pragma: no cover
# for websocket connections Tornado does not accept a response, since
# it already emitted the 101 status code
return
for header, value in headers:
tornado_handler.set_header(header, value)
tornado_handler.write(payload)
tornado_handler.finish()
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a tornado WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.tornado_handler = None
async def __call__(self, environ):
self.tornado_handler = environ['tornado.handler']
self.environ = environ
await self.handler(self)
async def close(self):
self.tornado_handler.close()
async def send(self, message):
try:
self.tornado_handler.write_message(
message, binary=isinstance(message, bytes))
except tornado.websocket.WebSocketClosedError:
raise exceptions.EngineIOError()
async def wait(self):
msg = await self.tornado_handler.get_next_message()
if not isinstance(msg, six.binary_type) and \
not isinstance(msg, six.text_type):
raise IOError()
return msg
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': WebSocket,
}

View file

@ -0,0 +1,30 @@
import importlib
import sys
from eventlet import sleep
from eventlet.websocket import WebSocketWSGI as _WebSocketWSGI
class WebSocketWSGI(_WebSocketWSGI):
def __init__(self, *args, **kwargs):
super(WebSocketWSGI, self).__init__(*args, **kwargs)
self._sock = None
def __call__(self, environ, start_response):
if 'eventlet.input' not in environ:
raise RuntimeError('You need to use the eventlet server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['eventlet.input'].get_socket()
return super(WebSocketWSGI, self).__call__(environ, start_response)
_async = {
'threading': importlib.import_module('eventlet.green.threading'),
'thread_class': 'Thread',
'queue': importlib.import_module('eventlet.queue'),
'queue_class': 'Queue',
'websocket': sys.modules[__name__],
'websocket_class': 'WebSocketWSGI',
'sleep': sleep
}

View file

@ -0,0 +1,63 @@
import importlib
import sys
import gevent
try:
import geventwebsocket # noqa
_websocket_available = True
except ImportError:
_websocket_available = False
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class WebSocketWSGI(object): # pragma: no cover
"""
This wrapper class provides a gevent WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
if 'wsgi.websocket' not in environ:
raise RuntimeError('You need to use the gevent-websocket server. '
'See the Deployment section of the '
'documentation for more information.')
self._sock = environ['wsgi.websocket']
self.environ = environ
self.version = self._sock.version
self.path = self._sock.path
self.origin = self._sock.origin
self.protocol = self._sock.protocol
return self.app(self)
def close(self):
return self._sock.close()
def send(self, message):
return self._sock.send(message)
def wait(self):
return self._sock.receive()
_async = {
'threading': sys.modules[__name__],
'thread_class': 'Thread',
'queue': importlib.import_module('gevent.queue'),
'queue_class': 'JoinableQueue',
'websocket': sys.modules[__name__] if _websocket_available else None,
'websocket_class': 'WebSocketWSGI' if _websocket_available else None,
'sleep': gevent.sleep
}

View file

@ -0,0 +1,155 @@
import importlib
import sys
import six
import gevent
import uwsgi
_websocket_available = hasattr(uwsgi, 'websocket_handshake')
class Thread(gevent.Greenlet): # pragma: no cover
"""
This wrapper class provides gevent Greenlet interface that is compatible
with the standard library's Thread class.
"""
def __init__(self, target, args=[], kwargs={}):
super(Thread, self).__init__(target, *args, **kwargs)
def _run(self):
return self.run()
class uWSGIWebSocket(object): # pragma: no cover
"""
This wrapper class provides a uWSGI WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
self._sock = None
def __call__(self, environ, start_response):
self._sock = uwsgi.connection_fd()
self.environ = environ
uwsgi.websocket_handshake()
self._req_ctx = None
if hasattr(uwsgi, 'request_context'):
# uWSGI >= 2.1.x with support for api access across-greenlets
self._req_ctx = uwsgi.request_context()
else:
# use event and queue for sending messages
from gevent.event import Event
from gevent.queue import Queue
from gevent.select import select
self._event = Event()
self._send_queue = Queue()
# spawn a select greenlet
def select_greenlet_runner(fd, event):
"""Sets event when data becomes available to read on fd."""
while True:
event.set()
try:
select([fd], [], [])[0]
except ValueError:
break
self._select_greenlet = gevent.spawn(
select_greenlet_runner,
self._sock,
self._event)
self.app(self)
def close(self):
"""Disconnects uWSGI from the client."""
uwsgi.disconnect()
if self._req_ctx is None:
# better kill it here in case wait() is not called again
self._select_greenlet.kill()
self._event.set()
def _send(self, msg):
"""Transmits message either in binary or UTF-8 text mode,
depending on its type."""
if isinstance(msg, six.binary_type):
method = uwsgi.websocket_send_binary
else:
method = uwsgi.websocket_send
if self._req_ctx is not None:
method(msg, request_context=self._req_ctx)
else:
method(msg)
def _decode_received(self, msg):
"""Returns either bytes or str, depending on message type."""
if not isinstance(msg, six.binary_type):
# already decoded - do nothing
return msg
# only decode from utf-8 if message is not binary data
type = six.byte2int(msg[0:1])
if type >= 48: # no binary
return msg.decode('utf-8')
# binary message, don't try to decode
return msg
def send(self, msg):
"""Queues a message for sending. Real transmission is done in
wait method.
Sends directly if uWSGI version is new enough."""
if self._req_ctx is not None:
self._send(msg)
else:
self._send_queue.put(msg)
self._event.set()
def wait(self):
"""Waits and returns received messages.
If running in compatibility mode for older uWSGI versions,
it also sends messages that have been queued by send().
A return value of None means that connection was closed.
This must be called repeatedly. For uWSGI < 2.1.x it must
be called from the main greenlet."""
while True:
if self._req_ctx is not None:
try:
msg = uwsgi.websocket_recv(request_context=self._req_ctx)
except IOError: # connection closed
return None
return self._decode_received(msg)
else:
# we wake up at least every 3 seconds to let uWSGI
# do its ping/ponging
event_set = self._event.wait(timeout=3)
if event_set:
self._event.clear()
# maybe there is something to send
msgs = []
while True:
try:
msgs.append(self._send_queue.get(block=False))
except gevent.queue.Empty:
break
for msg in msgs:
self._send(msg)
# maybe there is something to receive, if not, at least
# ensure uWSGI does its ping/ponging
try:
msg = uwsgi.websocket_recv_nb()
except IOError: # connection closed
self._select_greenlet.kill()
return None
if msg: # message available
return self._decode_received(msg)
_async = {
'threading': sys.modules[__name__],
'thread_class': 'Thread',
'queue': importlib.import_module('gevent.queue'),
'queue_class': 'JoinableQueue',
'websocket': sys.modules[__name__] if _websocket_available else None,
'websocket_class': 'uWSGIWebSocket' if _websocket_available else None,
'sleep': gevent.sleep
}

View file

@ -0,0 +1,145 @@
import sys
from urllib.parse import urlsplit
from sanic.response import HTTPResponse
try:
from sanic.websocket import WebSocketProtocol
except ImportError:
# the installed version of sanic does not have websocket support
WebSocketProtocol = None
import six
def create_route(app, engineio_server, engineio_endpoint):
"""This function sets up the engine.io endpoint as a route for the
application.
Note that both GET and POST requests must be hooked up on the engine.io
endpoint.
"""
app.add_route(engineio_server.handle_request, engineio_endpoint,
methods=['GET', 'POST', 'OPTIONS'])
try:
app.enable_websocket()
except AttributeError:
# ignore, this version does not support websocket
pass
def translate_request(request):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
uri_parts = urlsplit(request.url)
environ = {
'wsgi.input': AwaitablePayload(request.body),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'sanic',
'REQUEST_METHOD': request.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': request.url,
'SERVER_PROTOCOL': 'HTTP/' + request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'sanic',
'SERVER_PORT': '0',
'sanic.request': request
}
for hdr_name, hdr_value in request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
headers_dict = {}
content_type = None
for h in headers:
if h[0].lower() == 'content-type':
content_type = h[1]
else:
headers_dict[h[0]] = h[1]
return HTTPResponse(body_bytes=payload, content_type=content_type,
status=int(status.split()[0]), headers=headers_dict)
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a sanic WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self._sock = None
async def __call__(self, environ):
request = environ['sanic.request']
protocol = request.transport.get_protocol()
self._sock = await protocol.websocket_handshake(request)
self.environ = environ
await self.handler(self)
async def close(self):
await self._sock.close()
async def send(self, message):
await self._sock.send(message)
async def wait(self):
data = await self._sock.recv()
if not isinstance(data, six.binary_type) and \
not isinstance(data, six.text_type):
raise IOError()
return data
_async = {
'asyncio': True,
'create_route': create_route,
'translate_request': translate_request,
'make_response': make_response,
'websocket': sys.modules[__name__] if WebSocketProtocol else None,
'websocket_class': 'WebSocket' if WebSocketProtocol else None
}

View file

@ -0,0 +1,17 @@
import importlib
import time
try:
queue = importlib.import_module('queue')
except ImportError: # pragma: no cover
queue = importlib.import_module('Queue') # pragma: no cover
_async = {
'threading': importlib.import_module('threading'),
'thread_class': 'Thread',
'queue': queue,
'queue_class': 'Queue',
'websocket': None,
'websocket_class': None,
'sleep': time.sleep
}

View file

@ -0,0 +1,154 @@
import asyncio
import sys
from urllib.parse import urlsplit
try:
import tornado.web
import tornado.websocket
except ImportError: # pragma: no cover
pass
import six
def get_tornado_handler(engineio_server):
class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.receive_queue = asyncio.Queue()
async def get(self):
if self.request.headers.get('Upgrade', '').lower() == 'websocket':
super().get()
await engineio_server.handle_request(self)
async def post(self):
await engineio_server.handle_request(self)
async def options(self):
await engineio_server.handle_request(self)
async def on_message(self, message):
await self.receive_queue.put(message)
async def get_next_message(self):
return await self.receive_queue.get()
def on_close(self):
self.receive_queue.put_nowait(None)
return Handler
def translate_request(handler):
"""This function takes the arguments passed to the request handler and
uses them to generate a WSGI compatible environ dictionary.
"""
class AwaitablePayload(object):
def __init__(self, payload):
self.payload = payload or b''
async def read(self, length=None):
if length is None:
r = self.payload
self.payload = b''
else:
r = self.payload[:length]
self.payload = self.payload[length:]
return r
payload = handler.request.body
uri_parts = urlsplit(handler.request.path)
full_uri = handler.request.path
if handler.request.query: # pragma: no cover
full_uri += '?' + handler.request.query
environ = {
'wsgi.input': AwaitablePayload(payload),
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': 'aiohttp',
'REQUEST_METHOD': handler.request.method,
'QUERY_STRING': handler.request.query or '',
'RAW_URI': full_uri,
'SERVER_PROTOCOL': 'HTTP/%s' % handler.request.version,
'REMOTE_ADDR': '127.0.0.1',
'REMOTE_PORT': '0',
'SERVER_NAME': 'aiohttp',
'SERVER_PORT': '0',
'tornado.handler': handler
}
for hdr_name, hdr_value in handler.request.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
environ[key] = hdr_value
environ['wsgi.url_scheme'] = environ.get('HTTP_X_FORWARDED_PROTO', 'http')
path_info = uri_parts.path
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = ''
return environ
def make_response(status, headers, payload, environ):
"""This function generates an appropriate response object for this async
mode.
"""
tornado_handler = environ['tornado.handler']
tornado_handler.set_status(int(status.split()[0]))
for header, value in headers:
tornado_handler.set_header(header, value)
tornado_handler.write(payload)
tornado_handler.finish()
class WebSocket(object): # pragma: no cover
"""
This wrapper class provides a tornado WebSocket interface that is
somewhat compatible with eventlet's implementation.
"""
def __init__(self, handler):
self.handler = handler
self.tornado_handler = None
async def __call__(self, environ):
self.tornado_handler = environ['tornado.handler']
self.environ = environ
await self.handler(self)
async def close(self):
self.tornado_handler.close()
async def send(self, message):
self.tornado_handler.write_message(
message, binary=isinstance(message, bytes))
async def wait(self):
msg = await self.tornado_handler.get_next_message()
if not isinstance(msg, six.binary_type) and \
not isinstance(msg, six.text_type):
raise IOError()
return msg
_async = {
'asyncio': True,
'translate_request': translate_request,
'make_response': make_response,
'websocket': sys.modules[__name__],
'websocket_class': 'WebSocket'
}

View file

@ -0,0 +1,556 @@
import asyncio
try:
import aiohttp
except ImportError: # pragma: no cover
aiohttp = None
import six
try:
import websockets
except ImportError: # pragma: no cover
websockets = None
from . import client
from . import exceptions
from . import packet
from . import payload
class AsyncClient(client.Client):
"""An Engine.IO client for asyncio.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
"""
def is_asyncio_based(self):
return True
async def connect(self, url, headers={}, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Note: this method is a coroutine.
Example usage::
eio = engineio.Client()
await eio.connect('http://localhost:5000')
"""
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, six.text_type):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return await getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
async def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
Note: this method is a coroutine.
"""
if self.read_loop_task:
await self.read_loop_task
async def send(self, data, binary=None):
"""Send a message to a client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
Note: this method is a coroutine.
"""
await self._send_packet(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
async def disconnect(self, abort=False):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
Note: this method is a coroutine.
"""
if self.state == 'connected':
await self._send_packet(packet.Packet(packet.CLOSE))
await self.queue.put(None)
self.state = 'disconnecting'
await self._trigger_event('disconnect', run_async=False)
if self.current_transport == 'websocket':
await self.ws.close()
if not abort:
await self.read_loop_task
self.state = 'disconnected'
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
Note: this method is a coroutine.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self):
"""Create a queue object."""
q = asyncio.Queue()
q.Empty = asyncio.QueueEmpty
return q
def create_event(self):
"""Create an event object."""
return asyncio.Event()
def _reset(self):
if self.http: # pragma: no cover
asyncio.ensure_future(self.http.close())
super()._reset()
async def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if aiohttp is None: # pragma: no cover
self.logger.error('aiohttp not installed -- cannot make HTTP '
'requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers)
if r is None:
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status != 200:
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status))
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
six.raise_from(exceptions.ConnectionError(
'Unexpected response from server'), None)
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
await self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if await self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
async def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if websockets is None: # pragma: no cover
self.logger.error('websockets package not installed')
return False
websocket_url = self._get_engineio_url(url, engineio_path,
'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
# get the cookies from the long-polling connection so that they can
# also be sent the the WebSocket route
cookies = None
if self.http:
cookies = '; '.join(["{}={}".format(cookie.key, cookie.value)
for cookie in self.http._cookie_jar])
headers = headers.copy()
headers['Cookie'] = cookies
try:
ws = await websockets.connect(
websocket_url + self._get_url_timestamp(),
extra_headers=headers)
except (websockets.exceptions.InvalidURI,
websockets.exceptions.InvalidHandshake):
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING, data='probe').encode(
always_bytes=False)
try:
await ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = await ws.recv()
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode(always_bytes=False)
try:
await ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
if self.http: # pragma: no cover
await self.http.close()
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = await ws.recv()
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
client.connected_clients.append(self)
await self._trigger_event('connect', run_async=False)
self.ws = ws
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
async def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
await self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PONG:
self.pong_received = True
elif pkt.packet_type == packet.CLOSE:
await self.disconnect(abort=True)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
async def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
await self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
async def _send_request(
self, method, url, headers=None, body=None): # pragma: no cover
if self.http is None or self.http.closed:
self.http = aiohttp.ClientSession()
method = getattr(self.http, method.lower())
try:
return await method(url, headers=headers, data=body)
except aiohttp.ClientError:
return
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
if run_async:
return self.start_background_task(self.handlers[event],
*args)
else:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
if run_async:
async def async_handler():
return self.handlers[event](*args)
return self.start_background_task(async_handler)
else:
try:
ret = self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
async def _ping_loop(self):
"""This background task sends a PING to the server at the requested
interval.
"""
self.pong_received = True
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.info(
'PONG response has not been received, aborting')
if self.ws:
await self.ws.close()
await self.queue.put(None)
break
self.pong_received = False
await self._send_packet(packet.Packet(packet.PING))
try:
await asyncio.wait_for(self.ping_loop_event.wait(),
self.ping_interval)
except (asyncio.TimeoutError,
asyncio.CancelledError): # pragma: no cover
pass
self.logger.info('Exiting ping task')
async def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected':
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp())
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
await self.queue.put(None)
break
if r.status != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
await self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=await r.read())
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
await self.queue.put(None)
break
for pkt in p.packets:
await self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect', run_async=False)
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
async def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = await self.ws.recv()
except websockets.exceptions.ConnectionClosed:
self.logger.info(
'Read loop: WebSocket connection was closed, aborting')
await self.queue.put(None)
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
await self.queue.put(None)
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
await self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
await self.write_loop_task
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
await self.ping_loop_task
if self.state == 'connected':
await self._trigger_event('disconnect', run_async=False)
try:
client.connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
async def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [await asyncio.wait_for(self.queue.get(), timeout)]
except (self.queue.Empty, asyncio.TimeoutError,
asyncio.CancelledError):
self.logger.error('packet queue is empty, aborting')
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get_nowait())
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = await self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
for pkt in packets:
self.queue.task_done()
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
break
if r.status != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status)
self._reset()
break
else:
# websocket
try:
for pkt in packets:
await self.ws.send(pkt.encode(always_bytes=False))
self.queue.task_done()
except websockets.exceptions.ConnectionClosed:
self.logger.info(
'Write loop: WebSocket connection was closed, '
'aborting')
break
self.logger.info('Exiting write loop task')

View file

@ -0,0 +1,444 @@
import asyncio
import six
from six.moves import urllib
from . import exceptions
from . import packet
from . import server
from . import asyncio_socket
class AsyncServer(server.Server):
"""An Engine.IO server for asyncio.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "aiohttp",
"sanic", "tornado" and "asgi". If this argument is not
given, "aiohttp" is tried first, followed by "sanic",
"tornado", and finally "asgi". The first async mode that
has all its dependencies installed is the one that is
chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting.
:param ping_interval: The interval in seconds at which the client pings
the server.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport.
:param allow_upgrades: Whether to allow transport upgrades or not.
:param http_compression: Whether to compress packages when using the
polling transport.
:param compression_threshold: Only compress messages when their byte size
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: List of origins that are allowed to connect
to this server. All origins are allowed by
default.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
def is_asyncio_based(self):
return True
def async_modes(self):
return ['aiohttp', 'sanic', 'tornado', 'asgi']
def attach(self, app, engineio_path='engine.io'):
"""Attach the Engine.IO server to an application."""
engineio_path = engineio_path.strip('/')
self._async['create_route'](app, self, '/{}/'.format(engineio_path))
async def send(self, sid, data, binary=None):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
Note: this method is a coroutine.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
await socket.send(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
async def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
socket = self._get_socket(sid)
return socket.session
async def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
async def __aenter__(self):
self.session = await self.server.get_session(sid)
return self.session
async def __aexit__(self, *args):
await self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
async def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
Note: this method is a coroutine.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
await socket.close()
del self.sockets[sid]
else:
await asyncio.wait([client.close()
for client in six.itervalues(self.sockets)])
self.sockets = {}
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application. This function
returns the HTTP response to deliver to the client.
Note: this method is a coroutine.
"""
translate_request = self._async['translate_request']
if asyncio.iscoroutinefunction(translate_request):
environ = await translate_request(*args, **kwargs)
else:
environ = translate_request(*args, **kwargs)
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = await self._handle_connect(environ, transport,
b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = await socket.handle_get_request(environ)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
await socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
await self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
cors_headers = self._cors_headers(environ)
make_response = self._async['make_response']
if asyncio.iscoroutinefunction(make_response):
response = await make_response(r['status'],
r['headers'] + cors_headers,
r['response'], environ)
else:
response = make_response(r['status'], r['headers'] + cors_headers,
r['response'], environ)
return response
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
"""
return asyncio.ensure_future(target(*args, **kwargs))
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await asyncio.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode. For asyncio based async modes, this returns an instance of
``asyncio.Queue``.
"""
return asyncio.Queue(*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns an
instance of ``asyncio.QueueEmpty``.
"""
return asyncio.QueueEmpty
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode. For asyncio based async modes, this returns
an instance of ``asyncio.Event``.
"""
return asyncio.Event(*args, **kwargs)
async def _handle_connect(self, environ, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
sid = self._generate_id()
s = asyncio_socket.AsyncSocket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(
packet.OPEN, {'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(self.ping_interval * 1000)})
await s.send(pkt)
ret = await self._trigger_event('connect', sid, environ,
run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
if transport == 'websocket':
ret = await s.handle_get_request(environ)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(await s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
ret = None
if event in self.handlers:
if asyncio.iscoroutinefunction(self.handlers[event]) is True:
if run_async:
return self.start_background_task(self.handlers[event],
*args)
else:
try:
ret = await self.handlers[event](*args)
except asyncio.CancelledError: # pragma: no cover
pass
except:
self.logger.exception(event + ' async handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
else:
if run_async:
async def async_handler():
return self.handlers[event](*args)
return self.start_background_task(async_handler)
else:
try:
ret = self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
return ret
async def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
await self.sleep(self.ping_timeout)
continue
# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)
try:
# iterate over the current clients
for socket in self.sockets.copy().values():
if not socket.closing and not socket.closed:
await socket.check_ping_timeout()
await self.sleep(sleep_interval)
except (SystemExit, KeyboardInterrupt, asyncio.CancelledError):
self.logger.info('service task canceled')
break
except:
if asyncio.get_event_loop().is_closed():
self.logger.info('event loop is closed, exiting service '
'task')
break
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

View file

@ -0,0 +1,235 @@
import asyncio
import six
import sys
import time
from . import exceptions
from . import packet
from . import payload
from . import socket
class AsyncSocket(socket.Socket):
async def poll(self):
"""Wait for packets to send to the client."""
try:
packets = [await asyncio.wait_for(self.queue.get(),
self.server.ping_timeout)]
self.queue.task_done()
except (asyncio.TimeoutError, asyncio.CancelledError):
raise exceptions.QueueEmpty()
if packets == [None]:
return []
try:
packets.append(self.queue.get_nowait())
self.queue.task_done()
except asyncio.QueueEmpty:
pass
return packets
async def receive(self, pkt):
"""Receive packet from the client."""
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PING:
self.last_ping = time.time()
await self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
await self.server._trigger_event(
'message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
await self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
await self.close(wait=False, abort=True)
else:
raise exceptions.UnknownPacketError()
async def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_interval + 5:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
await self.close(wait=False, abort=False)
return False
return True
async def send(self, pkt):
"""Send a packet to the client."""
if not await self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
await self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
async def handle_get_request(self, environ):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return await getattr(self, '_upgrade_' + transport)(environ)
try:
packets = await self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
await self.close(wait=False)
six.reraise(*exc)
return packets
async def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = await environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
await self.receive(pkt)
async def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
await self.server._trigger_event('disconnect', self.sid)
if not abort:
await self.send(packet.Packet(packet.CLOSE))
self.closed = True
if wait:
await self.queue.join()
async def _upgrade_websocket(self, environ):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise IOError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](self._websocket_handler)
return await ws(environ)
async def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return
await ws.send(packet.Packet(
packet.PONG,
data=six.text_type('probe')).encode(always_bytes=False))
await self.queue.put(packet.Packet(packet.NOOP)) # end poll
try:
pkt = await ws.wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
return
self.upgraded = True
# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
await self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
async def writer():
while True:
packets = None
try:
packets = await self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
await ws.send(pkt.encode(always_bytes=False))
except:
break
writer_task = asyncio.ensure_future(writer())
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
wait_task = asyncio.ensure_future(ws.wait())
try:
p = await asyncio.wait_for(wait_task, self.server.ping_timeout)
except asyncio.CancelledError: # pragma: no cover
# there is a bug (https://bugs.python.org/issue30508) in
# asyncio that causes a "Task exception never retrieved" error
# to appear when wait_task raises an exception before it gets
# cancelled. Calling wait_task.exception() prevents the error
# from being issued in Python 3.6, but causes other errors in
# other versions, so we run it with all errors suppressed and
# hope for the best.
try:
wait_task.exception()
except:
pass
break
except:
break
if p is None:
# connection closed by client
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
try:
await self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
await self.queue.put(None) # unlock the writer task so it can exit
await asyncio.wait_for(writer_task, timeout=None)
await self.close(wait=False, abort=True)

View file

@ -0,0 +1,641 @@
import logging
try:
import queue
except ImportError: # pragma: no cover
import Queue as queue
import signal
import threading
import time
import six
from six.moves import urllib
try:
import requests
except ImportError: # pragma: no cover
requests = None
try:
import websocket
except ImportError: # pragma: no cover
websocket = None
from . import exceptions
from . import packet
from . import payload
default_logger = logging.getLogger('engineio.client')
connected_clients = []
if six.PY2: # pragma: no cover
ConnectionError = OSError
def signal_handler(sig, frame):
"""SIGINT handler.
Disconnect all active clients and then invoke the original signal handler.
"""
for client in connected_clients[:]:
if client.is_asyncio_based():
client.start_background_task(client.disconnect, abort=True)
else:
client.disconnect(abort=True)
return original_signal_handler(sig, frame)
original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
class Client(object):
"""An Engine.IO client.
This class implements a fully compliant Engine.IO web client with support
for websocket and long-polling transports.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
"""
event_names = ['connect', 'disconnect', 'message']
def __init__(self, logger=False, json=None):
self.handlers = {}
self.base_url = None
self.transports = None
self.current_transport = None
self.sid = None
self.upgrades = None
self.ping_interval = None
self.ping_timeout = None
self.pong_received = True
self.http = None
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
self.ping_loop_task = None
self.ping_loop_event = self.create_event()
self.queue = None
self.state = 'disconnected'
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
def is_asyncio_based(self):
return False
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler():
print('Connection request')
# as a method:
def message_handler(msg):
print('Received message: ', msg)
eio.send('response')
eio.on('message', message_handler)
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def connect(self, url, headers={}, transports=None,
engineio_path='engine.io'):
"""Connect to an Engine.IO server.
:param url: The URL of the Engine.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param engineio_path: The endpoint where the Engine.IO server is
installed. The default value is appropriate for
most cases.
Example usage::
eio = engineio.Client()
eio.connect('http://localhost:5000')
"""
if self.state != 'disconnected':
raise ValueError('Client is not in a disconnected state')
valid_transports = ['polling', 'websocket']
if transports is not None:
if isinstance(transports, six.string_types):
transports = [transports]
transports = [transport for transport in transports
if transport in valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or valid_transports
self.queue = self.create_queue()
return getattr(self, '_connect_' + self.transports[0])(
url, headers, engineio_path)
def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
"""
if self.read_loop_task:
self.read_loop_task.join()
def send(self, data, binary=None):
"""Send a message to a client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
"""
self._send_packet(packet.Packet(packet.MESSAGE, data=data,
binary=binary))
def disconnect(self, abort=False):
"""Disconnect from the server.
:param abort: If set to ``True``, do not wait for background tasks
associated with the connection to end.
"""
if self.state == 'connected':
self._send_packet(packet.Packet(packet.CLOSE))
self.queue.put(None)
self.state = 'disconnecting'
self._trigger_event('disconnect', run_async=False)
if self.current_transport == 'websocket':
self.ws.close()
if not abort:
self.read_loop_task.join()
self.state = 'disconnected'
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
def transport(self):
"""Return the name of the transport currently in use.
The possible values returned by this function are ``'polling'`` and
``'websocket'``.
"""
return self.current_transport
def start_background_task(self, target, *args, **kwargs):
"""Start a background task.
This is a utility function that applications can use to start a
background task.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
th = threading.Thread(target=target, args=args, kwargs=kwargs)
th.start()
return th
def sleep(self, seconds=0):
"""Sleep for the requested amount of time."""
return time.sleep(seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object."""
q = queue.Queue(*args, **kwargs)
q.Empty = queue.Empty
return q
def create_event(self, *args, **kwargs):
"""Create an event object."""
return threading.Event(*args, **kwargs)
def _reset(self):
self.state = 'disconnected'
self.sid = None
def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
if requests is None: # pragma: no cover
# not installed
self.logger.error('requests package is not installed -- cannot '
'send HTTP requests!')
return
self.base_url = self._get_engineio_url(url, engineio_path, 'polling')
self.logger.info('Attempting polling connection to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers)
if r is None:
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
if r.status_code != 200:
raise exceptions.ConnectionError(
'Unexpected status code {} in server response'.format(
r.status_code))
try:
p = payload.Payload(encoded_payload=r.content)
except ValueError:
six.raise_from(exceptions.ConnectionError(
'Unexpected response from server'), None)
open_packet = p.packets[0]
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError(
'OPEN packet not returned by server')
self.logger.info(
'Polling connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'polling'
self.base_url += '&sid=' + self.sid
self.state = 'connected'
connected_clients.append(self)
self._trigger_event('connect', run_async=False)
for pkt in p.packets[1:]:
self._receive_packet(pkt)
if 'websocket' in self.upgrades and 'websocket' in self.transports:
# attempt to upgrade to websocket
if self._connect_websocket(url, headers, engineio_path):
# upgrade to websocket succeeded, we're done here
return
# start background tasks associated with this client
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_polling)
def _connect_websocket(self, url, headers, engineio_path):
"""Establish or upgrade to a WebSocket connection with the server."""
if websocket is None: # pragma: no cover
# not installed
self.logger.warning('websocket-client package not installed, only '
'polling transport is available')
return False
websocket_url = self._get_engineio_url(url, engineio_path, 'websocket')
if self.sid:
self.logger.info(
'Attempting WebSocket upgrade to ' + websocket_url)
upgrade = True
websocket_url += '&sid=' + self.sid
else:
upgrade = False
self.base_url = websocket_url
self.logger.info(
'Attempting WebSocket connection to ' + websocket_url)
# get the cookies from the long-polling connection so that they can
# also be sent the the WebSocket route
cookies = None
if self.http:
cookies = '; '.join(["{}={}".format(cookie.name, cookie.value)
for cookie in self.http.cookies])
try:
ws = websocket.create_connection(
websocket_url + self._get_url_timestamp(), header=headers,
cookie=cookies)
except ConnectionError:
if upgrade:
self.logger.warning(
'WebSocket upgrade failed: connection error')
return False
else:
raise exceptions.ConnectionError('Connection error')
if upgrade:
p = packet.Packet(packet.PING,
data=six.text_type('probe')).encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
try:
p = ws.recv()
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected recv exception: %s',
str(e))
return False
pkt = packet.Packet(encoded_packet=p)
if pkt.packet_type != packet.PONG or pkt.data != 'probe':
self.logger.warning(
'WebSocket upgrade failed: no PONG packet')
return False
p = packet.Packet(packet.UPGRADE).encode()
try:
ws.send(p)
except Exception as e: # pragma: no cover
self.logger.warning(
'WebSocket upgrade failed: unexpected send exception: %s',
str(e))
return False
self.current_transport = 'websocket'
self.logger.info('WebSocket upgrade was successful')
else:
try:
p = ws.recv()
except Exception as e: # pragma: no cover
raise exceptions.ConnectionError(
'Unexpected recv exception: ' + str(e))
open_packet = packet.Packet(encoded_packet=p)
if open_packet.packet_type != packet.OPEN:
raise exceptions.ConnectionError('no OPEN packet')
self.logger.info(
'WebSocket connection accepted with ' + str(open_packet.data))
self.sid = open_packet.data['sid']
self.upgrades = open_packet.data['upgrades']
self.ping_interval = open_packet.data['pingInterval'] / 1000.0
self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0
self.current_transport = 'websocket'
self.state = 'connected'
connected_clients.append(self)
self._trigger_event('connect', run_async=False)
self.ws = ws
# start background tasks associated with this client
self.ping_loop_task = self.start_background_task(self._ping_loop)
self.write_loop_task = self.start_background_task(self._write_loop)
self.read_loop_task = self.start_background_task(
self._read_loop_websocket)
return True
def _receive_packet(self, pkt):
"""Handle incoming packets from the server."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.logger.info(
'Received packet %s data %s', packet_name,
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
if pkt.packet_type == packet.MESSAGE:
self._trigger_event('message', pkt.data, run_async=True)
elif pkt.packet_type == packet.PONG:
self.pong_received = True
elif pkt.packet_type == packet.CLOSE:
self.disconnect(abort=True)
elif pkt.packet_type == packet.NOOP:
pass
else:
self.logger.error('Received unexpected packet of type %s',
pkt.packet_type)
def _send_packet(self, pkt):
"""Queue a packet to be sent to the server."""
if self.state != 'connected':
return
self.queue.put(pkt)
self.logger.info(
'Sending packet %s data %s',
packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes) else '<binary>')
def _send_request(
self, method, url, headers=None, body=None): # pragma: no cover
if self.http is None:
self.http = requests.Session()
try:
return self.http.request(method, url, headers=headers, data=body)
except requests.exceptions.RequestException:
pass
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
def _get_engineio_url(self, url, engineio_path, transport):
"""Generate the Engine.IO connection URL."""
engineio_path = engineio_path.strip('/')
parsed_url = urllib.parse.urlparse(url)
if transport == 'polling':
scheme = 'http'
elif transport == 'websocket':
scheme = 'ws'
else: # pragma: no cover
raise ValueError('invalid transport')
if parsed_url.scheme in ['https', 'wss']:
scheme += 's'
return ('{scheme}://{netloc}/{path}/?{query}'
'{sep}transport={transport}&EIO=3').format(
scheme=scheme, netloc=parsed_url.netloc,
path=engineio_path, query=parsed_url.query,
sep='&' if parsed_url.query else '',
transport=transport)
def _get_url_timestamp(self):
"""Generate the Engine.IO query string timestamp."""
return '&t=' + str(time.time())
def _ping_loop(self):
"""This background task sends a PING to the server at the requested
interval.
"""
self.pong_received = True
self.ping_loop_event.clear()
while self.state == 'connected':
if not self.pong_received:
self.logger.info(
'PONG response has not been received, aborting')
if self.ws:
self.ws.close()
self.queue.put(None)
break
self.pong_received = False
self._send_packet(packet.Packet(packet.PING))
self.ping_loop_event.wait(timeout=self.ping_interval)
self.logger.info('Exiting ping task')
def _read_loop_polling(self):
"""Read packets by polling the Engine.IO server."""
while self.state == 'connected':
self.logger.info(
'Sending polling GET request to ' + self.base_url)
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp())
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
self.queue.put(None)
break
if r.status_code != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self.queue.put(None)
break
try:
p = payload.Payload(encoded_payload=r.content)
except ValueError:
self.logger.warning(
'Unexpected packet from server, aborting')
self.queue.put(None)
break
for pkt in p.packets:
self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', run_async=False)
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _read_loop_websocket(self):
"""Read packets from the Engine.IO WebSocket connection."""
while self.state == 'connected':
p = None
try:
p = self.ws.recv()
except websocket.WebSocketConnectionClosedException:
self.logger.warning(
'WebSocket connection was closed, aborting')
self.queue.put(None)
break
except Exception as e:
self.logger.info(
'Unexpected error "%s", aborting', str(e))
self.queue.put(None)
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
self._receive_packet(pkt)
self.logger.info('Waiting for write loop task to end')
self.write_loop_task.join()
self.logger.info('Waiting for ping loop task to end')
self.ping_loop_event.set()
self.ping_loop_task.join()
if self.state == 'connected':
self._trigger_event('disconnect', run_async=False)
try:
connected_clients.remove(self)
except ValueError: # pragma: no cover
pass
self._reset()
self.logger.info('Exiting read loop task')
def _write_loop(self):
"""This background task sends packages to the server as they are
pushed to the send queue.
"""
while self.state == 'connected':
# to simplify the timeout handling, use the maximum of the
# ping interval and ping timeout as timeout, with an extra 5
# seconds grace period
timeout = max(self.ping_interval, self.ping_timeout) + 5
packets = None
try:
packets = [self.queue.get(timeout=timeout)]
except self.queue.Empty:
self.logger.error('packet queue is empty, aborting')
break
if packets == [None]:
self.queue.task_done()
packets = []
else:
while True:
try:
packets.append(self.queue.get(block=False))
except self.queue.Empty:
break
if packets[-1] is None:
packets = packets[:-1]
self.queue.task_done()
break
if not packets:
# empty packet list returned -> connection closed
break
if self.current_transport == 'polling':
p = payload.Payload(packets=packets)
r = self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'})
for pkt in packets:
self.queue.task_done()
if r is None:
self.logger.warning(
'Connection refused by the server, aborting')
break
if r.status_code != 200:
self.logger.warning('Unexpected status code %s in server '
'response, aborting', r.status_code)
self._reset()
break
else:
# websocket
try:
for pkt in packets:
encoded_packet = pkt.encode(always_bytes=False)
if pkt.binary:
self.ws.send_binary(encoded_packet)
else:
self.ws.send(encoded_packet)
self.queue.task_done()
except websocket.WebSocketConnectionClosedException:
self.logger.warning(
'WebSocket connection was closed, aborting')
break
self.logger.info('Exiting write loop task')

View file

@ -0,0 +1,22 @@
class EngineIOError(Exception):
pass
class ContentTooLongError(EngineIOError):
pass
class UnknownPacketError(EngineIOError):
pass
class QueueEmpty(EngineIOError):
pass
class SocketIsClosedError(EngineIOError):
pass
class ConnectionError(EngineIOError):
pass

View file

@ -0,0 +1,87 @@
import os
from engineio.static_files import get_static_file
class WSGIApp(object):
"""WSGI application middleware for Engine.IO.
This middleware dispatches traffic to an Engine.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another WSGI application.
:param engineio_app: The Engine.IO server. Must be an instance of the
``engineio.Server`` class.
:param wsgi_app: The WSGI app that receives all other traffic.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param engineio_path: The endpoint where the Engine.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import engineio
import eventlet
eio = engineio.Server()
app = engineio.WSGIApp(eio, static_files={
'/': {'content_type': 'text/html', 'filename': 'index.html'},
'/index.html': {'content_type': 'text/html',
'filename': 'index.html'},
})
eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
"""
def __init__(self, engineio_app, wsgi_app=None, static_files=None,
engineio_path='engine.io'):
self.engineio_app = engineio_app
self.wsgi_app = wsgi_app
self.engineio_path = engineio_path.strip('/')
self.static_files = static_files or {}
def __call__(self, environ, start_response):
if 'gunicorn.socket' in environ:
# gunicorn saves the socket under environ['gunicorn.socket'], while
# eventlet saves it under environ['eventlet.input']. Eventlet also
# stores the socket inside a wrapper class, while gunicon writes it
# directly into the environment. To give eventlet's WebSocket
# module access to this socket when running under gunicorn, here we
# copy the socket to the eventlet format.
class Input(object):
def __init__(self, socket):
self.socket = socket
def get_socket(self):
return self.socket
environ['eventlet.input'] = Input(environ['gunicorn.socket'])
path = environ['PATH_INFO']
if path is not None and \
path.startswith('/{0}/'.format(self.engineio_path)):
return self.engineio_app.handle_request(environ, start_response)
else:
static_file = get_static_file(path, self.static_files) \
if self.static_files else None
if static_file:
if os.path.exists(static_file['filename']):
start_response(
'200 OK',
[('Content-Type', static_file['content_type'])])
with open(static_file['filename'], 'rb') as f:
return [f.read()]
else:
return self.not_found(start_response)
elif self.wsgi_app is not None:
return self.wsgi_app(environ, start_response)
return self.not_found(start_response)
def not_found(self, start_response):
start_response("404 Not Found", [('Content-Type', 'text/plain')])
return [b'Not Found']
class Middleware(WSGIApp):
"""This class has been renamed to ``WSGIApp`` and is now deprecated."""
def __init__(self, engineio_app, wsgi_app=None,
engineio_path='engine.io'):
super(Middleware, self).__init__(engineio_app, wsgi_app,
engineio_path=engineio_path)

View file

@ -0,0 +1,92 @@
import base64
import json as _json
import six
(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6)
packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP']
binary_types = (six.binary_type, bytearray)
class Packet(object):
"""Engine.IO packet."""
json = _json
def __init__(self, packet_type=NOOP, data=None, binary=None,
encoded_packet=None):
self.packet_type = packet_type
self.data = data
if binary is not None:
self.binary = binary
elif isinstance(data, six.text_type):
self.binary = False
elif isinstance(data, binary_types):
self.binary = True
else:
self.binary = False
if encoded_packet:
self.decode(encoded_packet)
def encode(self, b64=False, always_bytes=True):
"""Encode the packet for transmission."""
if self.binary and not b64:
encoded_packet = six.int2byte(self.packet_type)
else:
encoded_packet = six.text_type(self.packet_type)
if self.binary and b64:
encoded_packet = 'b' + encoded_packet
if self.binary:
if b64:
encoded_packet += base64.b64encode(self.data).decode('utf-8')
else:
encoded_packet += self.data
elif isinstance(self.data, six.string_types):
encoded_packet += self.data
elif isinstance(self.data, dict) or isinstance(self.data, list):
encoded_packet += self.json.dumps(self.data,
separators=(',', ':'))
elif self.data is not None:
encoded_packet += str(self.data)
if always_bytes and not isinstance(encoded_packet, binary_types):
encoded_packet = encoded_packet.encode('utf-8')
return encoded_packet
def decode(self, encoded_packet):
"""Decode a transmitted package."""
b64 = False
if not isinstance(encoded_packet, binary_types):
encoded_packet = encoded_packet.encode('utf-8')
elif not isinstance(encoded_packet, bytes):
encoded_packet = bytes(encoded_packet)
self.packet_type = six.byte2int(encoded_packet[0:1])
if self.packet_type == 98: # 'b' --> binary base64 encoded packet
self.binary = True
encoded_packet = encoded_packet[1:]
self.packet_type = six.byte2int(encoded_packet[0:1])
self.packet_type -= 48
b64 = True
elif self.packet_type >= 48:
self.packet_type -= 48
self.binary = False
else:
self.binary = True
self.data = None
if len(encoded_packet) > 1:
if self.binary:
if b64:
self.data = base64.b64decode(encoded_packet[1:])
else:
self.data = encoded_packet[1:]
else:
try:
self.data = self.json.loads(
encoded_packet[1:].decode('utf-8'))
if isinstance(self.data, int):
# do not allow integer payloads, see
# github.com/miguelgrinberg/python-engineio/issues/75
# for background on this decision
raise ValueError
except ValueError:
self.data = encoded_packet[1:].decode('utf-8')

View file

@ -0,0 +1,80 @@
import six
from . import packet
from six.moves import urllib
class Payload(object):
"""Engine.IO payload."""
def __init__(self, packets=None, encoded_payload=None):
self.packets = packets or []
if encoded_payload is not None:
self.decode(encoded_payload)
def encode(self, b64=False, jsonp_index=None):
"""Encode the payload for transmission."""
encoded_payload = b''
for pkt in self.packets:
encoded_packet = pkt.encode(b64=b64)
packet_len = len(encoded_packet)
if b64:
encoded_payload += str(packet_len).encode('utf-8') + b':' + \
encoded_packet
else:
binary_len = b''
while packet_len != 0:
binary_len = six.int2byte(packet_len % 10) + binary_len
packet_len = int(packet_len / 10)
if not pkt.binary:
encoded_payload += b'\0'
else:
encoded_payload += b'\1'
encoded_payload += binary_len + b'\xff' + encoded_packet
if jsonp_index is not None:
encoded_payload = b'___eio[' + \
str(jsonp_index).encode() + \
b']("' + \
encoded_payload.replace(b'"', b'\\"') + \
b'");'
return encoded_payload
def decode(self, encoded_payload):
"""Decode a transmitted payload."""
self.packets = []
while encoded_payload:
# JSONP POST payload starts with 'd='
if encoded_payload.startswith(b'd='):
encoded_payload = urllib.parse.parse_qs(
encoded_payload)[b'd'][0]
if six.byte2int(encoded_payload[0:1]) <= 1:
packet_len = 0
i = 1
while six.byte2int(encoded_payload[i:i + 1]) != 255:
packet_len = packet_len * 10 + six.byte2int(
encoded_payload[i:i + 1])
i += 1
self.packets.append(packet.Packet(
encoded_packet=encoded_payload[i + 1:i + 1 + packet_len]))
else:
i = encoded_payload.find(b':')
if i == -1:
raise ValueError('invalid payload')
# extracting the packet out of the payload is extremely
# inefficient, because the payload needs to be treated as
# binary, but the non-binary packets have to be parsed as
# unicode. Luckily this complication only applies to long
# polling, as the websocket transport sends packets
# individually wrapped.
packet_len = int(encoded_payload[0:i])
pkt = encoded_payload.decode('utf-8', errors='ignore')[
i + 1: i + 1 + packet_len].encode('utf-8')
self.packets.append(packet.Packet(encoded_packet=pkt))
# the engine.io protocol sends the packet length in
# utf-8 characters, but we need it in bytes to be able to
# jump to the next packet in the payload
packet_len = len(pkt)
encoded_payload = encoded_payload[i + 1 + packet_len:]

View file

@ -0,0 +1,633 @@
import gzip
import importlib
import logging
import uuid
import zlib
import six
from six.moves import urllib
from . import exceptions
from . import packet
from . import payload
from . import socket
default_logger = logging.getLogger('engineio.server')
class Server(object):
"""An Engine.IO server.
This class implements a fully compliant Engine.IO web server with support
for websocket and long-polling transports.
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
The first async mode that has all its dependencies
installed is the one that is chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 60 seconds.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 100,000,000
bytes.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: Origin or list of origins that are allowed to
connect to this server. All origins are
allowed by default, which is equivalent to
setting this argument to ``'*'``.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default
is ``True``.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param monitor_clients: If set to ``True``, a background task will ensure
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
_default_monitor_clients = True
def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25,
max_http_buffer_size=100000000, allow_upgrades=True,
http_compression=True, compression_threshold=1024,
cookie='io', cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, monitor_clients=None, **kwargs):
self.ping_timeout = ping_timeout
self.ping_interval = ping_interval
self.max_http_buffer_size = max_http_buffer_size
self.allow_upgrades = allow_upgrades
self.http_compression = http_compression
self.compression_threshold = compression_threshold
self.cookie = cookie
self.cors_allowed_origins = cors_allowed_origins
self.cors_credentials = cors_credentials
self.async_handlers = async_handlers
self.sockets = {}
self.handlers = {}
self.start_service_task = monitor_clients \
if monitor_clients is not None else self._default_monitor_clients
if json is not None:
packet.Packet.json = json
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
modes = self.async_modes()
if async_mode is not None:
modes = [async_mode] if async_mode in modes else []
self._async = None
self.async_mode = None
for mode in modes:
try:
self._async = importlib.import_module(
'engineio.async_drivers.' + mode)._async
asyncio_based = self._async['asyncio'] \
if 'asyncio' in self._async else False
if asyncio_based != self.is_asyncio_based():
continue # pragma: no cover
self.async_mode = mode
break
except ImportError:
pass
if self.async_mode is None:
raise ValueError('Invalid async_mode specified')
if self.is_asyncio_based() and \
('asyncio' not in self._async or not
self._async['asyncio']): # pragma: no cover
raise ValueError('The selected async_mode is not asyncio '
'compatible')
if not self.is_asyncio_based() and 'asyncio' in self._async and \
self._async['asyncio']: # pragma: no cover
raise ValueError('The selected async_mode requires asyncio and '
'must use the AsyncServer class')
self.logger.info('Server initialized for %s.', self.async_mode)
def is_asyncio_based(self):
return False
def async_modes(self):
return ['eventlet', 'gevent_uwsgi', 'gevent', 'threading']
def on(self, event, handler=None):
"""Register an event handler.
:param event: The event name. Can be ``'connect'``, ``'message'`` or
``'disconnect'``.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
Example usage::
# as a decorator:
@eio.on('connect')
def connect_handler(sid, environ):
print('Connection request')
if environ['REMOTE_ADDR'] in blacklisted:
return False # reject
# as a method:
def message_handler(sid, msg):
print('Received message: ', msg)
eio.send(sid, 'response')
eio.on('message', message_handler)
The handler function receives the ``sid`` (session ID) for the
client as first argument. The ``'connect'`` event handler receives the
WSGI environment as a second argument, and can return ``False`` to
reject the connection. The ``'message'`` handler receives the message
payload as a second argument. The ``'disconnect'`` handler does not
take a second argument.
"""
if event not in self.event_names:
raise ValueError('Invalid event')
def set_handler(handler):
self.handlers[event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def send(self, sid, data, binary=None):
"""Send a message to a client.
:param sid: The session id of the recipient client.
:param data: The data to send to the client. Data can be of type
``str``, ``bytes``, ``list`` or ``dict``. If a ``list``
or ``dict``, the data will be serialized as JSON.
:param binary: ``True`` to send packet as binary, ``False`` to send
as text. If not given, unicode (Python 2) and str
(Python 3) are sent as text, and str (Python 2) and
bytes (Python 3) are sent as binary.
"""
try:
socket = self._get_socket(sid)
except KeyError:
# the socket is not available
self.logger.warning('Cannot send to sid %s', sid)
return
socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary))
def get_session(self, sid):
"""Return the user session for a client.
:param sid: The session id of the client.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
socket = self._get_socket(sid)
return socket.session
def save_session(self, sid, session):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
"""
socket = self._get_socket(sid)
socket.session = session
def session(self, sid):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.session = None
def __enter__(self):
self.session = self.server.get_session(sid)
return self.session
def __exit__(self, *args):
self.server.save_session(sid, self.session)
return _session_context_manager(self, sid)
def disconnect(self, sid=None):
"""Disconnect a client.
:param sid: The session id of the client to close. If this parameter
is not given, then all clients are closed.
"""
if sid is not None:
try:
socket = self._get_socket(sid)
except KeyError: # pragma: no cover
# the socket was already closed or gone
pass
else:
socket.close()
del self.sockets[sid]
else:
for client in six.itervalues(self.sockets):
client.close()
self.sockets = {}
def transport(self, sid):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
:param sid: The session of the client.
"""
return 'websocket' if self._get_socket(sid).upgraded else 'polling'
def handle_request(self, environ, start_response):
"""Handle an HTTP request from the client.
This is the entry point of the Engine.IO application, using the same
interface as a WSGI application. For the typical usage, this function
is invoked by the :class:`Middleware` instance, but it can be invoked
directly when the middleware is not used.
:param environ: The WSGI environment.
:param start_response: The WSGI ``start_response`` function.
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
method = environ['REQUEST_METHOD']
query = urllib.parse.parse_qs(environ.get('QUERY_STRING', ''))
sid = query['sid'][0] if 'sid' in query else None
b64 = False
jsonp = False
jsonp_index = None
if 'b64' in query:
if query['b64'][0] == "1" or query['b64'][0].lower() == "true":
b64 = True
if 'j' in query:
jsonp = True
try:
jsonp_index = int(query['j'][0])
except (ValueError, KeyError, IndexError):
# Invalid JSONP index number
pass
if jsonp and jsonp_index is None:
self.logger.warning('Invalid JSONP index number')
r = self._bad_request()
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
if transport != 'polling' and transport != 'websocket':
self.logger.warning('Invalid transport %s', transport)
r = self._bad_request()
else:
r = self._handle_connect(environ, start_response,
transport, b64, jsonp_index)
else:
if sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
packets = socket.handle_get_request(
environ, start_response)
if isinstance(packets, list):
r = self._ok(packets, b64=b64,
jsonp_index=jsonp_index)
else:
r = packets
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
r = self._bad_request()
else:
socket = self._get_socket(sid)
try:
socket.handle_post_request(environ)
r = self._ok(jsonp_index=jsonp_index)
except exceptions.EngineIOError:
if sid in self.sockets: # pragma: no cover
self.disconnect(sid)
r = self._bad_request()
except: # pragma: no cover
# for any other unexpected errors, we log the error
# and keep going
self.logger.exception('post request handler error')
r = self._ok(jsonp_index=jsonp_index)
elif method == 'OPTIONS':
r = self._ok()
else:
self.logger.warning('Method %s not supported', method)
r = self._method_not_found()
if not isinstance(r, dict):
return r or []
if self.http_compression and \
len(r['response']) >= self.compression_threshold:
encodings = [e.split(';')[0].strip() for e in
environ.get('HTTP_ACCEPT_ENCODING', '').split(',')]
for encoding in encodings:
if encoding in self.compression_methods:
r['response'] = \
getattr(self, '_' + encoding)(r['response'])
r['headers'] += [('Content-Encoding', encoding)]
break
cors_headers = self._cors_headers(environ)
start_response(r['status'], r['headers'] + cors_headers)
return [r['response']]
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
th = self._async['thread'](target=target, args=args, kwargs=kwargs)
th.start()
return th # pragma: no cover
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self._async['sleep'](seconds)
def create_queue(self, *args, **kwargs):
"""Create a queue object using the appropriate async model.
This is a utility function that applications can use to create a queue
without having to worry about using the correct call for the selected
async mode.
"""
return self._async['queue'](*args, **kwargs)
def get_queue_empty_exception(self):
"""Return the queue empty exception for the appropriate async model.
This is a utility function that applications can use to work with a
queue without having to worry about using the correct call for the
selected async mode.
"""
return self._async['queue_empty']
def create_event(self, *args, **kwargs):
"""Create an event object using the appropriate async model.
This is a utility function that applications can use to create an
event without having to worry about using the correct call for the
selected async mode.
"""
return self._async['event'](*args, **kwargs)
def _generate_id(self):
"""Generate a unique session id."""
return uuid.uuid4().hex
def _handle_connect(self, environ, start_response, transport, b64=False,
jsonp_index=None):
"""Handle a client connection request."""
if self.start_service_task:
# start the service task to monitor connected clients
self.start_service_task = False
self.start_background_task(self._service_task)
sid = self._generate_id()
s = socket.Socket(self, sid)
self.sockets[sid] = s
pkt = packet.Packet(
packet.OPEN, {'sid': sid,
'upgrades': self._upgrades(sid, transport),
'pingTimeout': int(self.ping_timeout * 1000),
'pingInterval': int(self.ping_interval * 1000)})
s.send(pkt)
ret = self._trigger_event('connect', sid, environ, run_async=False)
if ret is False:
del self.sockets[sid]
self.logger.warning('Application rejected connection')
return self._unauthorized()
if transport == 'websocket':
ret = s.handle_get_request(environ, start_response)
if s.closed:
# websocket connection ended, so we are done
del self.sockets[sid]
return ret
else:
s.connected = True
headers = None
if self.cookie:
headers = [('Set-Cookie', self.cookie + '=' + sid)]
try:
return self._ok(s.poll(), headers=headers, b64=b64,
jsonp_index=jsonp_index)
except exceptions.QueueEmpty:
return self._bad_request()
def _upgrades(self, sid, transport):
"""Return the list of possible upgrades for a client connection."""
if not self.allow_upgrades or self._get_socket(sid).upgraded or \
self._async['websocket'] is None or transport == 'websocket':
return []
return ['websocket']
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
run_async = kwargs.pop('run_async', False)
if event in self.handlers:
if run_async:
return self.start_background_task(self.handlers[event], *args)
else:
try:
return self.handlers[event](*args)
except:
self.logger.exception(event + ' handler error')
if event == 'connect':
# if connect handler raised error we reject the
# connection
return False
def _get_socket(self, sid):
"""Return the socket object for a given session."""
try:
s = self.sockets[sid]
except KeyError:
raise KeyError('Session not found')
if s.closed:
del self.sockets[sid]
raise KeyError('Session is disconnected')
return s
def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None):
"""Generate a successful HTTP response."""
if packets is not None:
if headers is None:
headers = []
if b64:
headers += [('Content-Type', 'text/plain; charset=UTF-8')]
else:
headers += [('Content-Type', 'application/octet-stream')]
return {'status': '200 OK',
'headers': headers,
'response': payload.Payload(packets=packets).encode(
b64=b64, jsonp_index=jsonp_index)}
else:
return {'status': '200 OK',
'headers': [('Content-Type', 'text/plain')],
'response': b'OK'}
def _bad_request(self):
"""Generate a bad request HTTP error response."""
return {'status': '400 BAD REQUEST',
'headers': [('Content-Type', 'text/plain')],
'response': b'Bad Request'}
def _method_not_found(self):
"""Generate a method not found HTTP error response."""
return {'status': '405 METHOD NOT FOUND',
'headers': [('Content-Type', 'text/plain')],
'response': b'Method Not Found'}
def _unauthorized(self):
"""Generate a unauthorized HTTP error response."""
return {'status': '401 UNAUTHORIZED',
'headers': [('Content-Type', 'text/plain')],
'response': b'Unauthorized'}
def _cors_headers(self, environ):
"""Return the cross-origin-resource-sharing headers."""
if isinstance(self.cors_allowed_origins, six.string_types):
if self.cors_allowed_origins == '*':
allowed_origins = None
else:
allowed_origins = [self.cors_allowed_origins]
else:
allowed_origins = self.cors_allowed_origins
if allowed_origins is not None and \
environ.get('HTTP_ORIGIN', '') not in allowed_origins:
return []
if 'HTTP_ORIGIN' in environ:
headers = [('Access-Control-Allow-Origin', environ['HTTP_ORIGIN'])]
else:
headers = [('Access-Control-Allow-Origin', '*')]
if environ['REQUEST_METHOD'] == 'OPTIONS':
headers += [('Access-Control-Allow-Methods', 'OPTIONS, GET, POST')]
if 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS' in environ:
headers += [('Access-Control-Allow-Headers',
environ['HTTP_ACCESS_CONTROL_REQUEST_HEADERS'])]
if self.cors_credentials:
headers += [('Access-Control-Allow-Credentials', 'true')]
return headers
def _gzip(self, response):
"""Apply gzip compression to a response."""
bytesio = six.BytesIO()
with gzip.GzipFile(fileobj=bytesio, mode='w') as gz:
gz.write(response)
return bytesio.getvalue()
def _deflate(self, response):
"""Apply deflate compression to a response."""
return zlib.compress(response)
def _service_task(self): # pragma: no cover
"""Monitor connected clients and clean up those that time out."""
while True:
if len(self.sockets) == 0:
# nothing to do
self.sleep(self.ping_timeout)
continue
# go through the entire client list in a ping interval cycle
sleep_interval = self.ping_timeout / len(self.sockets)
try:
# iterate over the current clients
for s in self.sockets.copy().values():
if not s.closing and not s.closed:
s.check_ping_timeout()
self.sleep(sleep_interval)
except (SystemExit, KeyboardInterrupt):
self.logger.info('service task canceled')
break
except:
# an unexpected exception has occurred, log it and continue
self.logger.exception('service task exception')

View file

@ -0,0 +1,247 @@
import six
import sys
import time
from . import exceptions
from . import packet
from . import payload
class Socket(object):
"""An Engine.IO socket."""
upgrade_protocols = ['websocket']
def __init__(self, server, sid):
self.server = server
self.sid = sid
self.queue = self.server.create_queue()
self.last_ping = time.time()
self.connected = False
self.upgrading = False
self.upgraded = False
self.packet_backlog = []
self.closing = False
self.closed = False
self.session = {}
def poll(self):
"""Wait for packets to send to the client."""
queue_empty = self.server.get_queue_empty_exception()
try:
packets = [self.queue.get(timeout=self.server.ping_timeout)]
self.queue.task_done()
except queue_empty:
raise exceptions.QueueEmpty()
if packets == [None]:
return []
while True:
try:
packets.append(self.queue.get(block=False))
self.queue.task_done()
except queue_empty:
break
return packets
def receive(self, pkt):
"""Receive packet from the client."""
packet_name = packet.packet_names[pkt.packet_type] \
if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
self.server.logger.info('%s: Received packet %s data %s',
self.sid, packet_name,
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
if pkt.packet_type == packet.PING:
self.last_ping = time.time()
self.send(packet.Packet(packet.PONG, pkt.data))
elif pkt.packet_type == packet.MESSAGE:
self.server._trigger_event('message', self.sid, pkt.data,
run_async=self.server.async_handlers)
elif pkt.packet_type == packet.UPGRADE:
self.send(packet.Packet(packet.NOOP))
elif pkt.packet_type == packet.CLOSE:
self.close(wait=False, abort=True)
else:
raise exceptions.UnknownPacketError()
def check_ping_timeout(self):
"""Make sure the client is still sending pings.
This helps detect disconnections for long-polling clients.
"""
if self.closed:
raise exceptions.SocketIsClosedError()
if time.time() - self.last_ping > self.server.ping_interval + 5:
self.server.logger.info('%s: Client is gone, closing socket',
self.sid)
# Passing abort=False here will cause close() to write a
# CLOSE packet. This has the effect of updating half-open sockets
# to their correct state of disconnected
self.close(wait=False, abort=False)
return False
return True
def send(self, pkt):
"""Send a packet to the client."""
if not self.check_ping_timeout():
return
if self.upgrading:
self.packet_backlog.append(pkt)
else:
self.queue.put(pkt)
self.server.logger.info('%s: Sending packet %s data %s',
self.sid, packet.packet_names[pkt.packet_type],
pkt.data if not isinstance(pkt.data, bytes)
else '<binary>')
def handle_get_request(self, environ, start_response):
"""Handle a long-polling GET request from the client."""
connections = [
s.strip()
for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
transport = environ.get('HTTP_UPGRADE', '').lower()
if 'upgrade' in connections and transport in self.upgrade_protocols:
self.server.logger.info('%s: Received request to upgrade to %s',
self.sid, transport)
return getattr(self, '_upgrade_' + transport)(environ,
start_response)
try:
packets = self.poll()
except exceptions.QueueEmpty:
exc = sys.exc_info()
self.close(wait=False)
six.reraise(*exc)
return packets
def handle_post_request(self, environ):
"""Handle a long-polling POST request from the client."""
length = int(environ.get('CONTENT_LENGTH', '0'))
if length > self.server.max_http_buffer_size:
raise exceptions.ContentTooLongError()
else:
body = environ['wsgi.input'].read(length)
p = payload.Payload(encoded_payload=body)
for pkt in p.packets:
self.receive(pkt)
def close(self, wait=True, abort=False):
"""Close the socket connection."""
if not self.closed and not self.closing:
self.closing = True
self.server._trigger_event('disconnect', self.sid, run_async=False)
if not abort:
self.send(packet.Packet(packet.CLOSE))
self.closed = True
self.queue.put(None)
if wait:
self.queue.join()
def _upgrade_websocket(self, environ, start_response):
"""Upgrade the connection from polling to websocket."""
if self.upgraded:
raise IOError('Socket has been upgraded already')
if self.server._async['websocket'] is None:
# the selected async mode does not support websocket
return self.server._bad_request()
ws = self.server._async['websocket'](self._websocket_handler)
return ws(environ, start_response)
def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
# try to set a socket timeout matching the configured ping interval
for attr in ['_sock', 'socket']: # pragma: no cover
if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
getattr(ws, attr).settimeout(self.server.ping_timeout)
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
pkt = ws.wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
self.server.logger.info(
'%s: Failed websocket upgrade, no PING packet', self.sid)
return []
ws.send(packet.Packet(
packet.PONG,
data=six.text_type('probe')).encode(always_bytes=False))
self.queue.put(packet.Packet(packet.NOOP)) # end poll
pkt = ws.wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
self.server.logger.info(
('%s: Failed websocket upgrade, expected UPGRADE packet, '
'received %s instead.'),
self.sid, pkt)
return []
self.upgraded = True
# flush any packets that were sent during the upgrade
for pkt in self.packet_backlog:
self.queue.put(pkt)
self.packet_backlog = []
self.upgrading = False
else:
self.connected = True
self.upgraded = True
# start separate writer thread
def writer():
while True:
packets = None
try:
packets = self.poll()
except exceptions.QueueEmpty:
break
if not packets:
# empty packet list returned -> connection closed
break
try:
for pkt in packets:
ws.send(pkt.encode(always_bytes=False))
except:
break
writer_task = self.server.start_background_task(writer)
self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
while True:
p = None
try:
p = ws.wait()
except Exception as e:
# if the socket is already closed, we can assume this is a
# downstream error of that
if not self.closed: # pragma: no cover
self.server.logger.info(
'%s: Unexpected error "%s", closing connection',
self.sid, str(e))
break
if p is None:
# connection closed by client
break
if isinstance(p, six.text_type): # pragma: no cover
p = p.encode('utf-8')
pkt = packet.Packet(encoded_packet=p)
try:
self.receive(pkt)
except exceptions.UnknownPacketError: # pragma: no cover
pass
except exceptions.SocketIsClosedError: # pragma: no cover
self.server.logger.info('Receive error -- socket is closed')
break
except: # pragma: no cover
# if we get an unexpected exception we log the error and exit
# the connection properly
self.server.logger.exception('Unknown receive error')
break
self.queue.put(None) # unlock the writer task so that it can exit
writer_task.join()
self.close(wait=False, abort=True)
return []

View file

@ -0,0 +1,55 @@
content_types = {
'css': 'text/css',
'gif': 'image/gif',
'html': 'text/html',
'jpg': 'image/jpeg',
'js': 'application/javascript',
'json': 'application/json',
'png': 'image/png',
'txt': 'text/plain',
}
def get_static_file(path, static_files):
"""Return the local filename and content type for the requested static
file URL.
:param path: the path portion of the requested URL.
:param static_files: a static file configuration dictionary.
This function returns a dictionary with two keys, "filename" and
"content_type". If the requested URL does not match any static file, the
return value is None.
"""
if path in static_files:
f = static_files[path]
else:
f = None
rest = ''
while path != '':
path, last = path.rsplit('/', 1)
rest = '/' + last + rest
if path in static_files:
f = static_files[path] + rest
break
elif path + '/' in static_files:
f = static_files[path + '/'] + rest[1:]
break
if f:
if isinstance(f, str):
f = {'filename': f}
if f['filename'].endswith('/'):
if '' in static_files:
if isinstance(static_files[''], str):
f['filename'] += static_files['']
else:
f['filename'] += static_files['']['filename']
if 'content_type' in static_files['']:
f['content_type'] = static_files['']['content_type']
else:
f['filename'] += 'index.html'
if 'content_type' not in f:
ext = f['filename'].rsplit('.')[-1]
f['content_type'] = content_types.get(
ext, 'application/octet-stream')
return f

View file

@ -0,0 +1,35 @@
import sys
from .client import Client
from .base_manager import BaseManager
from .pubsub_manager import PubSubManager
from .kombu_manager import KombuManager
from .redis_manager import RedisManager
from .zmq_manager import ZmqManager
from .server import Server
from .namespace import Namespace, ClientNamespace
from .middleware import WSGIApp, Middleware
from .tornado import get_tornado_handler
if sys.version_info >= (3, 5): # pragma: no cover
from .asyncio_client import AsyncClient
from .asyncio_server import AsyncServer
from .asyncio_manager import AsyncManager
from .asyncio_namespace import AsyncNamespace, AsyncClientNamespace
from .asyncio_redis_manager import AsyncRedisManager
from .asgi import ASGIApp
else: # pragma: no cover
AsyncClient = None
AsyncServer = None
AsyncManager = None
AsyncNamespace = None
AsyncRedisManager = None
__version__ = '4.2.1'
__all__ = ['__version__', 'Client', 'Server', 'BaseManager', 'PubSubManager',
'KombuManager', 'RedisManager', 'ZmqManager', 'Namespace',
'ClientNamespace', 'WSGIApp', 'Middleware']
if AsyncServer is not None: # pragma: no cover
__all__ += ['AsyncClient', 'AsyncServer', 'AsyncNamespace',
'AsyncClientNamespace', 'AsyncManager', 'AsyncRedisManager',
'ASGIApp', 'get_tornado_handler']

View file

@ -0,0 +1,36 @@
import engineio
class ASGIApp(engineio.ASGIApp): # pragma: no cover
"""ASGI application middleware for Socket.IO.
This middleware dispatches traffic to an Socket.IO application. It can
also serve a list of static files to the client, or forward unrelated
HTTP traffic to another ASGI application.
:param socketio_server: The Socket.IO server. Must be an instance of the
``socketio.AsyncServer`` class.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param other_asgi_app: A separate ASGI app that receives all other traffic.
:param socketio_path: The endpoint where the Socket.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import socketio
import uvicorn
sio = socketio.AsyncServer()
app = engineio.ASGIApp(sio, static_files={
'/': 'index.html',
'/static': './public',
})
uvicorn.run(app, host='127.0.0.1', port=5000)
"""
def __init__(self, socketio_server, other_asgi_app=None,
static_files=None, socketio_path='socket.io'):
super().__init__(socketio_server, other_asgi_app,
static_files=static_files,
engineio_path=socketio_path)

View file

@ -0,0 +1,445 @@
import asyncio
import logging
import random
import engineio
import six
from . import client
from . import exceptions
from . import packet
default_logger = logging.getLogger('socketio.client')
class AsyncClient(client.Client):
"""A Socket.IO client for asyncio.
This class implements a fully compliant Socket.IO web client with support
for websocket and long-polling transports.
:param reconnection: ``True`` if the client should automatically attempt to
reconnect to the server after an interruption, or
``False`` to not reconnect. The default is ``True``.
:param reconnection_attempts: How many reconnection attempts to issue
before giving up, or 0 for infinity attempts.
The default is 0.
:param reconnection_delay: How long to wait in seconds before the first
reconnection attempt. Each successive attempt
doubles this delay.
:param reconnection_delay_max: The maximum delay between reconnection
attempts.
:param randomization_factor: Randomization amount for each delay between
reconnection attempts. The default is 0.5,
which means that each delay is randomly
adjusted by +/- 50%.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
The Engine.IO configuration supports the following settings:
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def is_asyncio_based(self):
return True
async def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param namespaces: The list of custom namespaces to connect, in
addition to the default namespace. If not given,
the namespace list is obtained from the registered
event handlers.
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
Note: this method is a coroutine.
Example usage::
sio = socketio.Client()
sio.connect('http://localhost:5000')
"""
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
if namespaces is None:
namespaces = set(self.handlers.keys()).union(
set(self.namespace_handlers.keys()))
elif isinstance(namespaces, six.string_types):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = [n for n in namespaces if n != '/']
try:
await self.eio.connect(url, headers=headers,
transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
async def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
Note: this method is a coroutine.
"""
while True:
await self.eio.wait()
await self.sleep(1) # give the reconnect task time to start up
if not self._reconnect_task:
break
await self._reconnect_task
if self.eio.state != 'connected':
break
async def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
self.logger.info('Emitting event "%s" [%s]', event, namespace)
if callback is not None:
id = self._generate_ack_id(namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
elif data is not None:
data = [data]
else:
data = []
await self._send_packet(packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id,
binary=binary))
async def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
Note: this method is a coroutine.
"""
await self.emit('message', data=data, namespace=namespace,
callback=callback)
async def call(self, event, data=None, namespace=None, timeout=60):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
Note: this method is a coroutine.
"""
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
await self.emit(event, data=data, namespace=namespace,
callback=event_callback)
try:
await asyncio.wait_for(callback_event.wait(), timeout)
except asyncio.TimeoutError:
six.raise_from(exceptions.TimeoutError(), None)
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
async def disconnect(self):
"""Disconnect from the server.
Note: this method is a coroutine.
"""
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.DISCONNECT,
namespace=n))
await self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
await self.eio.disconnect(abort=True)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await self.eio.sleep(seconds)
async def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
await self.eio.send(ep, binary=binary)
binary = True
else:
await self.eio.send(encoded_packet, binary=False)
async def _handle_connect(self, namespace):
namespace = namespace or '/'
self.logger.info('Namespace {} is connected'.format(namespace))
await self._trigger_event('connect', namespace=namespace)
if namespace == '/':
for n in self.namespaces:
await self._send_packet(packet.Packet(packet.CONNECT,
namespace=n))
elif namespace not in self.namespaces:
self.namespaces.append(namespace)
async def _handle_disconnect(self, namespace):
namespace = namespace or '/'
await self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
async def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received event "%s" [%s]', data[0], namespace)
r = await self._trigger_event(data[0], namespace, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
await self._send_packet(packet.Packet(
packet.ACK, namespace=namespace, id=id, data=data,
binary=binary))
async def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received ack [%s]', namespace)
callback = None
try:
callback = self.callbacks[namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self.logger.warning('Unknown callback received, ignoring.')
else:
del self.callbacks[namespace][id]
if callback is not None:
if asyncio.iscoroutinefunction(callback):
await callback(*data)
else:
callback(*data)
def _handle_error(self, namespace):
namespace = namespace or '/'
self.logger.info('Connection to namespace {} was rejected'.format(
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
if asyncio.iscoroutinefunction(self.handlers[namespace][event]):
try:
ret = await self.handlers[namespace][event](*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = self.handlers[namespace][event](*args)
return ret
# or else, forward the event to a namepsace handler if one exists
elif namespace in self.namespace_handlers:
return await self.namespace_handlers[namespace].trigger_event(
event, *args)
async def _handle_reconnect(self):
self._reconnect_abort.clear()
client.reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
delay = current_delay
current_delay *= 2
if delay > self.reconnection_delay_max:
delay = self.reconnection_delay_max
delay += self.randomization_factor * (2 * random.random() - 1)
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
try:
await asyncio.wait_for(self._reconnect_abort.wait(), delay)
self.logger.info('Reconnect task aborted')
break
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
attempt_count += 1
try:
await self.connect(self.connection_url,
headers=self.connection_headers,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
except (exceptions.ConnectionError, ValueError):
pass
else:
self.logger.info('Reconnection successful')
self._reconnect_task = None
break
if self.reconnection_attempts and \
attempt_count >= self.reconnection_attempts:
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
client.reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
async def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
if self._binary_packet:
pkt = self._binary_packet
if pkt.add_attachment(data):
self._binary_packet = None
if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(pkt.namespace, pkt.id, pkt.data)
else:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(pkt.namespace)
elif pkt.packet_type == packet.EVENT:
await self._handle_event(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
await self._handle_ack(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet = pkt
elif pkt.packet_type == packet.ERROR:
self._handle_error(pkt.namespace)
else:
raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
self._reconnect_abort.set()
for n in self.namespaces:
await self._trigger_event('disconnect', namespace=n)
await self._trigger_event('disconnect', namespace='/')
self.callbacks = {}
self._binary_packet = None
self.sid = None
if self.eio.state == 'connected' and self.reconnection:
self._reconnect_task = self.start_background_task(
self._handle_reconnect)
def _engineio_client_class(self):
return engineio.AsyncClient

View file

@ -0,0 +1,58 @@
import asyncio
from .base_manager import BaseManager
class AsyncManager(BaseManager):
"""Manage a client list for an asyncio server."""
async def emit(self, event, data, namespace, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
Note: this method is a coroutine.
"""
if namespace not in self.rooms or room not in self.rooms[namespace]:
return
tasks = []
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
tasks.append(self.server._emit_internal(sid, event, data,
namespace, id))
if tasks == []: # pragma: no cover
return
await asyncio.wait(tasks)
async def close_room(self, room, namespace):
"""Remove all participants from a room.
Note: this method is a coroutine.
"""
return super().close_room(room, namespace)
async def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback.
Note: this method is a coroutine.
"""
callback = None
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self._get_logger().warning('Unknown callback received, ignoring.')
else:
del self.callbacks[sid][namespace][id]
if callback is not None:
ret = callback(*data)
if asyncio.iscoroutine(ret):
try:
await ret
except asyncio.CancelledError: # pragma: no cover
pass

View file

@ -0,0 +1,204 @@
import asyncio
from socketio import namespace
class AsyncNamespace(namespace.Namespace):
"""Base class for asyncio server-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on. These can be regular functions or
coroutines.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def is_asyncio_based(self):
return True
async def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
Note: this method is a coroutine.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
handler = getattr(self, handler_name)
if asyncio.iscoroutinefunction(handler) is True:
try:
ret = await handler(*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = handler(*args)
return ret
async def emit(self, event, data=None, room=None, skip_sid=None,
namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
The only difference with the :func:`socketio.Server.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.emit(event, data=data, room=room,
skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
async def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to one or more connected clients.
The only difference with the :func:`socketio.Server.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.send(data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
async def close_room(self, room, namespace=None):
"""Close a room.
The only difference with the :func:`socketio.Server.close_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.close_room(
room, namespace=namespace or self.namespace)
async def get_session(self, sid, namespace=None):
"""Return the user session for a client.
The only difference with the :func:`socketio.Server.get_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.get_session(
sid, namespace=namespace or self.namespace)
async def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
The only difference with the :func:`socketio.Server.save_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.save_session(
sid, session, namespace=namespace or self.namespace)
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
The only difference with the :func:`socketio.Server.session` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.session(sid, namespace=namespace or self.namespace)
async def disconnect(self, sid, namespace=None):
"""Disconnect a client.
The only difference with the :func:`socketio.Server.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.server.disconnect(
sid, namespace=namespace or self.namespace)
class AsyncClientNamespace(namespace.ClientNamespace):
"""Base class for asyncio client-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on. These can be regular functions or
coroutines.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def is_asyncio_based(self):
return True
async def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
Note: this method is a coroutine.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
handler = getattr(self, handler_name)
if asyncio.iscoroutinefunction(handler) is True:
try:
ret = await handler(*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = handler(*args)
return ret
async def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to the server.
The only difference with the :func:`socketio.Client.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.emit(event, data=data,
namespace=namespace or self.namespace,
callback=callback)
async def send(self, data, namespace=None, callback=None):
"""Send a message to the server.
The only difference with the :func:`socketio.Client.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.send(data,
namespace=namespace or self.namespace,
callback=callback)
async def disconnect(self):
"""Disconnect a client.
The only difference with the :func:`socketio.Client.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
Note: this method is a coroutine.
"""
return await self.client.disconnect()

View file

@ -0,0 +1,163 @@
from functools import partial
import uuid
import json
import pickle
import six
from .asyncio_manager import AsyncManager
class AsyncPubSubManager(AsyncManager):
"""Manage a client list attached to a pub/sub backend under asyncio.
This is a base class that enables multiple servers to share the list of
clients, with the servers communicating events through a pub/sub backend.
The use of a pub/sub backend also allows any client connected to the
backend to emit events addressed to Socket.IO clients.
The actual backends must be implemented by subclasses, this class only
provides a pub/sub generic framework for asyncio applications.
:param channel: The channel name on which the server sends and receives
notifications.
"""
name = 'asyncpubsub'
def __init__(self, channel='socketio', write_only=False, logger=None):
super().__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
def initialize(self):
super().initialize()
if not self.write_only:
self.thread = self.server.start_background_task(self._thread)
self._get_logger().info(self.name + ' backend initialized.')
async def emit(self, event, data, namespace=None, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
This method takes care or propagating the message to all the servers
that are connected through the message queue.
The parameters are the same as in :meth:`.Server.emit`.
Note: this method is a coroutine.
"""
if kwargs.get('ignore_queue'):
return await super().emit(
event, data, namespace=namespace, room=room, skip_sid=skip_sid,
callback=callback)
namespace = namespace or '/'
if callback is not None:
if self.server is None:
raise RuntimeError('Callbacks can only be issued from the '
'context of a server.')
if room is None:
raise ValueError('Cannot use callback without a room set.')
id = self._generate_ack_id(room, namespace, callback)
callback = (room, namespace, id)
else:
callback = None
await self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback,
'host_id': self.host_id})
async def close_room(self, room, namespace=None):
await self._publish({'method': 'close_room', 'room': room,
'namespace': namespace or '/'})
async def _publish(self, data):
"""Publish a message on the Socket.IO channel.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
async def _listen(self):
"""Return the next message published on the Socket.IO channel,
blocking until a message is available.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
async def _handle_emit(self, message):
# Events with callbacks are very tricky to handle across hosts
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, remote_host_id,
*remote_callback)
else:
callback = None
await super().emit(message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=callback)
async def _handle_callback(self, message):
if self.host_id == message.get('host_id'):
try:
sid = message['sid']
namespace = message['namespace']
id = message['id']
args = message['args']
except KeyError:
return
await self.trigger_callback(sid, namespace, id, args)
async def _return_callback(self, host_id, sid, namespace, callback_id,
*args):
# When an event callback is received, the callback is returned back
# the sender, which is identified by the host_id
await self._publish({'method': 'callback', 'host_id': host_id,
'sid': sid, 'namespace': namespace,
'id': callback_id, 'args': args})
async def _handle_close_room(self, message):
await super().close_room(
room=message.get('room'), namespace=message.get('namespace'))
async def _thread(self):
while True:
try:
message = await self._listen()
except:
import traceback
traceback.print_exc()
break
data = None
if isinstance(message, dict):
data = message
else:
if isinstance(message, six.binary_type): # pragma: no cover
try:
data = pickle.loads(message)
except:
pass
if data is None:
try:
data = json.loads(message)
except:
pass
if data and 'method' in data:
if data['method'] == 'emit':
await self._handle_emit(data)
elif data['method'] == 'callback':
await self._handle_callback(data)
elif data['method'] == 'close_room':
await self._handle_close_room(data)

View file

@ -0,0 +1,107 @@
import asyncio
import pickle
from urllib.parse import urlparse
try:
import aioredis
except ImportError:
aioredis = None
from .asyncio_pubsub_manager import AsyncPubSubManager
def _parse_redis_url(url):
p = urlparse(url)
if p.scheme not in {'redis', 'rediss'}:
raise ValueError('Invalid redis url')
ssl = p.scheme == 'rediss'
host = p.hostname or 'localhost'
port = p.port or 6379
password = p.password
if p.path:
db = int(p.path[1:])
else:
db = 0
return host, port, password, db, ssl
class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover
"""Redis based client manager for asyncio servers.
This class implements a Redis backend for event sharing across multiple
processes. Only kept here as one more example of how to build a custom
backend, since the kombu backend is perfectly adequate to support a Redis
message queue.
To use a Redis backend, initialize the :class:`Server` instance as
follows::
server = socketio.Server(client_manager=socketio.AsyncRedisManager(
'redis://hostname:port/0'))
:param url: The connection URL for the Redis server. For a default Redis
store running on the same host, use ``redis://``. To use an
SSL connection, use ``rediss://``.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'aioredis'
def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None):
if aioredis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install aioredis" in your '
'virtualenv).')
(
self.host, self.port, self.password, self.db, self.ssl
) = _parse_redis_url(url)
self.pub = None
self.sub = None
super().__init__(channel=channel, write_only=write_only, logger=logger)
async def _publish(self, data):
retry = True
while True:
try:
if self.pub is None:
self.pub = await aioredis.create_redis(
(self.host, self.port), db=self.db,
password=self.password, ssl=self.ssl
)
return await self.pub.publish(self.channel,
pickle.dumps(data))
except (aioredis.RedisError, OSError):
if retry:
self._get_logger().error('Cannot publish to redis... '
'retrying')
self.pub = None
retry = False
else:
self._get_logger().error('Cannot publish to redis... '
'giving up')
break
async def _listen(self):
retry_sleep = 1
while True:
try:
if self.sub is None:
self.sub = await aioredis.create_redis(
(self.host, self.port), db=self.db,
password=self.password, ssl=self.ssl
)
self.ch = (await self.sub.subscribe(self.channel))[0]
return await self.ch.get()
except (aioredis.RedisError, OSError):
self._get_logger().error('Cannot receive from redis... '
'retrying in '
'{} secs'.format(retry_sleep))
self.sub = None
await asyncio.sleep(retry_sleep)
retry_sleep *= 2
if retry_sleep > 60:
retry_sleep = 60

View file

@ -0,0 +1,515 @@
import asyncio
import engineio
import six
from . import asyncio_manager
from . import exceptions
from . import packet
from . import server
class AsyncServer(server.Server):
"""A Socket.IO server for asyncio.
This class implements a fully compliant Socket.IO web server with support
for websocket and long-polling transports, compatible with the asyncio
framework on Python 3.5 or newer.
:param client_manager: The client manager instance that will manage the
client list. When this is omitted, the client list
is stored in an in-memory structure, so the use of
multiple connected servers is not possible.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, event handlers are executed in
separate threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "aiohttp". If
this argument is not given, an async mode is chosen
based on the installed packages.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting.
:param ping_interval: The interval in seconds at which the client pings
the server.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport.
:param allow_upgrades: Whether to allow transport upgrades or not.
:param http_compression: Whether to compress packages when using the
polling transport.
:param compression_threshold: Only compress messages when their byte size
is greater than this value.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
:param cors_allowed_origins: List of origins that are allowed to connect
to this server. All origins are allowed by
default.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``.
"""
def __init__(self, client_manager=None, logger=False, json=None,
async_handlers=True, **kwargs):
if client_manager is None:
client_manager = asyncio_manager.AsyncManager()
super().__init__(client_manager=client_manager, logger=logger,
binary=False, json=json,
async_handlers=async_handlers, **kwargs)
def is_asyncio_based(self):
return True
def attach(self, app, socketio_path='socket.io'):
"""Attach the Socket.IO server to an application."""
self.eio.attach(app, socketio_path)
async def emit(self, event, data=None, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
room = to or room
self.logger.info('emitting event "%s" to %s [%s]', event,
room or 'all', namespace)
await self.manager.emit(event, data, namespace, room=room,
skip_sid=skip_sid, callback=callback,
**kwargs)
async def send(self, data, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
Note: this method is a coroutine.
"""
await self.emit('message', data=data, to=to, room=room,
skip_sid=skip_sid, namespace=namespace,
callback=callback, **kwargs)
async def call(self, event, data=None, to=None, sid=None, namespace=None,
timeout=60, **kwargs):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The session ID of the recipient client.
:param sid: Alias for the ``to`` parameter.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
client directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
if not self.async_handlers:
raise RuntimeError(
'Cannot use call() when async_handlers is False.')
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
await self.emit(event, data=data, room=to or sid, namespace=namespace,
callback=event_callback, **kwargs)
try:
await asyncio.wait_for(callback_event.wait(), timeout)
except asyncio.TimeoutError:
six.raise_from(exceptions.TimeoutError(), None)
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
async def close_room(self, room, namespace=None):
"""Close a room.
This function removes all the clients from the given room.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
self.logger.info('room %s is closing [%s]', room, namespace)
await self.manager.close_room(room, namespace)
async def get_session(self, sid, namespace=None):
"""Return the user session for a client.
:param sid: The session id of the client.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved. If you want to modify
the user session, use the ``session`` context manager instead.
"""
namespace = namespace or '/'
eio_session = await self.eio.get_session(sid)
return eio_session.setdefault(namespace, {})
async def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
"""
namespace = namespace or '/'
eio_session = await self.eio.get_session(sid)
eio_session[namespace] = session
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@eio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with eio.session(sid) as session:
session['username'] = username
@eio.on('message')
def on_message(sid, msg):
async with eio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid, namespace):
self.server = server
self.sid = sid
self.namespace = namespace
self.session = None
async def __aenter__(self):
self.session = await self.server.get_session(
sid, namespace=self.namespace)
return self.session
async def __aexit__(self, *args):
await self.server.save_session(sid, self.session,
namespace=self.namespace)
return _session_context_manager(self, sid, namespace)
async def disconnect(self, sid, namespace=None):
"""Disconnect a client.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace to disconnect. If this
argument is omitted the default namespace is used.
Note: this method is a coroutine.
"""
namespace = namespace or '/'
if self.manager.is_connected(sid, namespace=namespace):
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
await self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
await self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
async def handle_request(self, *args, **kwargs):
"""Handle an HTTP request from the client.
This is the entry point of the Socket.IO application. This function
returns the HTTP response body to deliver to the client.
Note: this method is a coroutine.
"""
return await self.eio.handle_request(*args, **kwargs)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute. Must be a coroutine.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
The return value is a ``asyncio.Task`` object.
Note: this method is a coroutine.
"""
return self.eio.start_background_task(target, *args, **kwargs)
async def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
Note: this method is a coroutine.
"""
return await self.eio.sleep(seconds)
async def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
else:
data = [data]
await self._send_packet(sid, packet.Packet(
packet.EVENT, namespace=namespace, data=[event] + data, id=id,
binary=None))
async def _send_packet(self, sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
await self.eio.send(sid, ep, binary=binary)
binary = True
else:
await self.eio.send(sid, encoded_packet, binary=False)
async def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
fail_reason = None
try:
success = await self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
await self._send_packet(sid, packet.Packet(
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False
elif not self.always_connect:
await self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
async def _handle_disconnect(self, sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
await self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
await self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
async def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
else:
await self._handle_event_internal(self, sid, data, namespace, id)
async def _handle_event_internal(self, server, sid, data, namespace, id):
r = await server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
await server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data,
binary=None))
async def _handle_ack(self, sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
await self.manager.trigger_callback(sid, namespace, id, data)
async def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
if asyncio.iscoroutinefunction(self.handlers[namespace][event]) \
is True:
try:
ret = await self.handlers[namespace][event](*args)
except asyncio.CancelledError: # pragma: no cover
ret = None
else:
ret = self.handlers[namespace][event](*args)
return ret
# or else, forward the event to a namepsace handler if one exists
elif namespace in self.namespace_handlers:
return await self.namespace_handlers[namespace].trigger_event(
event, *args)
async def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
return await self._handle_connect(sid, '/')
async def _handle_eio_message(self, sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
if pkt.packet_type == packet.BINARY_EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id,
pkt.data)
else:
await self._handle_ack(sid, pkt.namespace, pkt.id,
pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
await self._handle_connect(sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
await self._handle_disconnect(sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
elif pkt.packet_type == packet.ERROR:
raise ValueError('Unexpected ERROR packet.')
else:
raise ValueError('Unknown packet type.')
async def _handle_eio_disconnect(self, sid):
"""Handle Engine.IO disconnect event."""
await self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
def _engineio_server_class(self):
return engineio.AsyncServer

View file

@ -0,0 +1,178 @@
import itertools
import logging
import six
default_logger = logging.getLogger('socketio')
class BaseManager(object):
"""Manage client connections.
This class keeps track of all the clients and the rooms they are in, to
support the broadcasting of messages. The data used by this class is
stored in a memory structure, making it appropriate only for single process
services. More sophisticated storage backends can be implemented by
subclasses.
"""
def __init__(self):
self.logger = None
self.server = None
self.rooms = {}
self.callbacks = {}
self.pending_disconnect = {}
def set_server(self, server):
self.server = server
def initialize(self):
"""Invoked before the first request is received. Subclasses can add
their initialization code here.
"""
pass
def get_namespaces(self):
"""Return an iterable with the active namespace names."""
return six.iterkeys(self.rooms)
def get_participants(self, namespace, room):
"""Return an iterable with the active participants in a room."""
for sid, active in six.iteritems(self.rooms[namespace][room].copy()):
yield sid
def connect(self, sid, namespace):
"""Register a client connection to a namespace."""
self.enter_room(sid, namespace, None)
self.enter_room(sid, namespace, sid)
def is_connected(self, sid, namespace):
if namespace in self.pending_disconnect and \
sid in self.pending_disconnect[namespace]:
# the client is in the process of being disconnected
return False
try:
return self.rooms[namespace][None][sid]
except KeyError:
pass
def pre_disconnect(self, sid, namespace):
"""Put the client in the to-be-disconnected list.
This allows the client data structures to be present while the
disconnect handler is invoked, but still recognize the fact that the
client is soon going away.
"""
if namespace not in self.pending_disconnect:
self.pending_disconnect[namespace] = []
self.pending_disconnect[namespace].append(sid)
def disconnect(self, sid, namespace):
"""Register a client disconnect from a namespace."""
if namespace not in self.rooms:
return
rooms = []
for room_name, room in six.iteritems(self.rooms[namespace].copy()):
if sid in room:
rooms.append(room_name)
for room in rooms:
self.leave_room(sid, namespace, room)
if sid in self.callbacks and namespace in self.callbacks[sid]:
del self.callbacks[sid][namespace]
if len(self.callbacks[sid]) == 0:
del self.callbacks[sid]
if namespace in self.pending_disconnect and \
sid in self.pending_disconnect[namespace]:
self.pending_disconnect[namespace].remove(sid)
if len(self.pending_disconnect[namespace]) == 0:
del self.pending_disconnect[namespace]
def enter_room(self, sid, namespace, room):
"""Add a client to a room."""
if namespace not in self.rooms:
self.rooms[namespace] = {}
if room not in self.rooms[namespace]:
self.rooms[namespace][room] = {}
self.rooms[namespace][room][sid] = True
def leave_room(self, sid, namespace, room):
"""Remove a client from a room."""
try:
del self.rooms[namespace][room][sid]
if len(self.rooms[namespace][room]) == 0:
del self.rooms[namespace][room]
if len(self.rooms[namespace]) == 0:
del self.rooms[namespace]
except KeyError:
pass
def close_room(self, room, namespace):
"""Remove all participants from a room."""
try:
for sid in self.get_participants(namespace, room):
self.leave_room(sid, namespace, room)
except KeyError:
pass
def get_rooms(self, sid, namespace):
"""Return the rooms a client is in."""
r = []
try:
for room_name, room in six.iteritems(self.rooms[namespace]):
if room_name is not None and sid in room and room[sid]:
r.append(room_name)
except KeyError:
pass
return r
def emit(self, event, data, namespace, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace."""
if namespace not in self.rooms or room not in self.rooms[namespace]:
return
if not isinstance(skip_sid, list):
skip_sid = [skip_sid]
for sid in self.get_participants(namespace, room):
if sid not in skip_sid:
if callback is not None:
id = self._generate_ack_id(sid, namespace, callback)
else:
id = None
self.server._emit_internal(sid, event, data, namespace, id)
def trigger_callback(self, sid, namespace, id, data):
"""Invoke an application callback."""
callback = None
try:
callback = self.callbacks[sid][namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self._get_logger().warning('Unknown callback received, ignoring.')
else:
del self.callbacks[sid][namespace][id]
if callback is not None:
callback(*data)
def _generate_ack_id(self, sid, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if sid not in self.callbacks:
self.callbacks[sid] = {}
if namespace not in self.callbacks[sid]:
self.callbacks[sid][namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[sid][namespace][0])
self.callbacks[sid][namespace][id] = callback
return id
def _get_logger(self):
"""Get the appropriate logger
Prevents uninitialized servers in write-only mode from failing.
"""
if self.logger:
return self.logger
elif self.server:
return self.server.logger
else:
return default_logger

View file

@ -0,0 +1,590 @@
import itertools
import logging
import random
import signal
import engineio
import six
from . import exceptions
from . import namespace
from . import packet
default_logger = logging.getLogger('socketio.client')
reconnecting_clients = []
def signal_handler(sig, frame): # pragma: no cover
"""SIGINT handler.
Notify any clients that are in a reconnect loop to abort. Other
disconnection tasks are handled at the engine.io level.
"""
for client in reconnecting_clients[:]:
client._reconnect_abort.set()
return original_signal_handler(sig, frame)
original_signal_handler = signal.signal(signal.SIGINT, signal_handler)
class Client(object):
"""A Socket.IO client.
This class implements a fully compliant Socket.IO web client with support
for websocket and long-polling transports.
:param reconnection: ``True`` if the client should automatically attempt to
reconnect to the server after an interruption, or
``False`` to not reconnect. The default is ``True``.
:param reconnection_attempts: How many reconnection attempts to issue
before giving up, or 0 for infinity attempts.
The default is 0.
:param reconnection_delay: How long to wait in seconds before the first
reconnection attempt. Each successive attempt
doubles this delay.
:param reconnection_delay_max: The maximum delay between reconnection
attempts.
:param randomization_factor: Randomization amount for each delay between
reconnection attempts. The default is 0.5,
which means that each delay is randomly
adjusted by +/- 50%.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
The Engine.IO configuration supports the following settings:
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def __init__(self, reconnection=True, reconnection_attempts=0,
reconnection_delay=1, reconnection_delay_max=5,
randomization_factor=0.5, logger=False, binary=False,
json=None, **kwargs):
self.reconnection = reconnection
self.reconnection_attempts = reconnection_attempts
self.reconnection_delay = reconnection_delay
self.reconnection_delay_max = reconnection_delay_max
self.randomization_factor = randomization_factor
self.binary = binary
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if json is not None:
packet.Packet.json = json
engineio_options['json'] = json
self.eio = self._engineio_client_class()(**engineio_options)
self.eio.on('connect', self._handle_eio_connect)
self.eio.on('message', self._handle_eio_message)
self.eio.on('disconnect', self._handle_eio_disconnect)
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
self.connection_url = None
self.connection_headers = None
self.connection_transports = None
self.connection_namespaces = None
self.socketio_path = None
self.sid = None
self.namespaces = []
self.handlers = {}
self.namespace_handlers = {}
self.callbacks = {}
self._binary_packet = None
self._reconnect_task = None
self._reconnect_abort = self.eio.create_event()
def is_asyncio_based(self):
return False
def on(self, event, handler=None, namespace=None):
"""Register an event handler.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the handler is associated with
the default namespace.
Example usage::
# as a decorator:
@sio.on('connect')
def connect_handler():
print('Connected!')
# as a method:
def message_handler(msg):
print('Received message: ', msg)
sio.send( 'response')
sio.on('message', message_handler)
The ``'connect'`` event handler receives no arguments. The
``'message'`` handler and handlers for custom event names receive the
message payload as only argument. Any values returned from a message
handler will be passed to the client's acknowledgement callback
function if it exists. The ``'disconnect'`` handler does not take
arguments.
"""
namespace = namespace or '/'
def set_handler(handler):
if namespace not in self.handlers:
self.handlers[namespace] = {}
self.handlers[namespace][event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def event(self, *args, **kwargs):
"""Decorator to register an event handler.
This is a simplified version of the ``on()`` method that takes the
event name from the decorated function.
Example usage::
@sio.event
def my_event(data):
print('Received data: ', data)
The above example is equivalent to::
@sio.on('my_event')
def my_event(data):
print('Received data: ', data)
A custom namespace can be given as an argument to the decorator::
@sio.event(namespace='/test')
def my_event(data):
print('Received data: ', data)
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# the decorator was invoked without arguments
# args[0] is the decorated function
return self.on(args[0].__name__)(args[0])
else:
# the decorator was invoked with arguments
def set_handler(handler):
return self.on(handler.__name__, *args, **kwargs)(handler)
return set_handler
def register_namespace(self, namespace_handler):
"""Register a namespace handler object.
:param namespace_handler: An instance of a :class:`Namespace`
subclass that handles all the event traffic
for a namespace.
"""
if not isinstance(namespace_handler, namespace.ClientNamespace):
raise ValueError('Not a namespace instance')
if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
raise ValueError('Not a valid namespace class for this client')
namespace_handler._set_client(self)
self.namespace_handlers[namespace_handler.namespace] = \
namespace_handler
def connect(self, url, headers={}, transports=None,
namespaces=None, socketio_path='socket.io'):
"""Connect to a Socket.IO server.
:param url: The URL of the Socket.IO server. It can include custom
query string parameters if required by the server.
:param headers: A dictionary with custom headers to send with the
connection request.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. If not
given, the polling transport is connected first,
then an upgrade to websocket is attempted.
:param namespaces: The list of custom namespaces to connect, in
addition to the default namespace. If not given,
the namespace list is obtained from the registered
event handlers.
:param socketio_path: The endpoint where the Socket.IO server is
installed. The default value is appropriate for
most cases.
Example usage::
sio = socketio.Client()
sio.connect('http://localhost:5000')
"""
self.connection_url = url
self.connection_headers = headers
self.connection_transports = transports
self.connection_namespaces = namespaces
self.socketio_path = socketio_path
if namespaces is None:
namespaces = set(self.handlers.keys()).union(
set(self.namespace_handlers.keys()))
elif isinstance(namespaces, six.string_types):
namespaces = [namespaces]
self.connection_namespaces = namespaces
self.namespaces = [n for n in namespaces if n != '/']
try:
self.eio.connect(url, headers=headers, transports=transports,
engineio_path=socketio_path)
except engineio.exceptions.ConnectionError as exc:
six.raise_from(exceptions.ConnectionError(exc.args[0]), None)
def wait(self):
"""Wait until the connection with the server ends.
Client applications can use this function to block the main thread
during the life of the connection.
"""
while True:
self.eio.wait()
self.sleep(1) # give the reconnect task time to start up
if not self._reconnect_task:
break
self._reconnect_task.join()
if self.eio.state != 'connected':
break
def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
"""
namespace = namespace or '/'
self.logger.info('Emitting event "%s" [%s]', event, namespace)
if callback is not None:
id = self._generate_ack_id(namespace, callback)
else:
id = None
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
elif data is not None:
data = [data]
else:
data = []
self._send_packet(packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id,
binary=binary))
def send(self, data, namespace=None, callback=None):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
"""
self.emit('message', data=data, namespace=namespace,
callback=callback)
def call(self, event, data=None, namespace=None, timeout=60):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
"""
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
self.emit(event, data=data, namespace=namespace,
callback=event_callback)
if not callback_event.wait(timeout=timeout):
raise exceptions.TimeoutError()
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
def disconnect(self):
"""Disconnect from the server."""
# here we just request the disconnection
# later in _handle_eio_disconnect we invoke the disconnect handler
for n in self.namespaces:
self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n))
self._send_packet(packet.Packet(
packet.DISCONNECT, namespace='/'))
self.eio.disconnect(abort=True)
def transport(self):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
"""
return self.eio.transport()
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self.eio.sleep(seconds)
def _send_packet(self, pkt):
"""Send a Socket.IO packet to the server."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
self.eio.send(ep, binary=binary)
binary = True
else:
self.eio.send(encoded_packet, binary=False)
def _generate_ack_id(self, namespace, callback):
"""Generate a unique identifier for an ACK packet."""
namespace = namespace or '/'
if namespace not in self.callbacks:
self.callbacks[namespace] = {0: itertools.count(1)}
id = six.next(self.callbacks[namespace][0])
self.callbacks[namespace][id] = callback
return id
def _handle_connect(self, namespace):
namespace = namespace or '/'
self.logger.info('Namespace {} is connected'.format(namespace))
self._trigger_event('connect', namespace=namespace)
if namespace == '/':
for n in self.namespaces:
self._send_packet(packet.Packet(packet.CONNECT, namespace=n))
elif namespace not in self.namespaces:
self.namespaces.append(namespace)
def _handle_disconnect(self, namespace):
namespace = namespace or '/'
self._trigger_event('disconnect', namespace=namespace)
if namespace in self.namespaces:
self.namespaces.remove(namespace)
def _handle_event(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received event "%s" [%s]', data[0], namespace)
r = self._trigger_event(data[0], namespace, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
self._send_packet(packet.Packet(packet.ACK, namespace=namespace,
id=id, data=data, binary=binary))
def _handle_ack(self, namespace, id, data):
namespace = namespace or '/'
self.logger.info('Received ack [%s]', namespace)
callback = None
try:
callback = self.callbacks[namespace][id]
except KeyError:
# if we get an unknown callback we just ignore it
self.logger.warning('Unknown callback received, ignoring.')
else:
del self.callbacks[namespace][id]
if callback is not None:
callback(*data)
def _handle_error(self, namespace):
namespace = namespace or '/'
self.logger.info('Connection to namespace {} was rejected'.format(
namespace))
if namespace in self.namespaces:
self.namespaces.remove(namespace)
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)
# or else, forward the event to a namespace handler if one exists
elif namespace in self.namespace_handlers:
return self.namespace_handlers[namespace].trigger_event(
event, *args)
def _handle_reconnect(self):
self._reconnect_abort.clear()
reconnecting_clients.append(self)
attempt_count = 0
current_delay = self.reconnection_delay
while True:
delay = current_delay
current_delay *= 2
if delay > self.reconnection_delay_max:
delay = self.reconnection_delay_max
delay += self.randomization_factor * (2 * random.random() - 1)
self.logger.info(
'Connection failed, new attempt in {:.02f} seconds'.format(
delay))
print('***', self._reconnect_abort.wait)
if self._reconnect_abort.wait(delay):
self.logger.info('Reconnect task aborted')
break
attempt_count += 1
try:
self.connect(self.connection_url,
headers=self.connection_headers,
transports=self.connection_transports,
namespaces=self.connection_namespaces,
socketio_path=self.socketio_path)
except (exceptions.ConnectionError, ValueError):
pass
else:
self.logger.info('Reconnection successful')
self._reconnect_task = None
break
if self.reconnection_attempts and \
attempt_count >= self.reconnection_attempts:
self.logger.info(
'Maximum reconnection attempts reached, giving up')
break
reconnecting_clients.remove(self)
def _handle_eio_connect(self):
"""Handle the Engine.IO connection event."""
self.logger.info('Engine.IO connection established')
self.sid = self.eio.sid
def _handle_eio_message(self, data):
"""Dispatch Engine.IO messages."""
if self._binary_packet:
pkt = self._binary_packet
if pkt.add_attachment(data):
self._binary_packet = None
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(pkt.namespace, pkt.id, pkt.data)
else:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet = pkt
elif pkt.packet_type == packet.ERROR:
self._handle_error(pkt.namespace)
else:
raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self):
"""Handle the Engine.IO disconnection event."""
self.logger.info('Engine.IO connection dropped')
for n in self.namespaces:
self._trigger_event('disconnect', namespace=n)
self._trigger_event('disconnect', namespace='/')
self.callbacks = {}
self._binary_packet = None
self.sid = None
if self.eio.state == 'connected' and self.reconnection:
self._reconnect_task = self.start_background_task(
self._handle_reconnect)
def _engineio_client_class(self):
return engineio.Client

View file

@ -0,0 +1,26 @@
class SocketIOError(Exception):
pass
class ConnectionError(SocketIOError):
pass
class ConnectionRefusedError(ConnectionError):
"""Connection refused exception.
This exception can be raised from a connect handler when the connection
is not accepted. The positional arguments provided with the exception are
returned with the error packet to the client.
"""
def __init__(self, *args):
if len(args) == 0:
self.error_args = None
elif len(args) == 1:
self.error_args = args[0]
else:
self.error_args = args
class TimeoutError(SocketIOError):
pass

View file

@ -0,0 +1,105 @@
import pickle
import uuid
try:
import kombu
except ImportError:
kombu = None
from .pubsub_manager import PubSubManager
class KombuManager(PubSubManager): # pragma: no cover
"""Client manager that uses kombu for inter-process messaging.
This class implements a client manager backend for event sharing across
multiple processes, using RabbitMQ, Redis or any other messaging mechanism
supported by `kombu <http://kombu.readthedocs.org/en/latest/>`_.
To use a kombu backend, initialize the :class:`Server` instance as
follows::
url = 'amqp://user:password@hostname:port//'
server = socketio.Server(client_manager=socketio.KombuManager(url))
:param url: The connection URL for the backend messaging queue. Example
connection URLs are ``'amqp://guest:guest@localhost:5672//'``
and ``'redis://localhost:6379/'`` for RabbitMQ and Redis
respectively. Consult the `kombu documentation
<http://kombu.readthedocs.org/en/latest/userguide\
/connections.html#urls>`_ for more on how to construct
connection URLs.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'kombu'
def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None):
if kombu is None:
raise RuntimeError('Kombu package is not installed '
'(Run "pip install kombu" in your '
'virtualenv).')
super(KombuManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
self.url = url
self.producer = self._producer()
def initialize(self):
super(KombuManager, self).initialize()
monkey_patched = True
if self.server.async_mode == 'eventlet':
from eventlet.patcher import is_monkey_patched
monkey_patched = is_monkey_patched('socket')
elif 'gevent' in self.server.async_mode:
from gevent.monkey import is_module_patched
monkey_patched = is_module_patched('socket')
if not monkey_patched:
raise RuntimeError(
'Kombu requires a monkey patched socket library to work '
'with ' + self.server.async_mode)
def _connection(self):
return kombu.Connection(self.url)
def _exchange(self):
return kombu.Exchange(self.channel, type='fanout', durable=False)
def _queue(self):
queue_name = 'flask-socketio.' + str(uuid.uuid4())
return kombu.Queue(queue_name, self._exchange(),
durable=False,
queue_arguments={'x-expires': 300000})
def _producer(self):
return self._connection().Producer(exchange=self._exchange())
def __error_callback(self, exception, interval):
self._get_logger().exception('Sleeping {}s'.format(interval))
def _publish(self, data):
connection = self._connection()
publish = connection.ensure(self.producer, self.producer.publish,
errback=self.__error_callback)
publish(pickle.dumps(data))
def _listen(self):
reader_queue = self._queue()
while True:
connection = self._connection().ensure_connection(
errback=self.__error_callback)
try:
with connection.SimpleQueue(reader_queue) as queue:
while True:
message = queue.get(block=True)
message.ack()
yield message.payload
except connection.connection_errors:
self._get_logger().exception("Connection error "
"while reading from queue")

View file

@ -0,0 +1,42 @@
import engineio
class WSGIApp(engineio.WSGIApp):
"""WSGI middleware for Socket.IO.
This middleware dispatches traffic to a Socket.IO application. It can also
serve a list of static files to the client, or forward unrelated HTTP
traffic to another WSGI application.
:param socketio_app: The Socket.IO server. Must be an instance of the
``socketio.Server`` class.
:param wsgi_app: The WSGI app that receives all other traffic.
:param static_files: A dictionary with static file mapping rules. See the
documentation for details on this argument.
:param socketio_path: The endpoint where the Socket.IO application should
be installed. The default value is appropriate for
most cases.
Example usage::
import socketio
import eventlet
from . import wsgi_app
sio = socketio.Server()
app = socketio.WSGIApp(sio, wsgi_app)
eventlet.wsgi.server(eventlet.listen(('', 8000)), app)
"""
def __init__(self, socketio_app, wsgi_app=None, static_files=None,
socketio_path='socket.io'):
super(WSGIApp, self).__init__(socketio_app, wsgi_app,
static_files=static_files,
engineio_path=socketio_path)
class Middleware(WSGIApp):
"""This class has been renamed to WSGIApp and is now deprecated."""
def __init__(self, socketio_app, wsgi_app=None,
socketio_path='socket.io'):
super(Middleware, self).__init__(socketio_app, wsgi_app,
socketio_path=socketio_path)

View file

@ -0,0 +1,191 @@
class BaseNamespace(object):
def __init__(self, namespace=None):
self.namespace = namespace or '/'
def is_asyncio_based(self):
return False
def trigger_event(self, event, *args):
"""Dispatch an event to the proper handler method.
In the most common usage, this method is not overloaded by subclasses,
as it performs the routing of events to methods. However, this
method can be overriden if special dispatching rules are needed, or if
having a single method that catches all events is desired.
"""
handler_name = 'on_' + event
if hasattr(self, handler_name):
return getattr(self, handler_name)(*args)
class Namespace(BaseNamespace):
"""Base class for server-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def __init__(self, namespace=None):
super(Namespace, self).__init__(namespace=namespace)
self.server = None
def _set_server(self, server):
self.server = server
def emit(self, event, data=None, room=None, skip_sid=None, namespace=None,
callback=None):
"""Emit a custom event to one or more connected clients.
The only difference with the :func:`socketio.Server.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.emit(event, data=data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to one or more connected clients.
The only difference with the :func:`socketio.Server.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.send(data, room=room, skip_sid=skip_sid,
namespace=namespace or self.namespace,
callback=callback)
def enter_room(self, sid, room, namespace=None):
"""Enter a room.
The only difference with the :func:`socketio.Server.enter_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.enter_room(sid, room,
namespace=namespace or self.namespace)
def leave_room(self, sid, room, namespace=None):
"""Leave a room.
The only difference with the :func:`socketio.Server.leave_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.leave_room(sid, room,
namespace=namespace or self.namespace)
def close_room(self, room, namespace=None):
"""Close a room.
The only difference with the :func:`socketio.Server.close_room` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.close_room(room,
namespace=namespace or self.namespace)
def rooms(self, sid, namespace=None):
"""Return the rooms a client is in.
The only difference with the :func:`socketio.Server.rooms` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.rooms(sid, namespace=namespace or self.namespace)
def get_session(self, sid, namespace=None):
"""Return the user session for a client.
The only difference with the :func:`socketio.Server.get_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
"""
return self.server.get_session(
sid, namespace=namespace or self.namespace)
def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
The only difference with the :func:`socketio.Server.save_session`
method is that when the ``namespace`` argument is not given the
namespace associated with the class is used.
"""
return self.server.save_session(
sid, session, namespace=namespace or self.namespace)
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
The only difference with the :func:`socketio.Server.session` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.session(sid, namespace=namespace or self.namespace)
def disconnect(self, sid, namespace=None):
"""Disconnect a client.
The only difference with the :func:`socketio.Server.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.server.disconnect(sid,
namespace=namespace or self.namespace)
class ClientNamespace(BaseNamespace):
"""Base class for client-side class-based namespaces.
A class-based namespace is a class that contains all the event handlers
for a Socket.IO namespace. The event handlers are methods of the class
with the prefix ``on_``, such as ``on_connect``, ``on_disconnect``,
``on_message``, ``on_json``, and so on.
:param namespace: The Socket.IO namespace to be used with all the event
handlers defined in this class. If this argument is
omitted, the default namespace is used.
"""
def __init__(self, namespace=None):
super(ClientNamespace, self).__init__(namespace=namespace)
self.client = None
def _set_client(self, client):
self.client = client
def emit(self, event, data=None, namespace=None, callback=None):
"""Emit a custom event to the server.
The only difference with the :func:`socketio.Client.emit` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.emit(event, data=data,
namespace=namespace or self.namespace,
callback=callback)
def send(self, data, room=None, skip_sid=None, namespace=None,
callback=None):
"""Send a message to the server.
The only difference with the :func:`socketio.Client.send` method is
that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.send(data, namespace=namespace or self.namespace,
callback=callback)
def disconnect(self):
"""Disconnect from the server.
The only difference with the :func:`socketio.Client.disconnect` method
is that when the ``namespace`` argument is not given the namespace
associated with the class is used.
"""
return self.client.disconnect()

View file

@ -0,0 +1,179 @@
import functools
import json as _json
import six
(CONNECT, DISCONNECT, EVENT, ACK, ERROR, BINARY_EVENT, BINARY_ACK) = \
(0, 1, 2, 3, 4, 5, 6)
packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'ERROR',
'BINARY_EVENT', 'BINARY_ACK']
class Packet(object):
"""Socket.IO packet."""
# the format of the Socket.IO packet is as follows:
#
# packet type: 1 byte, values 0-6
# num_attachments: ASCII encoded, only if num_attachments != 0
# '-': only if num_attachments != 0
# namespace: only if namespace != '/'
# ',': only if namespace and one of id and data are defined in this packet
# id: ASCII encoded, only if id is not None
# data: JSON dump of data payload
json = _json
def __init__(self, packet_type=EVENT, data=None, namespace=None, id=None,
binary=None, encoded_packet=None):
self.packet_type = packet_type
self.data = data
self.namespace = namespace
self.id = id
if binary or (binary is None and self._data_is_binary(self.data)):
if self.packet_type == EVENT:
self.packet_type = BINARY_EVENT
elif self.packet_type == ACK:
self.packet_type = BINARY_ACK
else:
raise ValueError('Packet does not support binary payload.')
self.attachment_count = 0
self.attachments = []
if encoded_packet:
self.attachment_count = self.decode(encoded_packet)
def encode(self):
"""Encode the packet for transmission.
If the packet contains binary elements, this function returns a list
of packets where the first is the original packet with placeholders for
the binary components and the remaining ones the binary attachments.
"""
encoded_packet = six.text_type(self.packet_type)
if self.packet_type == BINARY_EVENT or self.packet_type == BINARY_ACK:
data, attachments = self._deconstruct_binary(self.data)
encoded_packet += six.text_type(len(attachments)) + '-'
else:
data = self.data
attachments = None
needs_comma = False
if self.namespace is not None and self.namespace != '/':
encoded_packet += self.namespace
needs_comma = True
if self.id is not None:
if needs_comma:
encoded_packet += ','
needs_comma = False
encoded_packet += six.text_type(self.id)
if data is not None:
if needs_comma:
encoded_packet += ','
encoded_packet += self.json.dumps(data, separators=(',', ':'))
if attachments is not None:
encoded_packet = [encoded_packet] + attachments
return encoded_packet
def decode(self, encoded_packet):
"""Decode a transmitted package.
The return value indicates how many binary attachment packets are
necessary to fully decode the packet.
"""
ep = encoded_packet
try:
self.packet_type = int(ep[0:1])
except TypeError:
self.packet_type = ep
ep = ''
self.namespace = None
self.data = None
ep = ep[1:]
dash = ep.find('-')
attachment_count = 0
if dash > 0 and ep[0:dash].isdigit():
attachment_count = int(ep[0:dash])
ep = ep[dash + 1:]
if ep and ep[0:1] == '/':
sep = ep.find(',')
if sep == -1:
self.namespace = ep
ep = ''
else:
self.namespace = ep[0:sep]
ep = ep[sep + 1:]
q = self.namespace.find('?')
if q != -1:
self.namespace = self.namespace[0:q]
if ep and ep[0].isdigit():
self.id = 0
while ep and ep[0].isdigit():
self.id = self.id * 10 + int(ep[0])
ep = ep[1:]
if ep:
self.data = self.json.loads(ep)
return attachment_count
def add_attachment(self, attachment):
if self.attachment_count <= len(self.attachments):
raise ValueError('Unexpected binary attachment')
self.attachments.append(attachment)
if self.attachment_count == len(self.attachments):
self.reconstruct_binary(self.attachments)
return True
return False
def reconstruct_binary(self, attachments):
"""Reconstruct a decoded packet using the given list of binary
attachments.
"""
self.data = self._reconstruct_binary_internal(self.data,
self.attachments)
def _reconstruct_binary_internal(self, data, attachments):
if isinstance(data, list):
return [self._reconstruct_binary_internal(item, attachments)
for item in data]
elif isinstance(data, dict):
if data.get('_placeholder') and 'num' in data:
return attachments[data['num']]
else:
return {key: self._reconstruct_binary_internal(value,
attachments)
for key, value in six.iteritems(data)}
else:
return data
def _deconstruct_binary(self, data):
"""Extract binary components in the packet."""
attachments = []
data = self._deconstruct_binary_internal(data, attachments)
return data, attachments
def _deconstruct_binary_internal(self, data, attachments):
if isinstance(data, six.binary_type):
attachments.append(data)
return {'_placeholder': True, 'num': len(attachments) - 1}
elif isinstance(data, list):
return [self._deconstruct_binary_internal(item, attachments)
for item in data]
elif isinstance(data, dict):
return {key: self._deconstruct_binary_internal(value, attachments)
for key, value in six.iteritems(data)}
else:
return data
def _data_is_binary(self, data):
"""Check if the data contains binary components."""
if isinstance(data, six.binary_type):
return True
elif isinstance(data, list):
return functools.reduce(
lambda a, b: a or b, [self._data_is_binary(item)
for item in data], False)
elif isinstance(data, dict):
return functools.reduce(
lambda a, b: a or b, [self._data_is_binary(item)
for item in six.itervalues(data)],
False)
else:
return False

View file

@ -0,0 +1,154 @@
from functools import partial
import uuid
import json
import pickle
import six
from .base_manager import BaseManager
class PubSubManager(BaseManager):
"""Manage a client list attached to a pub/sub backend.
This is a base class that enables multiple servers to share the list of
clients, with the servers communicating events through a pub/sub backend.
The use of a pub/sub backend also allows any client connected to the
backend to emit events addressed to Socket.IO clients.
The actual backends must be implemented by subclasses, this class only
provides a pub/sub generic framework.
:param channel: The channel name on which the server sends and receives
notifications.
"""
name = 'pubsub'
def __init__(self, channel='socketio', write_only=False, logger=None):
super(PubSubManager, self).__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
def initialize(self):
super(PubSubManager, self).initialize()
if not self.write_only:
self.thread = self.server.start_background_task(self._thread)
self._get_logger().info(self.name + ' backend initialized.')
def emit(self, event, data, namespace=None, room=None, skip_sid=None,
callback=None, **kwargs):
"""Emit a message to a single client, a room, or all the clients
connected to the namespace.
This method takes care or propagating the message to all the servers
that are connected through the message queue.
The parameters are the same as in :meth:`.Server.emit`.
"""
if kwargs.get('ignore_queue'):
return super(PubSubManager, self).emit(
event, data, namespace=namespace, room=room, skip_sid=skip_sid,
callback=callback)
namespace = namespace or '/'
if callback is not None:
if self.server is None:
raise RuntimeError('Callbacks can only be issued from the '
'context of a server.')
if room is None:
raise ValueError('Cannot use callback without a room set.')
id = self._generate_ack_id(room, namespace, callback)
callback = (room, namespace, id)
else:
callback = None
self._publish({'method': 'emit', 'event': event, 'data': data,
'namespace': namespace, 'room': room,
'skip_sid': skip_sid, 'callback': callback,
'host_id': self.host_id})
def close_room(self, room, namespace=None):
self._publish({'method': 'close_room', 'room': room,
'namespace': namespace or '/'})
def _publish(self, data):
"""Publish a message on the Socket.IO channel.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
def _listen(self):
"""Return the next message published on the Socket.IO channel,
blocking until a message is available.
This method needs to be implemented by the different subclasses that
support pub/sub backends.
"""
raise NotImplementedError('This method must be implemented in a '
'subclass.') # pragma: no cover
def _handle_emit(self, message):
# Events with callbacks are very tricky to handle across hosts
# Here in the receiving end we set up a local callback that preserves
# the callback host and id from the sender
remote_callback = message.get('callback')
remote_host_id = message.get('host_id')
if remote_callback is not None and len(remote_callback) == 3:
callback = partial(self._return_callback, remote_host_id,
*remote_callback)
else:
callback = None
super(PubSubManager, self).emit(message['event'], message['data'],
namespace=message.get('namespace'),
room=message.get('room'),
skip_sid=message.get('skip_sid'),
callback=callback)
def _handle_callback(self, message):
if self.host_id == message.get('host_id'):
try:
sid = message['sid']
namespace = message['namespace']
id = message['id']
args = message['args']
except KeyError:
return
self.trigger_callback(sid, namespace, id, args)
def _return_callback(self, host_id, sid, namespace, callback_id, *args):
# When an event callback is received, the callback is returned back
# the sender, which is identified by the host_id
self._publish({'method': 'callback', 'host_id': host_id,
'sid': sid, 'namespace': namespace, 'id': callback_id,
'args': args})
def _handle_close_room(self, message):
super(PubSubManager, self).close_room(
room=message.get('room'), namespace=message.get('namespace'))
def _thread(self):
for message in self._listen():
data = None
if isinstance(message, dict):
data = message
else:
if isinstance(message, six.binary_type): # pragma: no cover
try:
data = pickle.loads(message)
except:
pass
if data is None:
try:
data = json.loads(message)
except:
pass
if data and 'method' in data:
if data['method'] == 'emit':
self._handle_emit(data)
elif data['method'] == 'callback':
self._handle_callback(data)
elif data['method'] == 'close_room':
self._handle_close_room(data)

View file

@ -0,0 +1,111 @@
import logging
import pickle
import time
try:
import redis
except ImportError:
redis = None
from .pubsub_manager import PubSubManager
logger = logging.getLogger('socketio')
class RedisManager(PubSubManager): # pragma: no cover
"""Redis based client manager.
This class implements a Redis backend for event sharing across multiple
processes. Only kept here as one more example of how to build a custom
backend, since the kombu backend is perfectly adequate to support a Redis
message queue.
To use a Redis backend, initialize the :class:`Server` instance as
follows::
url = 'redis://hostname:port/0'
server = socketio.Server(client_manager=socketio.RedisManager(url))
:param url: The connection URL for the Redis server. For a default Redis
store running on the same host, use ``redis://``.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set ot ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
"""
name = 'redis'
def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None):
if redis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install redis" in your '
'virtualenv).')
self.redis_url = url
self._redis_connect()
super(RedisManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
def initialize(self):
super(RedisManager, self).initialize()
monkey_patched = True
if self.server.async_mode == 'eventlet':
from eventlet.patcher import is_monkey_patched
monkey_patched = is_monkey_patched('socket')
elif 'gevent' in self.server.async_mode:
from gevent.monkey import is_module_patched
monkey_patched = is_module_patched('socket')
if not monkey_patched:
raise RuntimeError(
'Redis requires a monkey patched socket library to work '
'with ' + self.server.async_mode)
def _redis_connect(self):
self.redis = redis.Redis.from_url(self.redis_url)
self.pubsub = self.redis.pubsub()
def _publish(self, data):
retry = True
while True:
try:
if not retry:
self._redis_connect()
return self.redis.publish(self.channel, pickle.dumps(data))
except redis.exceptions.ConnectionError:
if retry:
logger.error('Cannot publish to redis... retrying')
retry = False
else:
logger.error('Cannot publish to redis... giving up')
break
def _redis_listen_with_retries(self):
retry_sleep = 1
connect = False
while True:
try:
if connect:
self._redis_connect()
self.pubsub.subscribe(self.channel)
for message in self.pubsub.listen():
yield message
except redis.exceptions.ConnectionError:
logger.error('Cannot receive from redis... '
'retrying in {} secs'.format(retry_sleep))
connect = True
time.sleep(retry_sleep)
retry_sleep *= 2
if retry_sleep > 60:
retry_sleep = 60
def _listen(self):
channel = self.channel.encode('utf-8')
self.pubsub.subscribe(self.channel)
for message in self._redis_listen_with_retries():
if message['channel'] == channel and \
message['type'] == 'message' and 'data' in message:
yield message['data']
self.pubsub.unsubscribe(self.channel)

View file

@ -0,0 +1,719 @@
import logging
import engineio
import six
from . import base_manager
from . import exceptions
from . import namespace
from . import packet
default_logger = logging.getLogger('socketio.server')
class Server(object):
"""A Socket.IO server.
This class implements a fully compliant Socket.IO web server with support
for websocket and long-polling transports.
:param client_manager: The client manager instance that will manage the
client list. When this is omitted, the client list
is stored in an in-memory structure, so the use of
multiple connected servers is not possible.
:param logger: To enable logging set to ``True`` or pass a logger object to
use. To disable logging set to ``False``. The default is
``False``.
:param binary: ``True`` to support binary payloads, ``False`` to treat all
payloads as text. On Python 2, if this is set to ``True``,
``unicode`` values are treated as text, and ``str`` and
``bytes`` values are treated as binary. This option has no
effect on Python 3, where text and binary payloads are
always automatically discovered.
:param json: An alternative json module to use for encoding and decoding
packets. Custom json modules must have ``dumps`` and ``loads``
functions that are compatible with the standard library
versions.
:param async_handlers: If set to ``True``, event handlers for a client are
executed in separate threads. To run handlers for a
client synchronously, set to ``False``. The default
is ``True``.
:param always_connect: When set to ``False``, new connections are
provisory until the connect handler returns
something other than ``False``, at which point they
are accepted. When set to ``True``, connections are
immediately accepted, and then if the connect
handler returns ``False`` a disconnect is issued.
Set to ``True`` if you need to emit events from the
connect handler and your client is confused when it
receives events before the connection acceptance.
In any other case use the default of ``False``.
:param kwargs: Connection parameters for the underlying Engine.IO server.
The Engine.IO configuration supports the following settings:
:param async_mode: The asynchronous model to use. See the Deployment
section in the documentation for a description of the
available options. Valid async modes are "threading",
"eventlet", "gevent" and "gevent_uwsgi". If this
argument is not given, "eventlet" is tried first, then
"gevent_uwsgi", then "gevent", and finally "threading".
The first async mode that has all its dependencies
installed is then one that is chosen.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 60 seconds.
:param ping_interval: The interval in seconds at which the client pings
the server. The default is 25 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 100,000,000
bytes.
:param allow_upgrades: Whether to allow transport upgrades or not. The
default is ``True``.
:param http_compression: Whether to compress packages when using the
polling transport. The default is ``True``.
:param compression_threshold: Only compress messages when their byte size
is greater than this value. The default is
1024 bytes.
:param cookie: Name of the HTTP cookie that contains the client session
id. If set to ``None``, a cookie is not sent to the client.
The default is ``'io'``.
:param cors_allowed_origins: List of origins that are allowed to connect
to this server. All origins are allowed by
default.
:param cors_credentials: Whether credentials (cookies, authentication) are
allowed in requests to this server. The default is
``True``.
:param engineio_logger: To enable Engine.IO logging set to ``True`` or pass
a logger object to use. To disable logging set to
``False``. The default is ``False``.
"""
def __init__(self, client_manager=None, logger=False, binary=False,
json=None, async_handlers=True, always_connect=False,
**kwargs):
engineio_options = kwargs
engineio_logger = engineio_options.pop('engineio_logger', None)
if engineio_logger is not None:
engineio_options['logger'] = engineio_logger
if json is not None:
packet.Packet.json = json
engineio_options['json'] = json
engineio_options['async_handlers'] = False
self.eio = self._engineio_server_class()(**engineio_options)
self.eio.on('connect', self._handle_eio_connect)
self.eio.on('message', self._handle_eio_message)
self.eio.on('disconnect', self._handle_eio_disconnect)
self.binary = binary
self.environ = {}
self.handlers = {}
self.namespace_handlers = {}
self._binary_packet = {}
if not isinstance(logger, bool):
self.logger = logger
else:
self.logger = default_logger
if not logging.root.handlers and \
self.logger.level == logging.NOTSET:
if logger:
self.logger.setLevel(logging.INFO)
else:
self.logger.setLevel(logging.ERROR)
self.logger.addHandler(logging.StreamHandler())
if client_manager is None:
client_manager = base_manager.BaseManager()
self.manager = client_manager
self.manager.set_server(self)
self.manager_initialized = False
self.async_handlers = async_handlers
self.always_connect = always_connect
self.async_mode = self.eio.async_mode
def is_asyncio_based(self):
return False
def on(self, event, handler=None, namespace=None):
"""Register an event handler.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param handler: The function that should be invoked to handle the
event. When this parameter is not given, the method
acts as a decorator for the handler function.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the handler is associated with
the default namespace.
Example usage::
# as a decorator:
@socket_io.on('connect', namespace='/chat')
def connect_handler(sid, environ):
print('Connection request')
if environ['REMOTE_ADDR'] in blacklisted:
return False # reject
# as a method:
def message_handler(sid, msg):
print('Received message: ', msg)
eio.send(sid, 'response')
socket_io.on('message', namespace='/chat', message_handler)
The handler function receives the ``sid`` (session ID) for the
client as first argument. The ``'connect'`` event handler receives the
WSGI environment as a second argument, and can return ``False`` to
reject the connection. The ``'message'`` handler and handlers for
custom event names receive the message payload as a second argument.
Any values returned from a message handler will be passed to the
client's acknowledgement callback function if it exists. The
``'disconnect'`` handler does not take a second argument.
"""
namespace = namespace or '/'
def set_handler(handler):
if namespace not in self.handlers:
self.handlers[namespace] = {}
self.handlers[namespace][event] = handler
return handler
if handler is None:
return set_handler
set_handler(handler)
def event(self, *args, **kwargs):
"""Decorator to register an event handler.
This is a simplified version of the ``on()`` method that takes the
event name from the decorated function.
Example usage::
@sio.event
def my_event(data):
print('Received data: ', data)
The above example is equivalent to::
@sio.on('my_event')
def my_event(data):
print('Received data: ', data)
A custom namespace can be given as an argument to the decorator::
@sio.event(namespace='/test')
def my_event(data):
print('Received data: ', data)
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# the decorator was invoked without arguments
# args[0] is the decorated function
return self.on(args[0].__name__)(args[0])
else:
# the decorator was invoked with arguments
def set_handler(handler):
return self.on(handler.__name__, *args, **kwargs)(handler)
return set_handler
def register_namespace(self, namespace_handler):
"""Register a namespace handler object.
:param namespace_handler: An instance of a :class:`Namespace`
subclass that handles all the event traffic
for a namespace.
"""
if not isinstance(namespace_handler, namespace.Namespace):
raise ValueError('Not a namespace instance')
if self.is_asyncio_based() != namespace_handler.is_asyncio_based():
raise ValueError('Not a valid namespace class for this server')
namespace_handler._set_server(self)
self.namespace_handlers[namespace_handler.namespace] = \
namespace_handler
def emit(self, event, data=None, to=None, room=None, skip_sid=None,
namespace=None, callback=None, **kwargs):
"""Emit a custom event to one or more connected clients.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender. To
skip multiple sids, pass a list.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
namespace = namespace or '/'
room = to or room
self.logger.info('emitting event "%s" to %s [%s]', event,
room or 'all', namespace)
self.manager.emit(event, data, namespace, room=room,
skip_sid=skip_sid, callback=callback, **kwargs)
def send(self, data, to=None, room=None, skip_sid=None, namespace=None,
callback=None, **kwargs):
"""Send a message to one or more connected clients.
This function emits an event with the name ``'message'``. Use
:func:`emit` to issue custom event names.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The recipient of the message. This can be set to the
session ID of a client to address only that client, or to
to any custom room created by the application to address all
the clients in that room, If this argument is omitted the
event is broadcasted to all connected clients.
:param room: Alias for the ``to`` parameter.
:param skip_sid: The session ID of a client to skip when broadcasting
to a room or to all clients. This can be used to
prevent a message from being sent to the sender. To
skip multiple sids, pass a list.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param callback: If given, this function will be called to acknowledge
the the client has received the message. The arguments
that will be passed to the function are those provided
by the client. Callback functions can only be used
when addressing an individual client.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
clients directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
self.emit('message', data=data, to=to, room=room, skip_sid=skip_sid,
namespace=namespace, callback=callback, **kwargs)
def call(self, event, data=None, to=None, sid=None, namespace=None,
timeout=60, **kwargs):
"""Emit a custom event to a client and wait for the response.
:param event: The event name. It can be any string. The event names
``'connect'``, ``'message'`` and ``'disconnect'`` are
reserved and should not be used.
:param data: The data to send to the client or clients. Data can be of
type ``str``, ``bytes``, ``list`` or ``dict``. If a
``list`` or ``dict``, the data will be serialized as JSON.
:param to: The session ID of the recipient client.
:param sid: Alias for the ``to`` parameter.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the event is emitted to the
default namespace.
:param timeout: The waiting timeout. If the timeout is reached before
the client acknowledges the event, then a
``TimeoutError`` exception is raised.
:param ignore_queue: Only used when a message queue is configured. If
set to ``True``, the event is emitted to the
client directly, without going through the queue.
This is more efficient, but only works when a
single server process is used. It is recommended
to always leave this parameter with its default
value of ``False``.
"""
if not self.async_handlers:
raise RuntimeError(
'Cannot use call() when async_handlers is False.')
callback_event = self.eio.create_event()
callback_args = []
def event_callback(*args):
callback_args.append(args)
callback_event.set()
self.emit(event, data=data, room=to or sid, namespace=namespace,
callback=event_callback, **kwargs)
if not callback_event.wait(timeout=timeout):
raise exceptions.TimeoutError()
return callback_args[0] if len(callback_args[0]) > 1 \
else callback_args[0][0] if len(callback_args[0]) == 1 \
else None
def enter_room(self, sid, room, namespace=None):
"""Enter a room.
This function adds the client to a room. The :func:`emit` and
:func:`send` functions can optionally broadcast events to all the
clients in a room.
:param sid: Session ID of the client.
:param room: Room name. If the room does not exist it is created.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('%s is entering room %s [%s]', sid, room, namespace)
self.manager.enter_room(sid, namespace, room)
def leave_room(self, sid, room, namespace=None):
"""Leave a room.
This function removes the client from a room.
:param sid: Session ID of the client.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('%s is leaving room %s [%s]', sid, room, namespace)
self.manager.leave_room(sid, namespace, room)
def close_room(self, room, namespace=None):
"""Close a room.
This function removes all the clients from the given room.
:param room: Room name.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
self.logger.info('room %s is closing [%s]', room, namespace)
self.manager.close_room(room, namespace)
def rooms(self, sid, namespace=None):
"""Return the rooms a client is in.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace for the event. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
return self.manager.get_rooms(sid, namespace)
def get_session(self, sid, namespace=None):
"""Return the user session for a client.
:param sid: The session id of the client.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
The return value is a dictionary. Modifications made to this
dictionary are not guaranteed to be preserved unless
``save_session()`` is called, or when the ``session`` context manager
is used.
"""
namespace = namespace or '/'
eio_session = self.eio.get_session(sid)
return eio_session.setdefault(namespace, {})
def save_session(self, sid, session, namespace=None):
"""Store the user session for a client.
:param sid: The session id of the client.
:param session: The session dictionary.
:param namespace: The Socket.IO namespace. If this argument is omitted
the default namespace is used.
"""
namespace = namespace or '/'
eio_session = self.eio.get_session(sid)
eio_session[namespace] = session
def session(self, sid, namespace=None):
"""Return the user session for a client with context manager syntax.
:param sid: The session id of the client.
This is a context manager that returns the user session dictionary for
the client. Any changes that are made to this dictionary inside the
context manager block are saved back to the session. Example usage::
@sio.on('connect')
def on_connect(sid, environ):
username = authenticate_user(environ)
if not username:
return False
with sio.session(sid) as session:
session['username'] = username
@sio.on('message')
def on_message(sid, msg):
with sio.session(sid) as session:
print('received message from ', session['username'])
"""
class _session_context_manager(object):
def __init__(self, server, sid, namespace):
self.server = server
self.sid = sid
self.namespace = namespace
self.session = None
def __enter__(self):
self.session = self.server.get_session(sid,
namespace=namespace)
return self.session
def __exit__(self, *args):
self.server.save_session(sid, self.session,
namespace=namespace)
return _session_context_manager(self, sid, namespace)
def disconnect(self, sid, namespace=None):
"""Disconnect a client.
:param sid: Session ID of the client.
:param namespace: The Socket.IO namespace to disconnect. If this
argument is omitted the default namespace is used.
"""
namespace = namespace or '/'
if self.manager.is_connected(sid, namespace=namespace):
self.logger.info('Disconnecting %s [%s]', sid, namespace)
self.manager.pre_disconnect(sid, namespace=namespace)
self._send_packet(sid, packet.Packet(packet.DISCONNECT,
namespace=namespace))
self._trigger_event('disconnect', namespace, sid)
self.manager.disconnect(sid, namespace=namespace)
def transport(self, sid):
"""Return the name of the transport used by the client.
The two possible values returned by this function are ``'polling'``
and ``'websocket'``.
:param sid: The session of the client.
"""
return self.eio.transport(sid)
def handle_request(self, environ, start_response):
"""Handle an HTTP request from the client.
This is the entry point of the Socket.IO application, using the same
interface as a WSGI application. For the typical usage, this function
is invoked by the :class:`Middleware` instance, but it can be invoked
directly when the middleware is not used.
:param environ: The WSGI environment.
:param start_response: The WSGI ``start_response`` function.
This function returns the HTTP response body to deliver to the client
as a byte sequence.
"""
return self.eio.handle_request(environ, start_response)
def start_background_task(self, target, *args, **kwargs):
"""Start a background task using the appropriate async model.
This is a utility function that applications can use to start a
background task using the method that is compatible with the
selected async mode.
:param target: the target function to execute.
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
"""
return self.eio.start_background_task(target, *args, **kwargs)
def sleep(self, seconds=0):
"""Sleep for the requested amount of time using the appropriate async
model.
This is a utility function that applications can use to put a task to
sleep without having to worry about using the correct call for the
selected async mode.
"""
return self.eio.sleep(seconds)
def _emit_internal(self, sid, event, data, namespace=None, id=None):
"""Send a message to a client."""
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
# tuples are expanded to multiple arguments, everything else is sent
# as a single argument
if isinstance(data, tuple):
data = list(data)
else:
data = [data]
self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace,
data=[event] + data, id=id,
binary=binary))
def _send_packet(self, sid, pkt):
"""Send a Socket.IO packet to a client."""
encoded_packet = pkt.encode()
if isinstance(encoded_packet, list):
binary = False
for ep in encoded_packet:
self.eio.send(sid, ep, binary=binary)
binary = True
else:
self.eio.send(sid, encoded_packet, binary=False)
def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
fail_reason = None
try:
success = self._trigger_event('connect', namespace, sid,
self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.error_args
success = False
if success is False:
if self.always_connect:
self.manager.pre_disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(
packet.DISCONNECT, data=fail_reason, namespace=namespace))
self.manager.disconnect(sid, namespace)
if not self.always_connect:
self._send_packet(sid, packet.Packet(
packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False
elif not self.always_connect:
self._send_packet(sid, packet.Packet(packet.CONNECT,
namespace=namespace))
def _handle_disconnect(self, sid, namespace):
"""Handle a client disconnect."""
namespace = namespace or '/'
if namespace == '/':
namespace_list = list(self.manager.get_namespaces())
else:
namespace_list = [namespace]
for n in namespace_list:
if n != '/' and self.manager.is_connected(sid, n):
self._trigger_event('disconnect', n, sid)
self.manager.disconnect(sid, n)
if namespace == '/' and self.manager.is_connected(sid, namespace):
self._trigger_event('disconnect', '/', sid)
self.manager.disconnect(sid, '/')
def _handle_event(self, sid, namespace, id, data):
"""Handle an incoming client event."""
namespace = namespace or '/'
self.logger.info('received event "%s" from %s [%s]', data[0], sid,
namespace)
if self.async_handlers:
self.start_background_task(self._handle_event_internal, self, sid,
data, namespace, id)
else:
self._handle_event_internal(self, sid, data, namespace, id)
def _handle_event_internal(self, server, sid, data, namespace, id):
r = server._trigger_event(data[0], namespace, sid, *data[1:])
if id is not None:
# send ACK packet with the response returned by the handler
# tuples are expanded as multiple arguments
if r is None:
data = []
elif isinstance(r, tuple):
data = list(r)
else:
data = [r]
if six.PY2 and not self.binary:
binary = False # pragma: nocover
else:
binary = None
server._send_packet(sid, packet.Packet(packet.ACK,
namespace=namespace,
id=id, data=data,
binary=binary))
def _handle_ack(self, sid, namespace, id, data):
"""Handle ACK packets from the client."""
namespace = namespace or '/'
self.logger.info('received ack from %s [%s]', sid, namespace)
self.manager.trigger_callback(sid, namespace, id, data)
def _trigger_event(self, event, namespace, *args):
"""Invoke an application event handler."""
# first see if we have an explicit handler for the event
if namespace in self.handlers and event in self.handlers[namespace]:
return self.handlers[namespace][event](*args)
# or else, forward the event to a namespace handler if one exists
elif namespace in self.namespace_handlers:
return self.namespace_handlers[namespace].trigger_event(
event, *args)
def _handle_eio_connect(self, sid, environ):
"""Handle the Engine.IO connection event."""
if not self.manager_initialized:
self.manager_initialized = True
self.manager.initialize()
self.environ[sid] = environ
return self._handle_connect(sid, '/')
def _handle_eio_message(self, sid, data):
"""Dispatch Engine.IO messages."""
if sid in self._binary_packet:
pkt = self._binary_packet[sid]
if pkt.add_attachment(data):
del self._binary_packet[sid]
if pkt.packet_type == packet.BINARY_EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
else:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
else:
pkt = packet.Packet(encoded_packet=data)
if pkt.packet_type == packet.CONNECT:
self._handle_connect(sid, pkt.namespace)
elif pkt.packet_type == packet.DISCONNECT:
self._handle_disconnect(sid, pkt.namespace)
elif pkt.packet_type == packet.EVENT:
self._handle_event(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.ACK:
self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data)
elif pkt.packet_type == packet.BINARY_EVENT or \
pkt.packet_type == packet.BINARY_ACK:
self._binary_packet[sid] = pkt
elif pkt.packet_type == packet.ERROR:
raise ValueError('Unexpected ERROR packet.')
else:
raise ValueError('Unknown packet type.')
def _handle_eio_disconnect(self, sid):
"""Handle Engine.IO disconnect event."""
self._handle_disconnect(sid, '/')
if sid in self.environ:
del self.environ[sid]
def _engineio_server_class(self):
return engineio.Server

View file

@ -0,0 +1,11 @@
import sys
if sys.version_info >= (3, 5):
try:
from engineio.async_drivers.tornado import get_tornado_handler as \
get_engineio_handler
except ImportError: # pragma: no cover
get_engineio_handler = None
def get_tornado_handler(socketio_server): # pragma: no cover
return get_engineio_handler(socketio_server.eio)

View file

@ -0,0 +1,111 @@
import pickle
import re
try:
import eventlet.green.zmq as zmq
except ImportError:
zmq = None
import six
from .pubsub_manager import PubSubManager
class ZmqManager(PubSubManager): # pragma: no cover
"""zmq based client manager.
NOTE: this zmq implementation should be considered experimental at this
time. At this time, eventlet is required to use zmq.
This class implements a zmq backend for event sharing across multiple
processes. To use a zmq backend, initialize the :class:`Server` instance as
follows::
url = 'zmq+tcp://hostname:port1+port2'
server = socketio.Server(client_manager=socketio.ZmqManager(url))
:param url: The connection URL for the zmq message broker,
which will need to be provided and running.
:param channel: The channel name on which the server sends and receives
notifications. Must be the same in all the servers.
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
A zmq message broker must be running for the zmq_manager to work.
you can write your own or adapt one from the following simple broker
below::
import zmq
receiver = zmq.Context().socket(zmq.PULL)
receiver.bind("tcp://*:5555")
publisher = zmq.Context().socket(zmq.PUB)
publisher.bind("tcp://*:5556")
while True:
publisher.send(receiver.recv())
"""
name = 'zmq'
def __init__(self, url='zmq+tcp://localhost:5555+5556',
channel='socketio',
write_only=False,
logger=None):
if zmq is None:
raise RuntimeError('zmq package is not installed '
'(Run "pip install pyzmq" in your '
'virtualenv).')
r = re.compile(r':\d+\+\d+$')
if not (url.startswith('zmq+tcp://') and r.search(url)):
raise RuntimeError('unexpected connection string: ' + url)
url = url.replace('zmq+', '')
(sink_url, sub_port) = url.split('+')
sink_port = sink_url.split(':')[-1]
sub_url = sink_url.replace(sink_port, sub_port)
sink = zmq.Context().socket(zmq.PUSH)
sink.connect(sink_url)
sub = zmq.Context().socket(zmq.SUB)
sub.setsockopt_string(zmq.SUBSCRIBE, u'')
sub.connect(sub_url)
self.sink = sink
self.sub = sub
self.channel = channel
super(ZmqManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
def _publish(self, data):
pickled_data = pickle.dumps(
{
'type': 'message',
'channel': self.channel,
'data': data
}
)
return self.sink.send(pickled_data)
def zmq_listen(self):
while True:
response = self.sub.recv()
if response is not None:
yield response
def _listen(self):
for message in self.zmq_listen():
if isinstance(message, six.binary_type):
try:
message = pickle.loads(message)
except Exception:
pass
if isinstance(message, dict) and \
message['type'] == 'message' and \
message['channel'] == self.channel and \
'data' in message:
yield message['data']
return

View file

@ -0,0 +1,28 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
from ._abnf import *
from ._app import WebSocketApp
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.59.0"

View file

@ -0,0 +1,458 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import array
import os
import struct
import six
from ._exceptions import *
from ._utils import validate_utf8
from threading import Lock
try:
if six.PY3:
import numpy
else:
numpy = None
except ImportError:
numpy = None
try:
# If wsaccel is available we use compiled routines to mask data.
if not numpy:
from wsaccel.xormask import XorMaskerSimple
def _mask(_m, _d):
return XorMaskerSimple(_m).process(_d)
except ImportError:
# wsaccel is not available, we rely on python implementations.
def _mask(_m, _d):
for i in range(len(_d)):
_d[i] ^= _m[i % 4]
if six.PY3:
return _d.tobytes()
else:
return _d.tostring()
__all__ = [
'ABNF', 'continuous_frame', 'frame_buffer',
'STATUS_NORMAL',
'STATUS_GOING_AWAY',
'STATUS_PROTOCOL_ERROR',
'STATUS_UNSUPPORTED_DATA_TYPE',
'STATUS_STATUS_NOT_AVAILABLE',
'STATUS_ABNORMAL_CLOSED',
'STATUS_INVALID_PAYLOAD',
'STATUS_POLICY_VIOLATION',
'STATUS_MESSAGE_TOO_BIG',
'STATUS_INVALID_EXTENSION',
'STATUS_UNEXPECTED_CONDITION',
'STATUS_BAD_GATEWAY',
'STATUS_TLS_HANDSHAKE_ERROR',
]
# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED_DATA_TYPE = 1003
STATUS_STATUS_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSED = 1006
STATUS_INVALID_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011
STATUS_BAD_GATEWAY = 1014
STATUS_TLS_HANDSHAKE_ERROR = 1015
VALID_CLOSE_STATUS = (
STATUS_NORMAL,
STATUS_GOING_AWAY,
STATUS_PROTOCOL_ERROR,
STATUS_UNSUPPORTED_DATA_TYPE,
STATUS_INVALID_PAYLOAD,
STATUS_POLICY_VIOLATION,
STATUS_MESSAGE_TOO_BIG,
STATUS_INVALID_EXTENSION,
STATUS_UNEXPECTED_CONDITION,
STATUS_BAD_GATEWAY,
)
class ABNF(object):
"""
ABNF frame class.
See http://tools.ietf.org/html/rfc5234
and http://tools.ietf.org/html/rfc6455#section-5.2
"""
# operation code values.
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
# available operation code value tuple
OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
OPCODE_PING, OPCODE_PONG)
# opcode human readable string
OPCODE_MAP = {
OPCODE_CONT: "cont",
OPCODE_TEXT: "text",
OPCODE_BINARY: "binary",
OPCODE_CLOSE: "close",
OPCODE_PING: "ping",
OPCODE_PONG: "pong"
}
# data length threshold.
LENGTH_7 = 0x7e
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
opcode=OPCODE_TEXT, mask=1, data=""):
"""
Constructor for ABNF. Please check RFC for arguments.
"""
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask = mask
if data is None:
data = ""
self.data = data
self.get_mask_key = os.urandom
def validate(self, skip_utf8_validation=False):
"""
Validate the ABNF frame.
Parameters
----------
skip_utf8_validation: skip utf8 validation.
"""
if self.rsv1 or self.rsv2 or self.rsv3:
raise WebSocketProtocolException("rsv is not implemented, yet")
if self.opcode not in ABNF.OPCODES:
raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
if self.opcode == ABNF.OPCODE_PING and not self.fin:
raise WebSocketProtocolException("Invalid ping frame.")
if self.opcode == ABNF.OPCODE_CLOSE:
l = len(self.data)
if not l:
return
if l == 1 or l >= 126:
raise WebSocketProtocolException("Invalid close frame.")
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.")
code = 256 * \
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.")
@staticmethod
def _is_valid_close_status(code):
return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
def __str__(self):
return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \
+ " data=" + str(self.data)
@staticmethod
def create_frame(data, opcode, fin=1):
"""
Create frame to send text, binary and other data.
Parameters
----------
data: <type>
data to send. This is string value(byte array).
If opcode is OPCODE_TEXT and this value is unicode,
data value is converted into unicode string, automatically.
opcode: <type>
operation code. please see OPCODE_XXX.
fin: <type>
fin flag. if set to 0, create continue fragmentation.
"""
if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type):
data = data.encode("utf-8")
# mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data)
def format(self):
"""
Format this object to string(byte array) to send data to server.
"""
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
length = len(self.data)
if length >= ABNF.LENGTH_63:
raise ValueError("data is too long")
frame_header = chr(self.fin << 7 |
self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
self.opcode)
if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length)
frame_header = six.b(frame_header)
elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask << 7 | 0x7e)
frame_header = six.b(frame_header)
frame_header += struct.pack("!H", length)
else:
frame_header += chr(self.mask << 7 | 0x7f)
frame_header = six.b(frame_header)
frame_header += struct.pack("!Q", length)
if not self.mask:
return frame_header + self.data
else:
mask_key = self.get_mask_key(4)
return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data)
if isinstance(mask_key, six.text_type):
mask_key = mask_key.encode('utf-8')
return mask_key + s
@staticmethod
def mask(mask_key, data):
"""
Mask or unmask data. Just do xor for each byte
Parameters
----------
mask_key: <type>
4 byte string(byte).
data: <type>
data to mask/unmask.
"""
if data is None:
data = ""
if isinstance(mask_key, six.text_type):
mask_key = six.b(mask_key)
if isinstance(data, six.text_type):
data = six.b(data)
if numpy:
origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii")
a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen:
return masked.tobytes()[:origlen]
return masked.tobytes()
else:
_m = array.array("B", mask_key)
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object):
_HEADER_MASK_INDEX = 5
_HEADER_LENGTH_INDEX = 6
def __init__(self, recv_fn, skip_utf8_validation):
self.recv = recv_fn
self.skip_utf8_validation = skip_utf8_validation
# Buffers over the packets from the layer beneath until desired amount
# bytes of bytes are received.
self.recv_buffer = []
self.clear()
self.lock = Lock()
def clear(self):
self.header = None
self.length = None
self.mask = None
def has_received_header(self):
return self.header is None
def recv_header(self):
header = self.recv_strict(2)
b1 = header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7f
self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
def has_mask(self):
if not self.header:
return False
return self.header[frame_buffer._HEADER_MASK_INDEX]
def has_received_length(self):
return self.length is None
def recv_length(self):
bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
length_bits = bits & 0x7f
if length_bits == 0x7e:
v = self.recv_strict(2)
self.length = struct.unpack("!H", v)[0]
elif length_bits == 0x7f:
v = self.recv_strict(8)
self.length = struct.unpack("!Q", v)[0]
else:
self.length = length_bits
def has_received_mask(self):
return self.mask is None
def recv_mask(self):
self.mask = self.recv_strict(4) if self.has_mask() else ""
def recv_frame(self):
with self.lock:
# Header
if self.has_received_header():
self.recv_header()
(fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
# Frame length
if self.has_received_length():
self.recv_length()
length = self.length
# Mask
if self.has_received_mask():
self.recv_mask()
mask = self.mask
# Payload
payload = self.recv_strict(length)
if has_mask:
payload = ABNF.mask(mask, payload)
# Reset for next frame
self.clear()
frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
frame.validate(self.skip_utf8_validation)
return frame
def recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self.recv_buffer)
while shortage > 0:
# Limit buffer size that we pass to socket.recv() to avoid
# fragmenting the heap -- the number of bytes recv() actually
# reads is limited by socket buffer and is relatively small,
# yet passing large numbers repeatedly causes lots of large
# buffers allocated and then shrunk, which results in
# fragmentation.
bytes_ = self.recv(min(16384, shortage))
self.recv_buffer.append(bytes_)
shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer)
if shortage == 0:
self.recv_buffer = []
return unified
else:
self.recv_buffer = [unified[bufsize:]]
return unified[:bufsize]
class continuous_frame(object):
def __init__(self, fire_cont_frame, skip_utf8_validation):
self.fire_cont_frame = fire_cont_frame
self.skip_utf8_validation = skip_utf8_validation
self.cont_data = None
self.recving_frames = None
def validate(self, frame):
if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
raise WebSocketProtocolException("Illegal frame")
if self.recving_frames and \
frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
raise WebSocketProtocolException("Illegal frame")
def add(self, frame):
if self.cont_data:
self.cont_data[1] += frame.data
else:
if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
self.recving_frames = frame.opcode
self.cont_data = [frame.opcode, frame.data]
if frame.fin:
self.recving_frames = None
def is_fire(self, frame):
return frame.fin or self.fire_cont_frame
def extract(self, frame):
data = self.cont_data
self.cont_data = None
frame.data = data[1]
if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
raise WebSocketPayloadException(
"cannot decode: " + repr(frame.data))
return [data[0], frame]

View file

@ -0,0 +1,399 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import inspect
import select
import sys
import threading
import time
import traceback
import six
from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout
from ._exceptions import *
from . import _logging
__all__ = ["WebSocketApp"]
class Dispatcher:
"""
Dispatcher
"""
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout)
if r:
if not read_callback():
break
check_callback()
class SSLDispatcher:
"""
SSLDispatcher
"""
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r = self.select()
if r:
if not read_callback():
break
check_callback()
def select(self):
sock = self.app.sock.sock
if sock.pending():
return [sock,]
r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r
class WebSocketApp(object):
"""
Higher level of APIs are provided. The interface is like JavaScript WebSocket object.
"""
def __init__(self, url, header=None,
on_open=None, on_message=None, on_error=None,
on_close=None, on_ping=None, on_pong=None,
on_cont_message=None,
keep_running=True, get_mask_key=None, cookie=None,
subprotocols=None,
on_data=None):
"""
WebSocketApp initialization
Parameters
----------
url: <type>
websocket url.
header: list or dict
custom header for websocket handshake.
on_open: <type>
callable object which is called at opening websocket.
this function has one argument. The argument is this class object.
on_message: <type>
callable object which is called when received data.
on_message has 2 arguments.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
on_error: <type>
callable object which is called when we get error.
on_error has 2 arguments.
The 1st argument is this class object.
The 2nd argument is exception object.
on_close: <type>
callable object which is called when closed the connection.
this function has one argument. The argument is this class object.
on_cont_message: <type>
callback object which is called when receive continued
frame data.
on_cont_message has 3 arguments.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is continue flag. if 0, the data continue
to next frame data
on_data: <type>
callback object which is called when a message received.
This is called before on_message or on_cont_message,
and then on_message or on_cont_message is called.
on_data has 4 argument.
The 1st argument is this class object.
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue
keep_running: <type>
this parameter is obsolete and ignored.
get_mask_key: func
a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information
cookie: str
cookie value.
subprotocols: <type>
array of available sub protocols. default is None.
"""
self.url = url
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_message = on_message
self.on_data = on_data
self.on_error = on_error
self.on_close = on_close
self.on_ping = on_ping
self.on_pong = on_pong
self.on_cont_message = on_cont_message
self.keep_running = False
self.get_mask_key = get_mask_key
self.sock = None
self.last_ping_tm = 0
self.last_pong_tm = 0
self.subprotocols = subprotocols
def send(self, data, opcode=ABNF.OPCODE_TEXT):
"""
send message
Parameters
----------
data: <type>
Message to send. If you set opcode to OPCODE_TEXT,
data must be utf-8 string or unicode.
opcode: <type>
Operation code of data. default is OPCODE_TEXT.
"""
if not self.sock or self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException(
"Connection is already closed.")
def close(self, **kwargs):
"""
Close websocket connection.
"""
self.keep_running = False
if self.sock:
self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event, payload):
while not event.wait(interval):
self.last_ping_tm = time.time()
if self.sock:
try:
self.sock.ping(payload)
except Exception as ex:
_logging.warning("send_ping routine terminated: {}".format(ex))
break
def run_forever(self, sockopt=None, sslopt=None,
ping_interval=0, ping_timeout=None,
ping_payload="",
http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False,
host=None, origin=None, dispatcher=None,
suppress_origin=False, proxy_type=None):
"""
Run event loop for WebSocket framework.
This loop is an infinite loop and is alive while websocket is available.
Parameters
----------
sockopt: tuple
values for socket.setsockopt.
sockopt must be tuple
and each element is argument of sock.setsockopt.
sslopt: dict
optional dict object for ssl socket option.
ping_interval: int or float
automatically send "ping" command
every specified period (in seconds)
if set to 0, not send automatically.
ping_timeout: int or float
timeout (in seconds) if the pong message is not received.
ping_payload: str
payload message to send with each ping.
http_proxy_host: <type>
http proxy host name.
http_proxy_port: <type>
http proxy port. If not set, set to 80.
http_no_proxy: <type>
host names, which doesn't use proxy.
skip_utf8_validation: bool
skip utf8 validation.
host: str
update host header.
origin: str
update origin header.
dispatcher: <type>
customize reading data from socket.
suppress_origin: bool
suppress outputting origin header.
Returns
-------
teardown: bool
False if caught KeyboardInterrupt, True if other exception was raised during a loop
"""
if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None
if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if not sockopt:
sockopt = []
if not sslopt:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
thread = None
self.keep_running = True
self.last_ping_tm = 0
self.last_pong_tm = 0
def teardown(close_frame=None):
"""
Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.is_alive():
event.set()
thread.join()
self.keep_running = False
if self.sock:
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
self._callback(self.on_close, *close_args)
self.sock = None
try:
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout())
self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
self._callback(self.on_open)
if ping_interval:
event = threading.Event()
thread = threading.Thread(
target=self._send_ping, args=(ping_interval, event, ping_payload))
thread.daemon = True
thread.start()
def read():
if not self.keep_running:
return teardown()
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
self.last_pong_tm = time.time()
self._callback(self.on_pong, frame.data)
elif op_code == ABNF.OPCODE_CONT and self.on_cont_message:
self._callback(self.on_data, frame.data,
frame.opcode, frame.fin)
self._callback(self.on_cont_message,
frame.data, frame.fin)
else:
data = frame.data
if six.PY3 and op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
return True
def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm and
has_timeout_expired and
(has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out")
return True
dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e)
if isinstance(e, SystemExit):
# propagate SystemExit further
raise
teardown()
return not isinstance(e, KeyboardInterrupt)
def create_dispatcher(self, ping_timeout):
timeout = ping_timeout or 10
if self.sock.is_ssl():
return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, data):
"""
_get_close_args extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments
"""
# if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
return []
else:
if not self.on_close or len(inspect.getfullargspec(self.on_close).args) != 3:
return []
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None]
def _callback(self, callback, *args):
if callback:
try:
callback(self, *args)
except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e))
if _logging.isEnabledForDebug():
_, _, tb = sys.exc_info()
traceback.print_tb(tb)

View file

@ -0,0 +1,78 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
try:
import Cookie
except:
import http.cookies as Cookie
class SimpleCookieJar(object):
def __init__(self):
self.jar = dict()
def add(self, set_cookie):
if set_cookie:
try:
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items():
domain = v.get("domain")
if domain:
if not domain.startswith("."):
domain = "." + domain
cookie = self.jar.get(domain) if self.jar.get(domain) else Cookie.SimpleCookie()
cookie.update(simpleCookie)
self.jar[domain.lower()] = cookie
def set(self, set_cookie):
if set_cookie:
try:
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items():
domain = v.get("domain")
if domain:
if not domain.startswith("."):
domain = "." + domain
self.jar[domain.lower()] = simpleCookie
def get(self, host):
if not host:
return ""
cookies = []
for domain, simpleCookie in self.jar.items():
host = host.lower()
if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain))
return "; ".join(filter(
None, sorted(
["%s=%s" % (k, v.value) for cookie in filter(None, cookies) for k, v in cookie.items()]
)))

View file

@ -0,0 +1,595 @@
from __future__ import print_function
"""
_core.py
====================================
WebSocket Python client
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import socket
import struct
import threading
import time
import six
# websocket modules
from ._abnf import *
from ._exceptions import *
from ._handshake import *
from ._http import *
from ._logging import *
from ._socket import *
from ._ssl_compat import *
from ._utils import *
__all__ = ['WebSocket', 'create_connection']
class WebSocket(object):
"""
Low level WebSocket interface.
This class is based on the WebSocket protocol `draft-hixie-thewebsocketprotocol-76 <http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76>`_
We can connect to the websocket server and send/receive data.
The following example is an echo client.
>>> import websocket
>>> ws = websocket.WebSocket()
>>> ws.connect("ws://echo.websocket.org")
>>> ws.send("Hello, Server")
>>> ws.recv()
'Hello, Server'
>>> ws.close()
Parameters
----------
get_mask_key: func
a callable to produce new mask keys, see the set_mask_key
function's docstring for more details
sockopt: tuple
values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setsockopt.
sslopt: dict
optional dict object for ssl socket option.
fire_cont_frame: bool
fire recv event for each cont frame. default is False
enable_multithread: bool
if set to True, lock send method.
skip_utf8_validation: bool
skip utf8 validation.
"""
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None,
fire_cont_frame=False, enable_multithread=False,
skip_utf8_validation=False, **_):
"""
Initialize WebSocket object.
Parameters
----------
sslopt: specify ssl certification verification options
"""
self.sock_opt = sock_opt(sockopt, sslopt)
self.handshake_response = None
self.sock = None
self.connected = False
self.get_mask_key = get_mask_key
# These buffer over the build-up of a single frame.
self.frame_buffer = frame_buffer(self._recv, skip_utf8_validation)
self.cont_frame = continuous_frame(
fire_cont_frame, skip_utf8_validation)
if enable_multithread:
self.lock = threading.Lock()
self.readlock = threading.Lock()
else:
self.lock = NoLock()
self.readlock = NoLock()
def __iter__(self):
"""
Allow iteration over websocket, implying sequential `recv` executions.
"""
while True:
yield self.recv()
def __next__(self):
return self.recv()
def next(self):
return self.__next__()
def fileno(self):
return self.sock.fileno()
def set_mask_key(self, func):
"""
Set function to create mask key. You can customize mask key generator.
Mainly, this is for testing purpose.
Parameters
----------
func: func
callable object. the func takes 1 argument as integer.
The argument means length of mask key.
This func must return string(byte array),
which length is argument specified.
"""
self.get_mask_key = func
def gettimeout(self):
"""
Get the websocket timeout (in seconds) as an int or float
Returns
----------
timeout: int or float
returns timeout value (in seconds). This value could be either float/integer.
"""
return self.sock_opt.timeout
def settimeout(self, timeout):
"""
Set the timeout to the websocket.
Parameters
----------
timeout: int or float
timeout time (in seconds). This value could be either float/integer.
"""
self.sock_opt.timeout = timeout
if self.sock:
self.sock.settimeout(timeout)
timeout = property(gettimeout, settimeout)
def getsubprotocol(self):
"""
Get subprotocol
"""
if self.handshake_response:
return self.handshake_response.subprotocol
else:
return None
subprotocol = property(getsubprotocol)
def getstatus(self):
"""
Get handshake status
"""
if self.handshake_response:
return self.handshake_response.status
else:
return None
status = property(getstatus)
def getheaders(self):
"""
Get handshake response header
"""
if self.handshake_response:
return self.handshake_response.headers
else:
return None
def is_ssl(self):
return isinstance(self.sock, ssl.SSLSocket)
headers = property(getheaders)
def connect(self, url, **options):
"""
Connect to url. url is websocket url scheme.
ie. ws://host:port/resource
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> ws = WebSocket()
>>> ws.connect("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
timeout: <type>
socket timeout time. This value is an integer or float.
if you set None for this value, it means "use default_timeout value"
Parameters
----------
options:
- header: list or dict
custom http header list or dict.
- cookie: str
cookie value.
- origin: str
custom origin url.
- suppress_origin: bool
suppress outputting origin header.
- host: str
custom host header string.
- http_proxy_host: <type>
http proxy host name.
- http_proxy_port: <type>
http proxy port. If not set, set to 80.
- http_no_proxy: <type>
host names, which doesn't use proxy.
- http_proxy_auth: <type>
http proxy auth information. tuple of username and password. default is None
- redirect_limit: <type>
number of redirects to follow.
- subprotocols: <type>
array of available sub protocols. default is None.
- socket: <type>
pre-initialized stream socket.
"""
self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout)
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
try:
self.handshake_response = handshake(self.sock, *addrs, **options)
for attempt in range(options.pop('redirect_limit', 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers['location']
self.sock.close()
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True
except:
if self.sock:
self.sock.close()
self.sock = None
raise
def send(self, payload, opcode=ABNF.OPCODE_TEXT):
"""
Send the data as string.
Parameters
----------
payload: <type>
Payload must be utf-8 string or unicode,
if the opcode is OPCODE_TEXT.
Otherwise, it must be string(byte array)
opcode: <type>
operation code to send. Please see OPCODE_XXX.
"""
frame = ABNF.create_frame(payload, opcode)
return self.send_frame(frame)
def send_frame(self, frame):
"""
Send the data frame.
>>> ws = create_connection("ws://echo.websocket.org/")
>>> frame = ABNF.create_frame("Hello", ABNF.OPCODE_TEXT)
>>> ws.send_frame(frame)
>>> cont_frame = ABNF.create_frame("My name is ", ABNF.OPCODE_CONT, 0)
>>> ws.send_frame(frame)
>>> cont_frame = ABNF.create_frame("Foo Bar", ABNF.OPCODE_CONT, 1)
>>> ws.send_frame(frame)
Parameters
----------
frame: <type>
frame data created by ABNF.create_frame
"""
if self.get_mask_key:
frame.get_mask_key = self.get_mask_key
data = frame.format()
length = len(data)
if (isEnabledForTrace()):
trace("send: " + repr(data))
with self.lock:
while data:
l = self._send(data)
data = data[l:]
return length
def send_binary(self, payload):
return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload=""):
"""
Send ping data.
Parameters
----------
payload: <type>
data payload to send server.
"""
if isinstance(payload, six.text_type):
payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PING)
def pong(self, payload=""):
"""
Send pong data.
Parameters
----------
payload: <type>
data payload to send server.
"""
if isinstance(payload, six.text_type):
payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PONG)
def recv(self):
"""
Receive string data(byte array) from the server.
Returns
----------
data: string (byte array) value.
"""
with self.readlock:
opcode, data = self.recv_data()
if six.PY3 and opcode == ABNF.OPCODE_TEXT:
return data.decode("utf-8")
elif opcode == ABNF.OPCODE_TEXT or opcode == ABNF.OPCODE_BINARY:
return data
else:
return ''
def recv_data(self, control_frame=False):
"""
Receive data with operation code.
Parameters
----------
control_frame: bool
a boolean flag indicating whether to return control frame
data, defaults to False
Returns
-------
opcode, frame.data: tuple
tuple of operation code and string(byte array) value.
"""
opcode, frame = self.recv_data_frame(control_frame)
return opcode, frame.data
def recv_data_frame(self, control_frame=False):
"""
Receive data with operation code.
Parameters
----------
control_frame: bool
a boolean flag indicating whether to return control frame
data, defaults to False
Returns
-------
frame.opcode, frame: tuple
tuple of operation code and string(byte array) value.
"""
while True:
frame = self.recv_frame()
if not frame:
# handle error:
# 'NoneType' object has no attribute 'opcode'
raise WebSocketProtocolException(
"Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
self.cont_frame.validate(frame)
self.cont_frame.add(frame)
if self.cont_frame.is_fire(frame):
return self.cont_frame.extract(frame)
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PING:
if len(frame.data) < 126:
self.pong(frame.data)
else:
raise WebSocketProtocolException(
"Ping message is too long")
if control_frame:
return frame.opcode, frame
elif frame.opcode == ABNF.OPCODE_PONG:
if control_frame:
return frame.opcode, frame
def recv_frame(self):
"""
Receive data as frame from server.
Returns
-------
self.frame_buffer.recv_frame(): ABNF frame object
"""
return self.frame_buffer.recv_frame()
def send_close(self, status=STATUS_NORMAL, reason=six.b("")):
"""
Send close data to the server.
Parameters
----------
status: <type>
status code to send. see STATUS_XXX.
reason: str or bytes
the reason to close. This must be string or bytes.
"""
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status=STATUS_NORMAL, reason=six.b(""), timeout=3):
"""
Close Websocket object
Parameters
----------
status: <type>
status code to send. see STATUS_XXX.
reason: <type>
the reason to close. This must be string.
timeout: int or float
timeout until receive a close frame.
If None, it will wait forever until receive a close frame.
"""
if self.connected:
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
try:
self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try:
frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError():
recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status >= 3000 and recv_status <= 4999:
debug("close status: " + repr(recv_status))
elif recv_status != STATUS_NORMAL:
error("close status: " + repr(recv_status))
break
except:
break
self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
pass
self.shutdown()
def abort(self):
"""
Low-level asynchronous abort, wakes up other threads that are waiting in recv_*
"""
if self.connected:
self.sock.shutdown(socket.SHUT_RDWR)
def shutdown(self):
"""
close socket, immediately.
"""
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
def _send(self, data):
return send(self.sock, data)
def _recv(self, bufsize):
try:
return recv(self.sock, bufsize)
except WebSocketConnectionClosedException:
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
raise
def create_connection(url, timeout=None, class_=WebSocket, **options):
"""
Connect to url and return websocket object.
Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied,
the global default timeout setting returned by getdefaulttimeout() is used.
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
Parameters
----------
timeout: int or float
socket timeout time. This value could be either float/integer.
if you set None for this value,
it means "use default_timeout value"
class_: <type>
class to instantiate when creating the connection. It has to implement
settimeout and connect. It's __init__ should be compatible with
WebSocket.__init__, i.e. accept all of it's kwargs.
options: <type>
- header: list or dict
custom http header list or dict.
- cookie: str
cookie value.
- origin: str
custom origin url.
- suppress_origin: bool
suppress outputting origin header.
- host: <type>
custom host header string.
- http_proxy_host: <type>
http proxy host name.
- http_proxy_port: <type>
http proxy port. If not set, set to 80.
- http_no_proxy: <type>
host names, which doesn't use proxy.
- http_proxy_auth: <type>
http proxy auth information. tuple of username and password. default is None
- enable_multithread: bool
enable lock for multithread.
- redirect_limit: <type>
number of redirects to follow.
- sockopt: <type>
socket options
- sslopt: <type>
ssl option
- subprotocols: <type>
array of available sub protocols. default is None.
- skip_utf8_validation: bool
skip utf8 validation.
- socket: <type>
pre-initialized stream socket.
"""
sockopt = options.pop("sockopt", [])
sslopt = options.pop("sslopt", {})
fire_cont_frame = options.pop("fire_cont_frame", False)
enable_multithread = options.pop("enable_multithread", False)
skip_utf8_validation = options.pop("skip_utf8_validation", False)
websock = class_(sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=fire_cont_frame,
enable_multithread=enable_multithread,
skip_utf8_validation=skip_utf8_validation, **options)
websock.settimeout(timeout if timeout is not None else getdefaulttimeout())
websock.connect(url, **options)
return websock

View file

@ -0,0 +1,86 @@
"""
Define WebSocket exceptions
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
class WebSocketException(Exception):
"""
WebSocket exception class.
"""
pass
class WebSocketProtocolException(WebSocketException):
"""
If the WebSocket protocol is invalid, this exception will be raised.
"""
pass
class WebSocketPayloadException(WebSocketException):
"""
If the WebSocket payload is invalid, this exception will be raised.
"""
pass
class WebSocketConnectionClosedException(WebSocketException):
"""
If remote host closed the connection or some network error happened,
this exception will be raised.
"""
pass
class WebSocketTimeoutException(WebSocketException):
"""
WebSocketTimeoutException will be raised at socket timeout during read/write data.
"""
pass
class WebSocketProxyException(WebSocketException):
"""
WebSocketProxyException will be raised when proxy error occurred.
"""
pass
class WebSocketBadStatusException(WebSocketException):
"""
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(self, message, status_code, status_message=None, resp_headers=None):
msg = message % (status_code, status_message)
super(WebSocketBadStatusException, self).__init__(msg)
self.status_code = status_code
self.resp_headers = resp_headers
class WebSocketAddressException(WebSocketException):
"""
If the websocket address info cannot be found, this exception will be raised.
"""
pass

View file

@ -0,0 +1,212 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import hashlib
import hmac
import os
import six
from ._cookiejar import SimpleCookieJar
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
def compare_digest(s1, s2):
return s1 == s2
# websocket supported version.
VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar()
class handshake_response(object):
def __init__(self, status, headers, subprotocol):
self.status = status
self.headers = headers
self.subprotocol = subprotocol
CookieJar.add(headers.get("set-cookie"))
def handshake(sock, hostname, port, resource, **options):
headers, key = _get_handshake_headers(resource, hostname, port, options)
header_str = "\r\n".join(headers)
send(sock, header_str)
dump("request header", header_str)
status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto)
def _pack_hostname(hostname):
# IPv6 address
if ':' in hostname:
return '[' + hostname + ']'
return hostname
def _get_handshake_headers(resource, host, port, options):
headers = [
"GET %s HTTP/1.1" % resource,
"Upgrade: websocket"
]
if port == 80 or port == 443:
hostport = _pack_hostname(host)
else:
hostport = "%s:%d" % (_pack_hostname(host), port)
if "host" in options and options["host"] is not None:
headers.append("Host: %s" % options["host"])
else:
headers.append("Host: %s" % hostport)
if "suppress_origin" not in options or not options["suppress_origin"]:
if "origin" in options and options["origin"] is not None:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
else:
key = options['header']['Sec-WebSocket-Key']
if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION)
if 'connection' not in options or options['connection'] is None:
headers.append('Connection: Upgrade')
else:
headers.append(options['connection'])
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
if "header" in options:
header = options["header"]
if isinstance(header, dict):
header = [
": ".join([k, v])
for k, v in header.items()
if v is not None
]
headers.extend(header)
server_cookie = CookieJar.get(host)
client_cookie = options.get("cookie", None)
cookie = "; ".join(filter(None, [server_cookie, client_cookie]))
if cookie:
headers.append("Cookie: %s" % cookie)
headers.append("")
headers.append("")
return headers, key
def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
status, resp_headers, status_message = read_headers(sock)
if status not in success_statuses:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
return status, resp_headers
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
def _validate(headers, key, subprotocols):
subproto = None
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False, None
r = [x.strip().lower() for x in r.split(',')]
if v not in r:
return False, None
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
error("Invalid subprotocol: " + str(subprotocols))
return False, None
subproto = subproto.lower()
result = headers.get("sec-websocket-accept", None)
if not result:
return False, None
result = result.lower()
if isinstance(result, six.text_type):
result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
success = compare_digest(hashed, result)
if success:
return True, subproto
else:
return False, None
def _create_sec_websocket_key():
randomness = os.urandom(16)
return base64encode(randomness).decode('utf-8').strip()

View file

@ -0,0 +1,335 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import errno
import os
import socket
import sys
import six
from ._exceptions import *
from ._logging import *
from ._socket import*
from ._ssl_compat import *
from ._url import *
if six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
__all__ = ["proxy_info", "connect", "read_headers"]
try:
import socks
ProxyConnectionError = socks.ProxyConnectionError
HAS_PYSOCKS = True
except:
class ProxyConnectionError(BaseException):
pass
HAS_PYSOCKS = False
class proxy_info(object):
def __init__(self, **options):
self.type = options.get("proxy_type") or "http"
if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']):
raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'")
self.host = options.get("http_proxy_host", None)
if self.host:
self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else:
self.port = 0
self.auth = None
self.no_proxy = None
def _open_proxied_socket(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url)
if not HAS_PYSOCKS:
raise WebSocketException("PySocks module not found.")
ptype = socks.SOCKS5
rdns = False
if proxy.type == "socks4":
ptype = socks.SOCKS4
if proxy.type == "http":
ptype = socks.HTTP
if proxy.type[-1] == "h":
rdns = True
sock = socks.create_connection(
(hostname, port),
proxy_type=ptype,
proxy_addr=proxy.host,
proxy_port=proxy.port,
proxy_rdns=rdns,
proxy_username=proxy.auth[0] if proxy.auth else None,
proxy_password=proxy.auth[1] if proxy.auth else None,
timeout=options.timeout,
socket_options=DEFAULT_SOCKET_OPTION + options.sockopt
)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url, options, proxy, socket):
if proxy.host and not socket and not (proxy.type == 'http'):
return _open_proxied_socket(url, options, proxy)
hostname, port, resource, is_secure = parse_url(url)
if socket:
return socket, (hostname, port, resource)
addrinfo_list, need_tunnel, auth = _get_addrinfo_list(
hostname, port, is_secure, proxy)
if not addrinfo_list:
raise WebSocketException(
"Host not found.: " + hostname + ":" + str(port))
sock = None
try:
sock = _open_socket(addrinfo_list, options.sockopt, options.timeout)
if need_tunnel:
sock = _tunnel(sock, hostname, port, auth)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
except:
if sock:
sock.close()
raise
def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(
hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost:
addrinfo_list = socket.getaddrinfo(
hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
# when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, True, pauth
except socket.gaierror as e:
raise WebSocketAddressException(e)
def _open_socket(addrinfo_list, sockopt, timeout):
err = None
for addrinfo in addrinfo_list:
family, socktype, proto = addrinfo[:3]
sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout)
for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts)
for opts in sockopt:
sock.setsockopt(*opts)
address = addrinfo[4]
err = None
while not err:
try:
sock.connect(address)
except ProxyConnectionError as error:
err = WebSocketProxyException(str(error))
err.remote_ip = str(address[0])
continue
except socket.error as error:
error.remote_ip = str(address[0])
try:
eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED)
except:
eConnRefused = (errno.ECONNREFUSED, )
if error.errno == errno.EINTR:
continue
elif error.errno in eConnRefused:
err = error
continue
else:
raise error
else:
break
else:
continue
break
else:
if err:
raise err
return sock
def _can_use_sni():
return six.PY2 and sys.version_info >= (2, 7, 9) or sys.version_info >= (3, 2)
def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
cafile = sslopt.get('ca_certs', None)
capath = sslopt.get('ca_cert_path', None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, 'load_default_certs'):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get('certfile', None):
context.load_cert_chain(
sslopt['certfile'],
sslopt.get('keyfile', None),
sslopt.get('password', None),
)
# see
# https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153
context.verify_mode = sslopt['cert_reqs']
if HAVE_CONTEXT_CHECK_HOSTNAME:
context.check_hostname = check_hostname
if 'ciphers' in sslopt:
context.set_ciphers(sslopt['ciphers'])
if 'cert_chain' in sslopt:
certfile, keyfile, password = sslopt['cert_chain']
context.load_cert_chain(certfile, keyfile, password)
if 'ecdh_curve' in sslopt:
context.set_ecdh_curve(sslopt['ecdh_curve'])
return context.wrap_socket(
sock,
do_handshake_on_connect=sslopt.get('do_handshake_on_connect', True),
suppress_ragged_eofs=sslopt.get('suppress_ragged_eofs', True),
server_hostname=hostname,
)
def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt)
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath
elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None:
sslopt['ca_cert_path'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True)
if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
else:
sslopt.pop('check_hostname', True)
sock = ssl.wrap_socket(sock, **sslopt)
if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname:
match_hostname(sock.getpeercert(), hostname)
return sock
def _tunnel(sock, host, port, auth):
debug("Connecting proxy...")
connect_header = "CONNECT %s:%d HTTP/1.1\r\n" % (host, port)
connect_header += "Host: %s:%d\r\n" % (host, port)
# TODO: support digest auth.
if auth and auth[0]:
auth_str = auth[0]
if auth[1]:
auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '')
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n"
dump("request header", connect_header)
send(sock, connect_header)
try:
status, resp_headers, status_message = read_headers(sock)
except Exception as e:
raise WebSocketProxyException(str(e))
if status != 200:
raise WebSocketProxyException(
"failed CONNECT via proxy status: %r" % status)
return sock
def read_headers(sock):
status = None
status_message = None
headers = {}
trace("--- response header ---")
while True:
line = recv_line(sock)
line = line.decode('utf-8').strip()
if not line:
break
trace(line)
if not status:
status_info = line.split(" ", 2)
status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2]
else:
kv = line.split(":", 1)
if len(kv) == 2:
key, value = kv
if key.lower() == "set-cookie" and headers.get("set-cookie"):
headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip()
else:
headers[key.lower()] = value.strip()
else:
raise WebSocketException("Invalid header")
trace("-----------------------")
return status, headers, status_message

View file

@ -0,0 +1,92 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import logging
_logger = logging.getLogger('websocket')
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
_logger.addHandler(NullHandler())
_traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace",
"isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"]
def enableTrace(traceable, handler=logging.StreamHandler()):
"""
Turn on/off the traceability.
Parameters
----------
traceable: bool
If set to True, traceability is enabled.
"""
global _traceEnabled
_traceEnabled = traceable
if traceable:
_logger.addHandler(handler)
_logger.setLevel(logging.DEBUG)
def dump(title, message):
if _traceEnabled:
_logger.debug("--- " + title + " ---")
_logger.debug(message)
_logger.debug("-----------------------")
def error(msg):
_logger.error(msg)
def warning(msg):
_logger.warning(msg)
def debug(msg):
_logger.debug(msg)
def trace(msg):
if _traceEnabled:
_logger.debug(msg)
def isEnabledForError():
return _logger.isEnabledFor(logging.ERROR)
def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace():
return _traceEnabled

View file

@ -0,0 +1,176 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import errno
import select
import socket
import six
from ._exceptions import *
from ._ssl_compat import *
from ._utils import *
DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
if hasattr(socket, "SO_KEEPALIVE"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
if hasattr(socket, "TCP_KEEPIDLE"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
if hasattr(socket, "TCP_KEEPINTVL"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
if hasattr(socket, "TCP_KEEPCNT"):
DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
_default_timeout = None
__all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout",
"recv", "recv_line", "send"]
class sock_opt(object):
def __init__(self, sockopt, sslopt):
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.sockopt = sockopt
self.sslopt = sslopt
self.timeout = None
def setdefaulttimeout(timeout):
"""
Set the global timeout setting to connect.
Parameters
----------
timeout: int or float
default socket timeout time (in seconds)
"""
global _default_timeout
_default_timeout = timeout
def getdefaulttimeout():
"""
Get default timeout
Returns
----------
_default_timeout: int or float
Return the global timeout setting (in seconds) to connect.
"""
return _default_timeout
def recv(sock, bufsize):
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try:
return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout())
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
except SSLError as e:
message = extract_err_message(e)
if isinstance(message, str) and 'timed out' in message:
raise WebSocketTimeoutException(message)
else:
raise
if not bytes_:
raise WebSocketConnectionClosedException(
"Connection is already closed.")
return bytes_
def recv_line(sock):
line = []
while True:
c = recv(sock, 1)
line.append(c)
if c == six.b("\n"):
break
return six.b("").join(line)
def send(sock, data):
if isinstance(data, six.text_type):
data = data.encode('utf-8')
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try:
return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout())
if w:
return sock.send(data)
try:
if sock.gettimeout() == 0:
return sock.send(data)
else:
return _send()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
except Exception as e:
message = extract_err_message(e)
if isinstance(message, str) and "timed out" in message:
raise WebSocketTimeoutException(message)
else:
raise

View file

@ -0,0 +1,53 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
try:
import ssl
from ssl import SSLError
from ssl import SSLWantReadError
from ssl import SSLWantWriteError
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True
else:
HAVE_CONTEXT_CHECK_HOSTNAME = False
if hasattr(ssl, "match_hostname"):
from ssl import match_hostname
else:
from backports.ssl_match_hostname import match_hostname
__all__.append("match_hostname")
__all__.append("HAVE_CONTEXT_CHECK_HOSTNAME")
HAVE_SSL = True
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = None
HAVE_SSL = False

View file

@ -0,0 +1,178 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import socket
import struct
from six.moves.urllib.parse import urlparse
__all__ = ["parse_url", "get_proxy_info"]
def parse_url(url):
"""
parse url and the result is tuple of
(hostname, port, resource path and the flag of secure mode)
Parameters
----------
url: str
url string.
"""
if ":" not in url:
raise ValueError("url is invalid")
scheme, url = url.split(":", 1)
parsed = urlparse(url, scheme="http")
if parsed.hostname:
hostname = parsed.hostname
else:
raise ValueError("hostname is invalid")
port = 0
if parsed.port:
port = parsed.port
is_secure = False
if scheme == "ws":
if not port:
port = 80
elif scheme == "wss":
is_secure = True
if not port:
port = 443
else:
raise ValueError("scheme %s is invalid" % scheme)
if parsed.path:
resource = parsed.path
else:
resource = "/"
if parsed.query:
resource += "?" + parsed.query
return hostname, port, resource, is_secure
DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"]
def _is_ip_address(addr):
try:
socket.inet_aton(addr)
except socket.error:
return False
else:
return True
def _is_subnet_address(hostname):
try:
addr, netmask = hostname.split("/")
return _is_ip_address(addr) and 0 <= int(netmask) < 32
except ValueError:
return False
def _is_address_in_network(ip, net):
ipaddr = struct.unpack('!I', socket.inet_aton(ip))[0]
netaddr, netmask = net.split('/')
netaddr = struct.unpack('!I', socket.inet_aton(netaddr))[0]
netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF
return ipaddr & netmask == netaddr
def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "")
if v:
no_proxy = v.split(",")
if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST
if '*' in no_proxy:
return True
if hostname in no_proxy:
return True
if _is_ip_address(hostname):
return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)])
for domain in [domain for domain in no_proxy if domain.startswith('.')]:
if hostname.endswith(domain):
return True
return False
def get_proxy_info(
hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None, proxy_type='http'):
"""
Try to retrieve proxy host and port from environment
if not provided in options.
Result is (proxy_host, proxy_port, proxy_auth).
proxy_auth is tuple of username and password
of proxy authentication information.
Parameters
----------
hostname: <type>
websocket server name.
is_secure: <type>
is the connection secure? (wss) looks for "https_proxy" in env
before falling back to "http_proxy"
options: <type>
- http_proxy_host: <type>
http proxy host name.
- http_proxy_port: <type>
http proxy port.
- http_no_proxy: <type>
host names, which doesn't use proxy.
- http_proxy_auth: <type>
http proxy auth information. tuple of username and password. default is None
- proxy_type: <type>
if set to "socks5" PySocks wrapper will be used in place of a http proxy. default is "http"
"""
if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None
if proxy_host:
port = proxy_port
auth = proxy_auth
return proxy_host, port, auth
env_keys = ["http_proxy"]
if is_secure:
env_keys.insert(0, "https_proxy")
for key in env_keys:
value = os.environ.get(key, None)
if value:
proxy = urlparse(value)
auth = (proxy.username, proxy.password) if proxy.username else None
return proxy.hostname, proxy.port, auth
return None, 0, None

View file

@ -0,0 +1,110 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock(object):
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
try:
# If wsaccel is available we use compiled routines to validate UTF-8
# strings.
from wsaccel.utf8validator import Utf8Validator
def _validate_utf8(utfbytes):
return Utf8Validator().validate(utfbytes)[0]
except ImportError:
# UTF-8 validator
# python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
_UTF8_ACCEPT = 0
_UTF8_REJECT = 12
_UTF8D = [
# The first part of the table maps bytes to character classes that
# to reduce the size of the transition table and create bitmasks.
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8,
# The second part is a transition table that maps a combination
# of a state of the automaton and a character class to a state.
0,12,24,36,60,96,84,12,12,12,48,72, 12,12,12,12,12,12,12,12,12,12,12,12,
12, 0,12,12,12,12,12, 0,12, 0,12,12, 12,24,12,12,12,12,12,24,12,24,12,12,
12,12,12,12,12,12,12,24,12,12,12,12, 12,24,12,12,12,12,12,12,12,24,12,12,
12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12,
12,36,12,12,12,12,12,12,12,12,12,12, ]
def _decode(state, codep, ch):
tp = _UTF8D[ch]
codep = (ch & 0x3f) | (codep << 6) if (
state != _UTF8_ACCEPT) else (0xff >> tp) & ch
state = _UTF8D[256 + state + tp]
return state, codep
def _validate_utf8(utfbytes):
state = _UTF8_ACCEPT
codep = 0
for i in utfbytes:
if six.PY2:
i = ord(i)
state, codep = _decode(state, codep, i)
if state == _UTF8_REJECT:
return False
return True
def validate_utf8(utfbytes):
"""
validate utf8 byte string.
utfbytes: utf byte string to check.
return value: if valid utf8 string, return true. Otherwise, return false.
"""
return _validate_utf8(utfbytes)
def extract_err_message(exception):
if exception.args:
return exception.args[0]
else:
return None
def extract_error_code(exception):
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None

View file

@ -0,0 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View file

@ -0,0 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View file

@ -0,0 +1,6 @@
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade, Keep-Alive
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something

View file

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import websocket as ws
from websocket._abnf import *
import sys
sys.path[0:0] = [""]
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
class ABNFTest(unittest.TestCase):
def testInit(self):
a = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0)
self.assertEqual(a.rsv2, 0)
self.assertEqual(a.rsv3, 0)
self.assertEqual(a.opcode, 9)
self.assertEqual(a.data, '')
a_bad = ABNF(0,1,0,0, opcode=77)
self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77)
def testValidate(self):
a = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
self.assertRaises(ws.WebSocketProtocolException, a.validate)
a_bad = ABNF(0,1,0,0, opcode=77)
self.assertRaises(ws.WebSocketProtocolException, a_bad.validate)
a_close = ABNF(0,1,0,0, opcode=ABNF.OPCODE_CLOSE, data="abcdefgh1234567890abcdefgh1234567890abcdefgh1234567890abcdefgh1234567890")
self.assertRaises(ws.WebSocketProtocolException, a_close.validate)
# This caused an error in the Python 2.7 Github Actions build
# Uncomment test case when Python 2 support no longer wanted
# def testMask(self):
# ab = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
# bytes_val = bytes("aaaa", 'utf-8')
# self.assertEqual(ab._get_masked(bytes_val), bytes_val)
def testFrameBuffer(self):
fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True)
fb.clear
self.assertEqual(fb.header, None)
self.assertEqual(fb.length, None)
self.assertEqual(fb.mask, None)
self.assertEqual(fb.has_mask(), False)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import os.path
import websocket as ws
import sys
sys.path[0:0] = [""]
try:
import ssl
except ImportError:
HAVE_SSL = False
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
# Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
TRACEABLE = True
class WebSocketAppTest(unittest.TestCase):
class NotSetYet(object):
""" A marker class for signalling that a value hasn't been set yet.
"""
def setUp(self):
ws.enableTrace(TRACEABLE)
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
def tearDown(self):
WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet()
WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet()
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testKeepRunning(self):
""" A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
def on_open(self, *args, **kwargs):
""" Set the keep_running flag for later inspection and immediately
close the connection.
"""
WebSocketAppTest.keep_running_open = self.keep_running
self.close()
def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use.
"""
WebSocketAppTest.keep_running_close = self.keep_running
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever()
# if numpy is installed, this assertion fail
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
# WebSocketAppTest.NotSetYet))
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
# WebSocketAppTest.NotSetYet))
# self.assertEqual(True, WebSocketAppTest.keep_running_open)
# self.assertEqual(False, WebSocketAppTest.keep_running_close)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
def my_mask_key_func():
pass
def on_open(self, *args, **kwargs):
""" Set the value so the test can use it later on and immediately
close the connection.
"""
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
self.close()
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever()
# if numpy is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
# self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self):
""" A WebSocketApp should ping regularly
"""
def on_ping(app, msg):
print("Got a ping!")
app.close()
def on_pong(app, msg):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong)
app.run_forever(ping_interval=2, ping_timeout=1) # , sslopt={"cert_reqs": ssl.CERT_NONE}
self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=2, ping_timeout=3, sslopt={"cert_reqs": ssl.CERT_NONE})
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,117 @@
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import unittest
from websocket._cookiejar import SimpleCookieJar
class CookieJarTest(unittest.TestCase):
def testAdd(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def testSet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEqual(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def testGet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEqual(cookie_jar.get("abc.com.es"), "")
self.assertEqual(cookie_jar.get("xabc.com"), "")

View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import os
import os.path
import websocket as ws
from websocket._http import proxy_info, read_headers, _open_proxied_socket, _tunnel
import sys
sys.path[0:0] = [""]
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
class SockMock(object):
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class OptsList():
def __init__(self):
self.timeout = 0
self.sockopt = []
class HttpTest(unittest.TestCase):
def testReadHeader(self):
status, header, status_message = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed
self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
def testTunnel(self):
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header01.txt"), "example.com", 80, ("username", "password"))
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header02.txt"), "example.com", 80, ("username", "password"))
def testConnect(self):
# Not currently testing an actual proxy connection, so just check whether TypeError is raised
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h"))
def testProxyInfo(self):
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").type, "http")
self.assertRaises(ValueError, proxy_info, http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="badval")
self.assertEqual(proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http").host, "example.com")
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").port, "8080")
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").auth, None)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
#
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import sys
import os
from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
sys.path[0:0] = [""]
class UrlTest(unittest.TestCase):
def test_address_in_network(self):
self.assertTrue(_is_address_in_network('127.0.0.1', '127.0.0.0/8'))
self.assertTrue(_is_address_in_network('127.1.0.1', '127.0.0.0/8'))
self.assertFalse(_is_address_in_network('127.1.0.1', '127.0.0.0/24'))
def testParseUrl(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/r/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080/")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("ws://www.example.com:8080")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/")
self.assertEqual(p[3], False)
p = parse_url("wss://www.example.com:8080/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://www.example.com:8080/r?key=value")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r?key=value")
self.assertEqual(p[3], True)
self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
return
p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("ws://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], False)
p = parse_url("wss://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 443)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
p = parse_url("wss://[2a03:4000:123:83::3]:8080/r")
self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 8080)
self.assertEqual(p[2], "/r")
self.assertEqual(p[3], True)
class IsNoProxyHostTest(unittest.TestCase):
def setUp(self):
self.no_proxy = os.environ.get("no_proxy", None)
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testMatchAll(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['*']))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ['*']))
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['other.websocket.org', '*']))
os.environ['no_proxy'] = '*'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("192.168.0.1", None))
os.environ['no_proxy'] = 'other.websocket.org, *'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def testIpAddress(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.1']))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ['127.0.0.1']))
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['other.websocket.org', '127.0.0.1']))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ['other.websocket.org', '127.0.0.1']))
os.environ['no_proxy'] = '127.0.0.1'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
os.environ['no_proxy'] = 'other.websocket.org, 127.0.0.1'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def testIpAddressInRange(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ['127.0.0.0/8']))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ['127.0.0.0/8']))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ['127.0.0.0/24']))
os.environ['no_proxy'] = '127.0.0.0/8'
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertTrue(_is_no_proxy_host("127.0.0.2", None))
os.environ['no_proxy'] = '127.0.0.0/24'
self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def testHostnameMatch(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ['my.websocket.org']))
self.assertTrue(_is_no_proxy_host("my.websocket.org", ['other.websocket.org', 'my.websocket.org']))
self.assertFalse(_is_no_proxy_host("my.websocket.org", ['other.websocket.org']))
os.environ['no_proxy'] = 'my.websocket.org'
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
self.assertFalse(_is_no_proxy_host("other.websocket.org", None))
os.environ['no_proxy'] = 'other.websocket.org, my.websocket.org'
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def testHostnameMatchDomain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['.websocket.org']))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", ['.websocket.org']))
self.assertTrue(_is_no_proxy_host("any.websocket.org", ['my.websocket.org', '.websocket.org']))
self.assertFalse(_is_no_proxy_host("any.websocket.com", ['.websocket.org']))
os.environ['no_proxy'] = '.websocket.org'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", None))
self.assertFalse(_is_no_proxy_host("any.websocket.com", None))
os.environ['no_proxy'] = 'my.websocket.org, .websocket.org'
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
class ProxyInfoTest(unittest.TestCase):
def setUp(self):
self.http_proxy = os.environ.get("http_proxy", None)
self.https_proxy = os.environ.get("https_proxy", None)
self.no_proxy = os.environ.get("no_proxy", None)
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
def tearDown(self):
if self.http_proxy:
os.environ["http_proxy"] = self.http_proxy
elif "http_proxy" in os.environ:
del os.environ["http_proxy"]
if self.https_proxy:
os.environ["https_proxy"] = self.https_proxy
elif "https_proxy" in os.environ:
del os.environ["https_proxy"]
if self.no_proxy:
os.environ["no_proxy"] = self.no_proxy
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testProxyFromArgs(self):
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128),
("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128),
("localhost", 3128, None))
self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(
get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")),
("localhost", 0, ("a", "b")))
self.assertEqual(
get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128,
no_proxy=["example.com"], proxy_auth=("a", "b")),
("localhost", 3128, ("a", "b")))
self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128,
no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")),
(None, 0, None))
def testProxyFromEnv(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None))
os.environ["http_proxy"] = "http://localhost/"
os.environ["https_proxy"] = "http://localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, None))
os.environ["http_proxy"] = "http://localhost:3128/"
os.environ["https_proxy"] = "http://localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, None))
os.environ["http_proxy"] = "http://a:b@localhost/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost/"
os.environ["https_proxy"] = "http://a:b@localhost2/"
os.environ["no_proxy"] = "example1.com,example2.com"
self.assertEqual(get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b")))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "example1.com,example2.com, .websocket.org"
self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None))
os.environ["http_proxy"] = "http://a:b@localhost:3128/"
os.environ["https_proxy"] = "http://a:b@localhost2:3128/"
os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16"
self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None))
self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None))
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,433 @@
# -*- coding: utf-8 -*-
#
"""
"""
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
"""
import sys
sys.path[0:0] = [""]
import os
import os.path
import socket
import six
# websocket-client
import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header
from websocket._http import read_headers
from websocket._utils import validate_utf8
if six.PY3:
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
try:
from ssl import SSLError
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
# Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
TRACEABLE = True
def create_mask_key(_):
return "abcd"
class SockMock(object):
def __init__(self):
self.data = []
self.sent = []
def add_packet(self, data):
self.data.append(data)
def gettimeout(self):
return None
def recv(self, bufsize):
if self.data:
e = self.data.pop(0)
if isinstance(e, Exception):
raise e
if len(e) > bufsize:
self.data.insert(0, e[bufsize:])
return e[:bufsize]
def send(self, data):
self.sent.append(data)
return len(data)
def close(self):
pass
class HeaderSockMock(SockMock):
def __init__(self, fname):
SockMock.__init__(self)
path = os.path.join(os.path.dirname(__file__), fname)
with open(path, "rb") as f:
self.add_packet(f.read())
class WebSocketTest(unittest.TestCase):
def setUp(self):
ws.enableTrace(TRACEABLE)
def tearDown(self):
pass
def testDefaultTimeout(self):
self.assertEqual(ws.getdefaulttimeout(), None)
ws.setdefaulttimeout(10)
self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None)
def testWSKey(self):
key = _create_sec_websocket_key()
self.assertTrue(key != 24)
self.assertTrue(six.u("¥n") not in key)
def testNonce(self):
""" WebSocket key should be a random 16-byte nonce.
"""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
def testWsUtils(self):
key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = {
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0="}
self.assertEqual(_validate_header(required_header, key, None), (True, None))
header = required_header.copy()
header["upgrade"] = "http"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["upgrade"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["connection"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["connection"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-accept"] = "something"
self.assertEqual(_validate_header(header, key, None), (False, None))
del header["sec-websocket-accept"]
self.assertEqual(_validate_header(header, key, None), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sub1"
self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1"))
self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
header = required_header.copy()
header["sec-websocket-protocol"] = "sUb1"
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1"))
header = required_header.copy()
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def testReadHeader(self):
status, header, status_message = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
status, header, status_message = read_headers(HeaderSockMock("data/header03.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
def testSend(self):
# TODO: add longer frame data
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEqual(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e"))
sock.send("こんにちは")
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
sock.send(u"こんにちは")
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
# sock.send("x" * 5000)
# self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
self.assertEqual(sock.send_binary(b'1111111111101'), 19)
def testRecv(self):
# TODO: add longer frame data
sock = ws.WebSocket()
s = sock.sock = SockMock()
something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
s.add_packet(something)
data = sock.recv()
self.assertEqual(data, "こんにちは")
s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e"))
data = sock.recv()
self.assertEqual(data, "Hello")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self):
count = 2
for _ in ws.create_connection('wss://stream.meetup.com/2/rsvps'):
count -= 1
if count == 0:
break
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testNext(self):
sock = ws.create_connection('wss://stream.meetup.com/2/rsvps')
self.assertEqual(str, type(next(sock)))
def testInternalRecvStrict(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(six.b("foo"))
s.add_packet(socket.timeout())
s.add_packet(six.b("bar"))
# s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz"))
with self.assertRaises(ws.WebSocketTimeoutException):
sock.frame_buffer.recv_strict(9)
# if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9)
# else:
# with self.assertRaises(SSLError):
# data = sock._recv_strict(9)
data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz"))
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(six.b("\x81"))
s.add_packet(socket.timeout())
s.add_packet(six.b("\x8dabcd\x29\x07\x0f\x08\x0e"))
s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40"))
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv()
data = sock.recv()
self.assertEqual(data, "Hello, World!")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithSimpleFragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C"))
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17"))
data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithFireEventOfFragmentation(self):
sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C"))
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C"))
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17"))
_, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is "))
_, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is "))
_, data = sock.recv_data()
self.assertEqual(data, six.b("the soul of wit"))
# OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C"))
with self.assertRaises(ws.WebSocketException):
sock.recv_data()
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testClose(self):
sock = ws.WebSocket()
sock.sock = SockMock()
sock.connected = True
sock.close()
self.assertEqual(sock.connected, False)
sock = ws.WebSocket()
s = sock.sock = SockMock()
sock.connected = True
s.add_packet(six.b('\x88\x80\x17\x98p\x84'))
sock.recv()
self.assertEqual(sock.connected, False)
def testRecvContFragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17"))
self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15"
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07"
"\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04"))
data = sock.recv()
self.assertEqual(
data,
"Once more unto the breach, dear friends, once more")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithFragmentationAndControlFrame(self):
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet(six.b("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA"))
# OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
# OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c"
"\x08\x0c\x04"))
data = sock.recv()
self.assertEqual(data, "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
self.assertEqual(
s.sent[0],
six.b("\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocket(self):
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEqual(s, None)
s.send("Hello, World")
result = s.recv()
self.assertEqual(result, "Hello, World")
s.send(u"こにゃにゃちは、世界")
result = s.recv()
self.assertEqual(result, "こにゃにゃちは、世界")
self.assertRaises(ValueError, s.send_close, -1, "")
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingPong(self):
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEqual(s, None)
s.ping("Hello")
s.pong("Hi")
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSecureWebSocket(self):
import ssl
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
self.assertEqual(s.getstatus(), 101)
self.assertNotEqual(s.getheaders(), None)
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocketWithCustomHeader(self):
s = ws.create_connection("ws://echo.websocket.org/",
headers={"User-Agent": "PythonWebsocketClient"})
self.assertNotEqual(s, None)
s.send("Hello, World")
result = s.recv()
self.assertEqual(result, "Hello, World")
self.assertRaises(ValueError, s.close, -1, "")
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testAfterClose(self):
s = ws.create_connection("ws://echo.websocket.org/")
self.assertNotEqual(s, None)
s.close()
self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
class SockOptTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockOpt(self):
sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
s = ws.create_connection("ws://echo.websocket.org", sockopt=sockopt)
self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0)
s.close()
class UtilsTest(unittest.TestCase):
def testUtf8Validator(self):
state = validate_utf8(six.b('\xf0\x90\x80\x80'))
self.assertEqual(state, True)
state = validate_utf8(six.b('\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited'))
self.assertEqual(state, False)
state = validate_utf8(six.b(''))
self.assertEqual(state, True)
if __name__ == "__main__":
unittest.main()