Skip to content

Commit

Permalink
add check for cloud path in pn_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
briangow committed Feb 3, 2025
1 parent e3a83b6 commit 7569042
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 14 deletions.
1 change: 1 addition & 0 deletions wfdb/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
wfdbdesc,
wfdbtime,
SIGNAL_CLASSES,
CLOUD_PROTOCOLS,
)
from wfdb.io._signal import est_res, wr_dat_file
from wfdb.io.annotation import (
Expand Down
7 changes: 6 additions & 1 deletion wfdb/io/_coreio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from wfdb.io import _url
from wfdb.io.download import config


def _open_file(
pn_dir,
file_name,
Expand Down Expand Up @@ -59,6 +58,12 @@ def _open_file(
newline=newline,
)
else:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)

url = posixpath.join(config.db_index_url, pn_dir, file_name)
return _url.openurl(
url,
Expand Down
13 changes: 12 additions & 1 deletion wfdb/io/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from wfdb.io import download, _coreio, util


MAX_I32 = 2147483647
MIN_I32 = -2147483648

Expand Down Expand Up @@ -1698,6 +1697,12 @@ def _rd_dat_file(file_name, dir_name, pn_dir, fmt, start_byte, n_samp):

# Stream dat file from PhysioNet
else:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)

dtype_in = np.dtype(DATA_LOAD_TYPES[fmt])
sig_data = download._stream_dat(
file_name, pn_dir, byte_count, start_byte, dtype_in
Expand Down Expand Up @@ -2613,6 +2618,12 @@ def _infer_sig_len(

# If the PhysioNet database path is provided, construct the download path using the database version
elif pn_dir is not None:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)

file_size = download._remote_file_size(
file_name=file_name, pn_dir=pn_dir
)
Expand Down
2 changes: 0 additions & 2 deletions wfdb/io/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from wfdb.io import _header
from wfdb.io import record
from wfdb.io import util
from wfdb.io.record import CLOUD_PROTOCOLS


class Annotation(object):
"""
Expand Down
38 changes: 28 additions & 10 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,12 @@ def rdheader(record_name, pn_dir=None, rd_segments=False):

# If the PhysioNet database path is provided, construct the download path using the database version
elif pn_dir is not None:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)

if "." not in pn_dir:
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
Expand Down Expand Up @@ -2032,11 +2038,17 @@ def rdrecord(
dir_name = os.path.abspath(dir_name)

# Read the header fields
if (pn_dir is not None) and ("." not in pn_dir):
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
dir_list[0], download.get_version(dir_list[0]), *dir_list[1:]
)
if pn_dir is not None:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)
if "." not in pn_dir:
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
dir_list[0], download.get_version(dir_list[0]), *dir_list[1:]
)

record = rdheader(record_name, pn_dir=pn_dir, rd_segments=False)

Expand Down Expand Up @@ -2320,11 +2332,17 @@ def rdsamp(
channels=[1,3])
"""
if (pn_dir is not None) and ("." not in pn_dir):
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
dir_list[0], download.get_version(dir_list[0]), *dir_list[1:]
)
if pn_dir is not None:
# check to make sure a cloud path isn't being passed under pn_dir
if any(pn_dir.startswith(proto) for proto in CLOUD_PROTOCOLS):
raise ValueError(
"Cloud paths should be passed under record_name, not under pn_dir"
)
if "." not in pn_dir:
dir_list = pn_dir.split("/")
pn_dir = posixpath.join(
dir_list[0], download.get_version(dir_list[0]), *dir_list[1:]
)

record = rdrecord(
record_name=record_name,
Expand Down

0 comments on commit 7569042

Please sign in to comment.