Skip to content

Commit 70f0450

Browse files
authored
connectors/ssh: add flag to use SCP as file transfer protocol
1 parent 73cc854 commit 70f0450

File tree

3 files changed

+244
-7
lines changed

3 files changed

+244
-7
lines changed

pyinfra/connectors/scp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .client import SCPClient # noqa: F401

pyinfra/connectors/scp/client.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from __future__ import annotations
2+
3+
import ntpath
4+
import os
5+
from pathlib import PurePath
6+
from shlex import quote
7+
from socket import timeout as SocketTimeoutError
8+
from typing import IO, AnyStr
9+
10+
from paramiko import Channel
11+
from paramiko.transport import Transport
12+
13+
SCP_COMMAND = b"scp"
14+
15+
16+
# Unicode conversion functions; assume UTF-8
17+
def asbytes(s: bytes | str | PurePath) -> bytes:
18+
"""Turns unicode into bytes, if needed.
19+
20+
Assumes UTF-8.
21+
"""
22+
if isinstance(s, bytes):
23+
return s
24+
elif isinstance(s, PurePath):
25+
return bytes(s)
26+
else:
27+
return s.encode("utf-8")
28+
29+
30+
def asunicode(s: bytes | str) -> str:
31+
"""Turns bytes into unicode, if needed.
32+
33+
Uses UTF-8.
34+
"""
35+
if isinstance(s, bytes):
36+
return s.decode("utf-8", "replace")
37+
else:
38+
return s
39+
40+
41+
class SCPClient:
42+
"""
43+
An scp1 implementation, compatible with openssh scp.
44+
Raises SCPException for all transport related errors. Local filesystem
45+
and OS errors pass through.
46+
47+
Main public methods are .putfo and .getfo
48+
"""
49+
50+
def __init__(
51+
self,
52+
transport: Transport,
53+
buff_size: int = 16384,
54+
socket_timeout: float = 10.0,
55+
):
56+
self.transport = transport
57+
self.buff_size = buff_size
58+
self.socket_timeout = socket_timeout
59+
self._channel: Channel | None = None
60+
self.scp_command = SCP_COMMAND
61+
62+
@property
63+
def channel(self) -> Channel:
64+
"""Return an open Channel, (re)opening if needed."""
65+
if self._channel is None or self._channel.closed:
66+
self._channel = self.transport.open_session()
67+
return self._channel
68+
69+
def __enter__(self):
70+
_ = self.channel # triggers opening if not already open
71+
return self
72+
73+
def __exit__(self, type, value, traceback):
74+
self.close()
75+
76+
def putfo(
77+
self,
78+
fl: IO[AnyStr],
79+
remote_path: str | bytes,
80+
mode: str | bytes = "0644",
81+
size: int | None = None,
82+
) -> None:
83+
if size is None:
84+
pos = fl.tell()
85+
fl.seek(0, os.SEEK_END) # Seek to end
86+
size = fl.tell() - pos
87+
fl.seek(pos, os.SEEK_SET) # Seek back
88+
89+
self.channel.settimeout(self.socket_timeout)
90+
self.channel.exec_command(
91+
self.scp_command + b" -t " + asbytes(quote(asunicode(remote_path)))
92+
)
93+
self._recv_confirm()
94+
self._send_file(fl, remote_path, mode, size=size)
95+
self.close()
96+
97+
def getfo(self, remote_path: str, fl: IO):
98+
remote_path_sanitized = quote(remote_path)
99+
if os.name == "nt":
100+
remote_file_name = ntpath.basename(remote_path_sanitized)
101+
else:
102+
remote_file_name = os.path.basename(remote_path_sanitized)
103+
self.channel.settimeout(self.socket_timeout)
104+
self.channel.exec_command(self.scp_command + b" -f " + asbytes(remote_path_sanitized))
105+
self._recv_all(fl, remote_file_name)
106+
self.close()
107+
return fl
108+
109+
def close(self):
110+
"""close scp channel"""
111+
if self._channel is not None:
112+
self._channel.close()
113+
self._channel = None
114+
115+
def _send_file(self, fl, name, mode, size):
116+
basename = asbytes(os.path.basename(name))
117+
# The protocol can't handle \n in the filename.
118+
# Quote them as the control sequence \^J for now,
119+
# which is how openssh handles it.
120+
self.channel.sendall(
121+
("C%s %d " % (mode, size)).encode("ascii") + basename.replace(b"\n", b"\\^J") + b"\n"
122+
)
123+
self._recv_confirm()
124+
file_pos = 0
125+
buff_size = self.buff_size
126+
chan = self.channel
127+
while file_pos < size:
128+
chan.sendall(fl.read(buff_size))
129+
file_pos = fl.tell()
130+
chan.sendall(b"\x00")
131+
self._recv_confirm()
132+
133+
def _recv_confirm(self):
134+
# read scp response
135+
msg = b""
136+
try:
137+
msg = self.channel.recv(512)
138+
except SocketTimeoutError:
139+
raise SCPException("Timeout waiting for scp response")
140+
# slice off the first byte, so this compare will work in py2 and py3
141+
if msg and msg[0:1] == b"\x00":
142+
return
143+
elif msg and msg[0:1] == b"\x01":
144+
raise SCPException(asunicode(msg[1:]))
145+
elif self.channel.recv_stderr_ready():
146+
msg = self.channel.recv_stderr(512)
147+
raise SCPException(asunicode(msg))
148+
elif not msg:
149+
raise SCPException("No response from server")
150+
else:
151+
raise SCPException("Invalid response from server", msg)
152+
153+
def _recv_all(self, fh: IO, remote_file_name: str) -> None:
154+
# loop over scp commands, and receive as necessary
155+
commands = (b"C",)
156+
while not self.channel.closed:
157+
# wait for command as long as we're open
158+
self.channel.sendall(b"\x00")
159+
msg = self.channel.recv(1024)
160+
if not msg: # chan closed while receiving
161+
break
162+
assert msg[-1:] == b"\n"
163+
msg = msg[:-1]
164+
code = msg[0:1]
165+
if code not in commands:
166+
raise SCPException(asunicode(msg[1:]))
167+
self._recv_file(msg[1:], fh, remote_file_name)
168+
169+
def _recv_file(self, cmd: bytes, fh: IO, remote_file_name: str) -> None:
170+
chan = self.channel
171+
parts = cmd.strip().split(b" ", 2)
172+
173+
try:
174+
size = int(parts[1])
175+
except (ValueError, IndexError):
176+
chan.send(b"\x01")
177+
chan.close()
178+
raise SCPException("Bad file format")
179+
180+
buff_size = self.buff_size
181+
pos = 0
182+
chan.send(b"\x00")
183+
try:
184+
while pos < size:
185+
# we have to make sure we don't read the final byte
186+
if size - pos <= buff_size:
187+
buff_size = size - pos
188+
data = chan.recv(buff_size)
189+
if not data:
190+
raise SCPException("Underlying channel was closed")
191+
fh.write(data)
192+
pos = fh.tell()
193+
msg = chan.recv(512)
194+
if msg and msg[0:1] != b"\x00":
195+
raise SCPException(asunicode(msg[1:]))
196+
except SocketTimeoutError:
197+
chan.close()
198+
raise SCPException("Error receiving, socket.timeout")
199+
200+
201+
class SCPException(Exception):
202+
"""SCP exception class"""
203+
204+
pass

