Skip to content

Commit

Permalink
It is now possible to upload()/download() with non-seekable file-like…
Browse files Browse the repository at this point in the history
… objects
  • Loading branch information
ivknv committed Oct 15, 2023
1 parent cefb96e commit 71b37d2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
28 changes: 28 additions & 0 deletions tests/yadisk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,31 @@ def test_ensure_path_has_schema(self):
self.assertEqual(ensure_path_has_schema("/asd:123", "trash"), "trash:/asd:123")
self.assertEqual(ensure_path_has_schema("example/path"), "disk:/example/path")
self.assertEqual(ensure_path_has_schema("app:/test"), "app:/test")

@async_test
async def test_upload_download_non_seekable(self):
# It should be possible to upload/download non-seekable file objects (such as stdin/stdout)
# See https://github.com/ivknv/yadisk/pull/31 for more details

test_input_file = BytesIO(b"0" * 1000)
test_input_file.seekable = lambda: False

def seek(*args, **kwargs):
raise NotImplementedError

test_input_file.seek = seek

dst_path = posixpath.join(self.path, "zeroes.txt")

await self.yadisk.upload(test_input_file, dst_path, n_retries=50)

test_output_file = BytesIO()
test_output_file.seekable = lambda: False
test_output_file.seek = seek

await self.yadisk.download(dst_path, test_output_file, n_retries=50)

await self.yadisk.remove(dst_path, permanently=True)

self.assertEqual(test_input_file.tell(), 1000)
self.assertEqual(test_output_file.tell(), 1000)
62 changes: 46 additions & 16 deletions yadisk_async/yadisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,28 @@ def is_async_file(file: Any) -> bool:

return is_async_func(read_method)

async def _file_tell(file: Any) -> int:
if is_async_func(file.tell):
return await file.tell()
else:
return file.tell()

async def _file_seek(file: Any, offset: int, whence: int = 0) -> int:
if is_async_func(file.seek):
return await file.seek(offset, whence)
else:
return file.seek(offset, whence)

async def _is_file_seekable(file: Any) -> bool:
if not hasattr(file, "seekable"):
# Assume the file is seekable if there's no way to check
return True

if is_async_func(file.seekable):
return await file.seekable()

return file.seekable();

def _apply_default_args(args: Dict[str, Any], default_args: Dict[str, Any]) -> None:
new_args = dict(default_args)
new_args.update(args)
Expand Down Expand Up @@ -682,11 +704,17 @@ async def _upload(self,
if n_retries is None:
n_retries = settings.DEFAULT_N_RETRIES

# Number of retries for getting the upload link.
# It is set to 0, unless the file is not seekable, in which case
# we have to use a different retry scheme
n_retries_for_upload_link = 0

kwargs["timeout"] = timeout

file = None
close_file = False
generator_factory: Optional[Callable[[], AsyncGenerator]] = None
file_position = 0

session = self.get_session()

Expand All @@ -701,14 +729,14 @@ async def _upload(self,
file = file_or_path

if generator_factory is None:
if is_async_func(file.tell):
file_position = await file.tell()
if await _is_file_seekable(file):
file_position = await _file_tell(file)
else:
file_position = file.tell()
n_retries, n_retries_for_upload_link = 0, n_retries

async def attempt():
temp_kwargs = dict(kwargs)
temp_kwargs["n_retries"] = 0
temp_kwargs["n_retries"] = n_retries_for_upload_link
temp_kwargs["retry_interval"] = 0.0

link = await get_upload_link_function(dst_path, **temp_kwargs)
Expand All @@ -725,10 +753,8 @@ async def attempt():
data = None

if generator_factory is None:
if is_async_func(file.seek):
await file.seek(file_position)
else:
file.seek(file_position)
if await _is_file_seekable(file):
await _file_seek(file, file_position)

if is_async_func(file.read):
data = read_in_chunks(file)
Expand Down Expand Up @@ -838,6 +864,11 @@ async def _download(self,
if n_retries is None:
n_retries = settings.DEFAULT_N_RETRIES

# Number of retries for getting the download link.
# It is set to 0, unless the file is not seekable, in which case
# we have to use a different retry scheme
n_retries_for_download_link = 0

retry_interval = kwargs.get("retry_interval")

if retry_interval is None:
Expand All @@ -852,6 +883,7 @@ async def _download(self,

file = None
close_file = False
file_position = 0

session = self.get_session()

Expand All @@ -863,14 +895,14 @@ async def _download(self,
close_file = False
file = file_or_path

if is_async_func(file.tell):
file_position = await file.tell()
if await _is_file_seekable(file):
file_position = await _file_tell(file)
else:
file_position = file.tell()
n_retries, n_retries_for_download_link = 0, n_retries

async def attempt() -> None:
temp_kwargs = dict(kwargs)
temp_kwargs["n_retries"] = 0
temp_kwargs["n_retries"] = n_retries_for_download_link
temp_kwargs["retry_interval"] = 0.0
link = await get_download_link_function(src_path, **temp_kwargs)

Expand All @@ -883,10 +915,8 @@ async def attempt() -> None:
except KeyError:
temp_kwargs["headers"] = {"Connection": "close"}

if is_async_func(file.seek):
await file.seek(file_position)
else:
file.seek(file_position)
if await _is_file_seekable(file):
await _file_seek(file, file_position)

async with session.get(link, **temp_kwargs) as response:
async for chunk in response.content.iter_chunked(8192):
Expand Down

0 comments on commit 71b37d2

Please sign in to comment.