WebSockets in Python

Since the dawn of AJAX, web developers have longed for persistent server-side connections. For a while Comet was hailed as the bastion of “server push”, but deep down we knew it was just a hack. Now finally, years later, we have an API and a protocol being standardized for socket connections between the browser and the server – aptly named, WebSockets.

WebSockets are bi-directional communication channels that run on single TCP sockets allowing communication between the client and the server. Since they behave like regular INET sockets, we should be able to easily implement them with existing tools. However, when I was looking for example implementations in Python, I didn’t find anything that quite satisfied me.

Python sockets module

import socket
websocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
websocket.bind(("localhost", 9999))
websocket.listen(5)

That’s all you need to get the WebSocket up and running. Granted, it’s not very useful since you can’t connect to it (no handshake), but it’s a WebSocket nonetheless. When a client connects to the socket, it initiates the handshake with the following

GET / HTTP/1.1
Upgrade: WebSocket
Connection: Upgrade
Host: localhost:9999
Origin: file://
Sec-WebSocket-Key1:   d3L703 2  {63 k  L1( 90
Sec-WebSocket-Key2:   14   +40Z7R<12om I8  0[

??????????????

And expects a response in a similar form:

HTTP/1.1 101 Web Socket Protocol Handshake
Upgrade: WebSocket
Connection: Upgrade
WebSocket-Origin: file://
WebSocket-Location: ws://localhost:9999/
Sec-Websocket-Origin: file://
Sec-Websocket-Location: ws://localhost:9999/

??????????????

The “?” are random bits used in the challenge/response part of the handshake. Interesting note: In addition to failing to do the Challenge/Response, Chrome looks for the “Websocket-X” headers, while Safari (correctly) looks for the “Sec-Websocket-X” headers.

Example of WebSocket server: http://gist.github.com/512987

import struct
import socket
import hashlib
import sys
from select import select
import re
import logging
from threading import Thread
import signal

class WebSocket(object):
    handshake = (
        "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
        "Upgrade: WebSocket\r\n"
        "Connection: Upgrade\r\n"
        "WebSocket-Origin: %(origin)s\r\n"
        "WebSocket-Location: ws://%(bind)s:%(port)s/\r\n"
        "Sec-Websocket-Origin: %(origin)s\r\n"
        "Sec-Websocket-Location: ws://%(bind)s:%(port)s/\r\n"
        "\r\n"
    )
    def __init__(self, client, server):
        self.client = client
        self.server = server
        self.handshaken = False
        self.header = ""
        self.data = ""
        
    def feed(self, data):
        if not self.handshaken:
            self.header += data
            if self.header.find('\r\n\r\n') != -1:
                parts = self.header.split('\r\n\r\n', 1)
                self.header = parts[0]
                if self.dohandshake(self.header, parts[1]):
                    logging.info("Handshake successful")
                    self.handshaken = True
        else:
            self.data += data
            validated = []
            msgs = self.data.split('\xff')
            self.data = msgs.pop()
            for msg in msgs:
                if msg[0] == '\x00':
                    self.onmessage(msg[1:])
                    
    def dohandshake(self, header, key=None):
        logging.debug("Begin handshake: %s" % header)
        digitRe = re.compile(r'[^0-9]')
        spacesRe = re.compile(r'\s')
        part_1 = part_2 = origin = None
        for line in header.split('\r\n')[1:]:
            name, value = line.split(': ', 1)
            if name.lower() == "sec-websocket-key1":
                key_number_1 = int(digitRe.sub('', value))
                spaces_1 = len(spacesRe.findall(value))
                if spaces_1 == 0:
                    return False
                if key_number_1 % spaces_1 != 0:
                    return False
                part_1 = key_number_1 / spaces_1
            elif name.lower() == "sec-websocket-key2":
                key_number_2 = int(digitRe.sub('', value))
                spaces_2 = len(spacesRe.findall(value))
                if spaces_2 == 0:
                    return False
                if key_number_2 % spaces_2 != 0:
                    return False
                part_2 = key_number_2 / spaces_2
            elif name.lower() == "origin":
                origin = value
        if part_1 and part_2:
            logging.debug("Using challenge + response")
            challenge = struct.pack('!I', part_1) + struct.pack('!I', part_2) + key
            response = hashlib.md5(challenge).digest()
            handshake = WebSocket.handshake + response
        else:
            logging.warning("Not using challenge + response")
            handshake = WebSocket.handshake
        handshake = handshake % {'origin': origin, 'port': self.server.port,
                                    'bind': self.server.bind}
        logging.debug("Sending handshake %s" % handshake)
        self.client.send(handshake)
        return True
                     
    def onmessage(self, data):
        logging.info("Got message: %s" % data)

    def send(self, data):
        logging.info("Sent message: %s" % data)
        self.client.send("\x00%s\xff" % data)
        
    def close(self):
        self.client.close()

class WebSocketServer(object):
    def __init__(self, bind, port, cls):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.bind((bind, port))
        self.bind = bind
        self.port = port
        self.cls = cls
        self.connections = {}
        self.listeners = [self.socket]
            
    def listen(self, backlog=5):
        self.socket.listen(backlog)
        logging.info("Listening on %s" % self.port)
        self.running = True
        while self.running:
            rList, wList, xList = select(self.listeners, [], self.listeners, 1)
            for ready in rList:
                if ready == self.socket:
                    logging.debug("New client connection")
                    client, address = self.socket.accept()
                    fileno = client.fileno()
                    self.listeners.append(fileno)
                    self.connections[fileno] = self.cls(client, self)
                else:
                    logging.debug("Client ready for reading %s" % ready)
                    client = self.connections[ready].client
                    data = client.recv(1024)
                    fileno = client.fileno()
                    if data:
                        self.connections[fileno].feed(data)
                    else:
                        logging.debug("Closing client %s" % ready)
                        self.connections[fileno].close()
                        del self.connections[fileno]
                        self.listeners.remove(ready)
            for failed in xList:
                if failed == self.socket:
                    logging.error("Socket broke")
                    for fileno, conn in self.connections:
                        conn.close()
                    self.running = False

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
    server = WebSocketServer("localhost", 9999, WebSocket)
    server_thread = Thread(target=server.listen, args=[5])
    server_thread.start()
    # Add SIGINT handler for killing the threads
    def signal_handler(signal, frame):
        logging.info("Caught Ctrl+C, shutting down...")
        server.running = False
        sys.exit()
    signal.signal(signal.SIGINT, signal_handler)



coded by nessus