pyinfra/connectors/ssh.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from shutil import which
66
from socket import error as socket_error, gaierror
77
from time import sleep
8-
from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple
8+
from typing import IO, TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple
99

1010
import click
1111
from paramiko import AuthenticationException, BadHostKeyException, SFTPClient, SSHException
@@ -17,6 +17,7 @@
1717
from pyinfra.api.util import get_file_io, memoize
1818

1919
from .base import BaseConnector, DataMeta
20+
from .scp import SCPClient
2021
from .ssh_util import get_private_key, raise_connect_error
2122
from .sshuserclient import SSHClient
2223
from .util import (
@@ -53,6 +54,7 @@ class ConnectorData(TypedDict):
5354
ssh_connect_retries: int
5455
ssh_connect_retry_min_delay: float
5556
ssh_connect_retry_max_delay: float
57+
ssh_file_transfer_protocol: str
5658

5759

5860
connector_data_meta: dict[str, DataMeta] = {
@@ -92,9 +94,27 @@ class ConnectorData(TypedDict):
9294
"Upper bound for random delay between retries",
9395
0.5,
9496
),
97+
"ssh_file_transfer_protocol": DataMeta(
98+
"Protocol to use for file transfers. Can be ``sftp`` or ``scp``.",
99+
"sftp",
100+
),
95101
}
96102

