Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SFTPFile.seek return the new position and .write return the number of written bytes #1696

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions fsspec/implementations/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@
logger = logging.getLogger("fsspec.sftp")


def _patch_SFTPFile(file):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, ugly :( (not a criticism of your work, just a shame)
Does this create reference cycles on self and the closure?

Other options might be:

  • patch the upstream ftp.open to return a subclass FsspecSShFile(SFTPFile, IOBase)
  • make our own class with the upstream file object as an attribute and pass all the methods through except for these two

Do these sound more palatable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I kinda agree that it is not the most beautiful. I also considered a separate file object class, but I didn't want to copy 10+ methods with the danger of missing some. Subclassing also does not work because the object is created inside SFTPClient.open.

I have also opened a PR to fix this upstream. I guess it is fine to wait for that to be merged and simply close this PR and issue here. I might have been too impatient. Fixing this here has the advantage that it will work with older versions of paramiko just fine, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I'm glad you made that PR. In that case let's wait - it doesn't appear that this part of the API gets touched much.

"""
This patcher tries to rectify the file API of paramiko.sftp_file.SFTPFile
to be more consistent with io.IOBase.
https://github.com/paramiko/paramiko/issues/2452
"""

if getattr(file, "_patched_by_fsspec", False):
return file
file._patched_by_fsspec = True

self = file

real_seek = self.seek
def seek(offset: int, whence: int = 0) -> int:
result = real_seek(offset, whence)
return self.tell() if result is None else result
self.seek = seek

real_write = self.write
def write(data) -> int:
old_offset = self.tell()
result = real_write(data)
return self.tell() - old_offset if result is None else result
self.write = write

return self


class SFTPFileSystem(AbstractFileSystem):
"""Files over SFTP/SSH

Expand Down Expand Up @@ -141,6 +170,9 @@ def get_file(self, rpath, lpath, **kwargs):
else:
self.ftp.get(self._strip_protocol(rpath), lpath)

def _open_patched(self, path, mode="rb", **kwargs):
return _patch_SFTPFile(self.ftp.open(path, mode, **kwargs))

def _open(self, path, mode="rb", block_size=None, **kwargs):
"""
block_size: int or None
Expand All @@ -151,14 +183,14 @@ def _open(self, path, mode="rb", block_size=None, **kwargs):
if kwargs.get("autocommit", True) is False:
# writes to temporary file, move on commit
path2 = "/".join([self.temppath, str(uuid.uuid4())])
f = self.ftp.open(path2, mode, bufsize=block_size if block_size else -1)
f = self._open_patched(path2, mode, bufsize=block_size if block_size else -1)
f.temppath = path2
f.targetpath = path
f.fs = self
f.commit = types.MethodType(commit_a_file, f)
f.discard = types.MethodType(discard_a_file, f)
else:
f = self.ftp.open(path, mode, bufsize=block_size if block_size else -1)
f = self._open_patched(path, mode, bufsize=block_size if block_size else -1)
return f

def _rm(self, path):
Expand Down