diff --git a/docs/changelog-fragments/280.bugfix.rst b/docs/changelog-fragments/280.bugfix.rst new file mode 100644 index 000000000..9b806055b --- /dev/null +++ b/docs/changelog-fragments/280.bugfix.rst @@ -0,0 +1,2 @@ +Improved ``channel.exec_command`` to always use a newly created ``ssh_channel`` to avoid +segfaults on repeated calls -- by :user:`Qalthos` diff --git a/src/pylibsshext/channel.pyx b/src/pylibsshext/channel.pyx index 0ac5d9ea7..baa9330c6 100644 --- a/src/pylibsshext/channel.pyx +++ b/src/pylibsshext/channel.pyx @@ -89,12 +89,10 @@ cdef class Channel: def poll(self, timeout=-1, stderr=0): if timeout < 0: rc = libssh.ssh_channel_poll(self._libssh_channel, stderr) - if rc == libssh.SSH_ERROR: - raise LibsshChannelException("Failed to poll channel: [%d]" % rc) else: rc = libssh.ssh_channel_poll_timeout(self._libssh_channel, timeout, stderr) - if rc == libssh.SSH_ERROR: - raise LibsshChannelException("Failed to poll channel: [%d]" % rc) + if rc == libssh.SSH_ERROR: + raise LibsshChannelException("Failed to poll channel: [{0}]".format(rc)) return rc def read_nonblocking(self, size=1024, stderr=0): @@ -113,7 +111,10 @@ cdef class Channel: return self.read_nonblocking(size=size, stderr=stderr) def write(self, data): - return libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data)) + written = libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data)) + if written == libssh.SSH_ERROR: + raise LibsshChannelException("Failed to write to ssh channel") + return written def sendall(self, data): return self.write(data) @@ -139,11 +140,21 @@ cdef class Channel: return response def exec_command(self, command): - rc = libssh.ssh_channel_request_exec(self._libssh_channel, command.encode("utf-8")) + # request_exec requires a fresh channel each run, so do not use the existing channel + cdef libssh.ssh_channel channel = libssh.ssh_channel_new(self._libssh_session) + if channel is NULL: + raise MemoryError + + rc = libssh.ssh_channel_open_session(channel) + if rc != libssh.SSH_OK: + libssh.ssh_channel_free(channel) + raise LibsshChannelException("Failed to open_session: [{0}]".format(rc)) + rc = libssh.ssh_channel_request_exec(channel, command.encode("utf-8")) if rc != libssh.SSH_OK: - self.close() - raise CalledProcessError() + libssh.ssh_channel_close(channel) + libssh.ssh_channel_free(channel) + raise LibsshChannelException("Failed to execute command [{0}]: [{1}]".format(command, rc)) result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'') cdef callbacks.ssh_channel_callbacks_struct cb @@ -151,11 +162,13 @@ cdef class Channel: cb.channel_data_function = &_process_outputs cb.userdata = result callbacks.ssh_callbacks_init(&cb) - callbacks.ssh_set_channel_callbacks(self._libssh_channel, &cb) - - libssh.ssh_channel_send_eof(self._libssh_channel) + callbacks.ssh_set_channel_callbacks(channel, &cb) - result.returncode = self.get_channel_exit_status() + libssh.ssh_channel_send_eof(channel) + result.returncode = libssh.ssh_channel_get_exit_status(channel) + if channel is not NULL: + libssh.ssh_channel_close(channel) + libssh.ssh_channel_free(channel) return result diff --git a/src/pylibsshext/includes/libssh.pxd b/src/pylibsshext/includes/libssh.pxd index 2b8215e83..8862220d5 100644 --- a/src/pylibsshext/includes/libssh.pxd +++ b/src/pylibsshext/includes/libssh.pxd @@ -21,7 +21,7 @@ from libc.stdint cimport uint32_t cdef extern from "libssh/libssh.h" nogil: - cpdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)" + cdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)" cdef struct ssh_session_struct: pass diff --git a/tests/unit/channel_test.py b/tests/unit/channel_test.py index 17b4ddb9b..d7832f49f 100644 --- a/tests/unit/channel_test.py +++ b/tests/unit/channel_test.py @@ -32,11 +32,14 @@ def ssh_channel(ssh_client_session): 'Ref: https://github.com/ansible/pylibssh/issues/57', # noqa: WPS326 strict=False, ) -@pytest.mark.forked() # noqa: PT023 -- it's unclear if braces are needed here +@pytest.mark.forked def test_exec_command(ssh_channel): """Test getting the output of a remotely executed command.""" u_cmd_out = ssh_channel.exec_command('echo -n Hello World').stdout.decode() assert u_cmd_out == u'Hello World' # noqa: WPS302 + # Test that repeated calls to exec_command do not segfault. + u_cmd_out = ssh_channel.exec_command('echo -n Hello Again').stdout.decode() + assert u_cmd_out == u'Hello Again' # noqa: WPS302 def test_double_close(ssh_channel):