diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 2fa0cb2c92..90a8b81fb4 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -33,17 +33,39 @@ def _get_remote_module(): class RemoteFile: - """File handle of a remote file (currently, only AWS S3 is supported).""" + """File handle of a remote file.""" - def __init__(self, url: str, nbytes: Optional[int] = None): - """Open a remote file given a bucket and object name. + def __init__(self, handle): + """Create a remote file from a Cython handle. + + This constructor should not be called directly instead use a + factory method like `RemoteFile.from_http_url()` + + Parameters + ---------- + handle : kvikio._lib.remote_handle.RemoteFile + The Cython handle + """ + assert isinstance(handle, _get_remote_module().RemoteFile) + self._handle = handle + + @classmethod + def from_http_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + """Open a http file. Parameters ---------- url URL to the remote file. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. """ - self._handle = _get_remote_module().RemoteFile.from_url(url, nbytes) + return RemoteFile(_get_remote_module().RemoteFile.from_url(url, nbytes)) def __enter__(self) -> RemoteFile: return self diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py index 9b510d042e..6bfb34c93a 100644 --- a/python/kvikio/tests/test_http_io.py +++ b/python/kvikio/tests/test_http_io.py @@ -51,7 +51,7 @@ def http_server(request, tmpdir): def test_file_size(http_server, tmpdir): a = np.arange(100) a.tofile(tmpdir / "a") - with kvikio.RemoteFile(f"{http_server}/a") as f: + with kvikio.RemoteFile.from_http_url(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes @@ -64,7 +64,7 @@ def test_read(http_server, tmpdir, xp, size, nthreads, tasksize): with kvikio.defaults.set_num_threads(nthreads): with kvikio.defaults.set_task_size(tasksize): - with kvikio.RemoteFile(f"{http_server}/a") as f: + with kvikio.RemoteFile.from_http_url(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes b = xp.empty_like(a) assert f.read(b) == a.nbytes @@ -77,7 +77,7 @@ def test_large_read(http_server, tmpdir, xp, nthreads): a.tofile(tmpdir / "a") with kvikio.defaults.set_num_threads(nthreads): - with kvikio.RemoteFile(f"{http_server}/a") as f: + with kvikio.RemoteFile.from_http_url(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes b = xp.empty_like(a) assert f.read(b) == a.nbytes @@ -88,7 +88,7 @@ def test_error_too_small_file(http_server, tmpdir, xp): a = xp.arange(10, dtype="uint8") b = xp.empty(100, dtype="uint8") a.tofile(tmpdir / "a") - with kvikio.RemoteFile(f"{http_server}/a") as f: + with kvikio.RemoteFile.from_http_url(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes with pytest.raises( ValueError, match=r"cannot read 0\+100 bytes into a 10 bytes file" @@ -105,7 +105,7 @@ def test_no_range_support(http_server, tmpdir, xp): a = xp.arange(100, dtype="uint8") a.tofile(tmpdir / "a") b = xp.empty_like(a) - with kvikio.RemoteFile(f"{http_server}/a") as f: + with kvikio.RemoteFile.from_http_url(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes with pytest.raises( OverflowError, match="maybe the server doesn't support file ranges?"