Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exec_command segfault partial fix #280

Merged
merged 5 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog-fragments/280.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Improved ``channel.exec_command`` to always use a newly created ``ssh_channel`` to avoid
segfaults on repeated calls -- by :user:`Qalthos`
37 changes: 25 additions & 12 deletions src/pylibsshext/channel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ cdef class Channel:
def poll(self, timeout=-1, stderr=0):
if timeout < 0:
rc = libssh.ssh_channel_poll(self._libssh_channel, stderr)
if rc == libssh.SSH_ERROR:
raise LibsshChannelException("Failed to poll channel: [%d]" % rc)
else:
rc = libssh.ssh_channel_poll_timeout(self._libssh_channel, timeout, stderr)
if rc == libssh.SSH_ERROR:
raise LibsshChannelException("Failed to poll channel: [%d]" % rc)
if rc == libssh.SSH_ERROR:
raise LibsshChannelException("Failed to poll channel: [{0}]".format(rc))
return rc

def read_nonblocking(self, size=1024, stderr=0):
Expand All @@ -113,7 +111,10 @@ cdef class Channel:
return self.read_nonblocking(size=size, stderr=stderr)

def write(self, data):
return libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data))
written = libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data))
if written == libssh.SSH_ERROR:
raise LibsshChannelException("Failed to write to ssh channel")
return written

def sendall(self, data):
return self.write(data)
Expand All @@ -139,23 +140,35 @@ cdef class Channel:
return response

def exec_command(self, command):
rc = libssh.ssh_channel_request_exec(self._libssh_channel, command.encode("utf-8"))
# request_exec requires a fresh channel each run, so do not use the existing channel
cdef libssh.ssh_channel channel = libssh.ssh_channel_new(self._libssh_session)
if channel is NULL:
raise MemoryError

rc = libssh.ssh_channel_open_session(channel)
if rc != libssh.SSH_OK:
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to open_session: [{0}]".format(rc))

rc = libssh.ssh_channel_request_exec(channel, command.encode("utf-8"))
if rc != libssh.SSH_OK:
self.close()
raise CalledProcessError()
libssh.ssh_channel_close(channel)
libssh.ssh_channel_free(channel)
raise LibsshChannelException("Failed to execute command [{0}]: [{1}]".format(command, rc))
result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'')

cdef callbacks.ssh_channel_callbacks_struct cb
memset(&cb, 0, sizeof(cb))
cb.channel_data_function = <callbacks.ssh_channel_data_callback>&_process_outputs
cb.userdata = <void *>result
callbacks.ssh_callbacks_init(&cb)
callbacks.ssh_set_channel_callbacks(self._libssh_channel, &cb)

libssh.ssh_channel_send_eof(self._libssh_channel)
callbacks.ssh_set_channel_callbacks(channel, &cb)

result.returncode = self.get_channel_exit_status()
libssh.ssh_channel_send_eof(channel)
result.returncode = libssh.ssh_channel_get_exit_status(channel)
if channel is not NULL:
libssh.ssh_channel_close(channel)
libssh.ssh_channel_free(channel)

return result

Expand Down
2 changes: 1 addition & 1 deletion src/pylibsshext/includes/libssh.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from libc.stdint cimport uint32_t

cdef extern from "libssh/libssh.h" nogil:

cpdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)"
cdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)"

cdef struct ssh_session_struct:
pass
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def ssh_channel(ssh_client_session):
'Ref: https://github.com/ansible/pylibssh/issues/57', # noqa: WPS326
strict=False,
)
@pytest.mark.forked() # noqa: PT023 -- it's unclear if braces are needed here
@pytest.mark.forked
def test_exec_command(ssh_channel):
"""Test getting the output of a remotely executed command."""
u_cmd_out = ssh_channel.exec_command('echo -n Hello World').stdout.decode()
assert u_cmd_out == u'Hello World' # noqa: WPS302
# Test that repeated calls to exec_command do not segfault.
u_cmd_out = ssh_channel.exec_command('echo -n Hello Again').stdout.decode()
assert u_cmd_out == u'Hello Again' # noqa: WPS302


def test_double_close(ssh_channel):
Expand Down