Skip to content

Commit 1277092

Browse files
authored
Merge pull request #67 from jmptbl/master
More fixes for non-blocking SSL connections.
2 parents f06c361 + cd68a87 commit 1277092

File tree

1 file changed

+55
-32
lines changed

1 file changed

+55
-32
lines changed

puka/connection.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _connect(self):
9696
self.needs_write = self.needs_write_connect
9797
self.on_write = self.on_write_connect
9898
if self.ssl:
99-
self.on_read = self.on_read_handshake
99+
self.on_read = self.on_read_handshake_connect
100100
else:
101101
self.on_read = self.on_read_nohandshake
102102

@@ -129,24 +129,45 @@ def _wrap_socket(self, sock):
129129
cert_reqs=cert_reqs,
130130
ca_certs=ca_certs)
131131

132+
def on_read_handshake_connect(self):
133+
errno = self.sd.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
134+
if errno:
135+
self._shutdown(exceptions.mark_frame(spec.Frame(),
136+
exceptions.ConnectionBroken()))
137+
return
138+
self.sd = self._wrap_socket(self.sd)
139+
self.needs_write = self.needs_write_handshake
140+
self.on_write = self.on_write_handshake
141+
self.on_read = self.on_read_handshake
142+
132143
def on_read_handshake(self):
133-
pass
144+
try:
145+
self.sd.do_handshake()
146+
self.needs_write = self.needs_write_nohandshake
147+
self.on_write = self.on_write_nohandshake
148+
self.on_read = self.on_read_nohandshake
149+
except ssl.SSLError, e:
150+
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
151+
return
152+
raise
153+
except socket.error, e:
154+
if e.errno == errno.EAGAIN:
155+
return
156+
self._shutdown(exceptions.mark_frame(spec.Frame(),
157+
exceptions.ConnectionBroken()))
158+
return
134159

135160
def on_read_nohandshake(self):
136-
while True:
137-
try:
138-
r = self.sd.recv(Connection.frame_max)
139-
break
140-
except ssl.SSLError, e:
141-
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
142-
select.select([self.sd], [], [])
143-
continue
144-
raise
145-
except socket.error, e:
146-
if e.errno == errno.EAGAIN:
147-
return
148-
else:
149-
raise
161+
try:
162+
r = self.sd.recv(Connection.frame_max)
163+
except ssl.SSLError, e:
164+
if e.args[0] == ssl.SSL_ERROR_WANT_READ:
165+
return
166+
raise
167+
except socket.error, e:
168+
if e.errno == errno.EAGAIN:
169+
return
170+
raise
150171

151172
if len(r) == 0:
152173
# a = self.sd.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
@@ -261,6 +282,7 @@ def on_write_connect(self):
261282
self.sd = self._wrap_socket(self.sd)
262283
self.needs_write = self.needs_write_handshake
263284
self.on_write = self.on_write_handshake
285+
self.on_read = self.on_read_handshake
264286
else:
265287
self.needs_write = self.needs_write_nohandshake
266288
self.on_write = self.on_write_nohandshake
@@ -272,21 +294,17 @@ def on_write_handshake(self):
272294
def on_write_nohandshake(self):
273295
if not self.send_buf: # already shutdown or empty buffer?
274296
return
275-
while True:
276-
try:
277-
# On windows socket.send blows up if the buffer is too large.
278-
r = self.sd.send(self.send_buf.read(128*1024))
279-
break
280-
except ssl.SSLError, e:
281-
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
282-
select.select([], [self.sd], [])
283-
continue
284-
raise
285-
except socket.error, e:
286-
if e.errno in (errno.EWOULDBLOCK, errno.ENOBUFS):
287-
return
288-
else:
289-
raise
297+
try:
298+
# On windows socket.send blows up if the buffer is too large.
299+
r = self.sd.send(self.send_buf.read(128*1024))
300+
except ssl.SSLError, e:
301+
if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
302+
return
303+
raise
304+
except socket.error, e:
305+
if e.errno in (errno.EWOULDBLOCK, errno.ENOBUFS):
306+
return
307+
raise
290308
self.send_buf.consume(r)
291309

292310
def _tune_frame_max(self, new_frame_max):
@@ -502,7 +520,12 @@ def set_ridiculously_high_buffers(sd):
502520
Set large tcp/ip buffers kernel. Let's move the complexity
503521
to the operating system! That's a wonderful idea!
504522
'''
505-
for flag in [socket.SO_SNDBUF, socket.SO_RCVBUF]:
523+
flags = []
524+
if hasattr(socket, 'SO_SNDBUF'):
525+
flags.append(socket.SO_SNDBUF)
526+
if hasattr(socket, 'SO_RECVBUF'):
527+
flags.append(socket.SO_RECVBUF)
528+
for flag in flags:
506529
for i in range(10):
507530
bef = sd.getsockopt(socket.SOL_SOCKET, flag)
508531
try:

0 commit comments

Comments
 (0)