97103

104+
class FileTransferClient(Protocol):
105+
def getfo(self, remote_filename: str, fl: IO) -> Any | None:
106+
"""
107+
Get a file from the remote host, writing to the provided file-like object.
108+
"""
109+
...
110+
111+
def putfo(self, fl: IO, remote_filename: str) -> Any | None:
112+
"""
113+
Put a file to the remote host, reading from the provided file-like object.
114+
"""
115+
...
116+
117+
98118
class SSHConnector(BaseConnector):
99119
"""
100120
Connect to hosts over SSH. This is the default connector and all targets default
@@ -268,7 +288,7 @@ def _connect(self) -> None:
268288

269289
@override
270290
def disconnect(self) -> None:
271-
self.get_sftp_connection.cache.clear()
291+
self.get_file_transfer_connection.cache.clear()
272292

273293
@override
274294
def run_shell_command(
@@ -353,23 +373,35 @@ def execute_command() -> Tuple[int, CommandOutput]:
353373
return status, combined_output
354374

355375
@memoize
356-
def get_sftp_connection(self):
376+
def get_file_transfer_connection(self) -> FileTransferClient | None:
357377
assert self.client is not None
358378
transport = self.client.get_transport()
359379
assert transport is not None, "No transport"
360380
try:
361-
return SFTPClient.from_transport(transport)
381+
if self.data["ssh_file_transfer_protocol"] == "sftp":
382+
logger.debug("Using SFTP for file transfer")
383+
return SFTPClient.from_transport(transport)
384+
elif self.data["ssh_file_transfer_protocol"] == "scp":
385+
logger.debug("Using SCP for file transfer")
386+
return SCPClient(transport)
387+
else:
388+
raise ConnectError(
389+
"Unsupported file transfer protocol: {0}".format(
390+
self.data["ssh_file_transfer_protocol"],
391+
),
392+
)
362393
except SSHException as e:
394+
363395
raise ConnectError(
364396
(
365397
"Unable to establish SFTP connection. Check that the SFTP subsystem "
366398
"for the SSH service at {0} is enabled."
367399
).format(self.host),
368400
) from e
369401

370-
def _get_file(self, remote_filename: str, filename_or_io):
402+
def _get_file(self, remote_filename: str, filename_or_io: str | IO):
371403
with get_file_io(filename_or_io, "wb") as file_io:
372-
sftp = self.get_sftp_connection()
404+
sftp = self.get_file_transfer_connection()
373405
sftp.getfo(remote_filename, file_io)
374406

375407
@override
@@ -448,7 +480,7 @@ def _put_file(self, filename_or_io, remote_location):
448480
while attempts < 3:
449481
try:
450482
with get_file_io(filename_or_io) as file_io:
451-
sftp = self.get_sftp_connection()
483+
sftp = self.get_file_transfer_connection()
452484
sftp.putfo(file_io, remote_location)
453485
return
454486
except OSError as e:

0 commit comments

Comments
 (0)