-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bump version, merge pull request #5 from AMYPAD/cuda
- Loading branch information
Showing
9 changed files
with
218 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
"""CUDA helpers | ||
Usage: | ||
miutil.cuinfo [options] | ||
Options: | ||
-n, --num-devices : print number of devices (ignores `-d`) | ||
-f, --nvcc-flags : print out flags for use nvcc compilation | ||
-d ID, --dev-id ID : select device ID [default: None:int] for all | ||
""" | ||
from argopt import argopt | ||
import pynvml | ||
|
||
__all__ = ["num_devices", "compute_capability", "memory", "name", "nvcc_flags"] | ||
|
||
|
||
def nvmlDeviceGetCudaComputeCapability(handle): | ||
major = pynvml.c_int() | ||
minor = pynvml.c_int() | ||
fn = pynvml.get_func_pointer("nvmlDeviceGetCudaComputeCapability") | ||
ret = fn(handle, pynvml.byref(major), pynvml.byref(minor)) | ||
pynvml.check_return(ret) | ||
return [major.value, minor.value] | ||
|
||
|
||
def num_devices(): | ||
"""returns total number of devices""" | ||
pynvml.nvmlInit() | ||
return pynvml.nvmlDeviceGetCount() | ||
|
||
|
||
def get_handle(dev_id=-1): | ||
"""allows negative indexing""" | ||
pynvml.nvmlInit() | ||
dev_id = num_devices() + dev_id if dev_id < 0 else dev_id | ||
try: | ||
return pynvml.nvmlDeviceGetHandleByIndex(dev_id) | ||
except pynvml.NVMLError: | ||
raise IndexError("invalid dev_id") | ||
|
||
|
||
def compute_capability(dev_id=-1): | ||
"""returns compute capability (major, minor)""" | ||
return tuple(nvmlDeviceGetCudaComputeCapability(get_handle(dev_id))) | ||
|
||
|
||
def memory(dev_id=-1): | ||
"""returns memory (total, free, used)""" | ||
mem = pynvml.nvmlDeviceGetMemoryInfo(get_handle(dev_id)) | ||
return (mem.total, mem.free, mem.used) | ||
|
||
|
||
def name(dev_id=-1): | ||
"""returns device name""" | ||
return pynvml.nvmlDeviceGetName(get_handle(dev_id)).decode("U8") | ||
|
||
|
||
def nvcc_flags(dev_id=-1): | ||
return "-gencode=arch=compute_{0:d}{1:d},code=compute_{0:d}{1:d}".format( | ||
*compute_capability(dev_id) | ||
) | ||
|
||
|
||
def main(*args, **kwargs): | ||
args = argopt(__doc__).parse_args(*args, **kwargs) | ||
noargs = True | ||
devices = range(num_devices()) if args.dev_id is None else [args.dev_id] | ||
|
||
if args.num_devices: | ||
print(num_devices()) | ||
noargs = False | ||
if args.nvcc_flags: | ||
print(" ".join(sorted(set(map(nvcc_flags, devices)))[::-1])) | ||
noargs = False | ||
if noargs: | ||
for dev_id in devices: | ||
print( | ||
"Device {:2d}:{}:compute capability:{:d}.{:d}".format( | ||
dev_id, name(dev_id), *compute_capability(dev_id) | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from pytest import importorskip, raises | ||
|
||
cuinfo = importorskip("miutil.cuinfo") | ||
|
||
|
||
def test_num_devices(): | ||
devices = cuinfo.num_devices() | ||
assert isinstance(devices, int) | ||
|
||
|
||
def test_compute_capability(): | ||
cc = cuinfo.compute_capability() | ||
if cc: | ||
assert len(cc) == 2, cc | ||
assert all(isinstance(i, int) for i in cc), cc | ||
|
||
|
||
def test_memory(): | ||
mem = cuinfo.memory() | ||
if mem: | ||
assert len(mem) == 3, mem | ||
assert all(isinstance(i, int) for i in mem), mem | ||
|
||
|
||
def test_cuinfo_cli(capsys): | ||
cuinfo.main(["--num-devices"]) | ||
out, _ = capsys.readouterr() | ||
devices = int(out) | ||
assert devices >= 0 | ||
|
||
# individual dev_id | ||
for dev_id in range(devices): | ||
cuinfo.main(["--dev-id", str(dev_id)]) | ||
out, _ = capsys.readouterr() | ||
assert len(out.split("Device ")) == devices + 1 | ||
|
||
# all dev_ids | ||
cuinfo.main() | ||
out, _ = capsys.readouterr() | ||
assert len(out.split("Device ")) == devices + 1 | ||
|
||
# dev_id one too much | ||
with raises(IndexError): | ||
cuinfo.main(["--dev-id", str(devices)]) | ||
|
||
cuinfo.main(["--nvcc-flags"]) | ||
out, _ = capsys.readouterr() | ||
assert not devices or out.startswith("-gencode=") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from os import path | ||
|
||
from miutil import fdio | ||
|
||
|
||
def test_create_dir(tmp_path): | ||
tmpdir = str(tmp_path / "create_dir") | ||
assert not path.exists(tmpdir) | ||
fdio.create_dir(tmpdir) | ||
assert path.exists(tmpdir) | ||
|
||
|
||
def test_hasext(): | ||
for fname, ext in [ | ||
("foo.bar", ".bar"), | ||
("foo.bar", "bar"), | ||
("foo.bar.baz", "baz"), | ||
("foo/bar.baz", "baz"), | ||
]: | ||
assert fdio.hasext(fname, ext) | ||
|
||
for fname, ext in [ | ||
("foo.bar", "baz"), | ||
("foo.bar.baz", "bar.baz"), | ||
("foo", "foo"), | ||
]: | ||
assert not fdio.hasext(fname, ext) | ||
|
||
|
||
def test_tmpdir(): | ||
with fdio.tmpdir() as tmpdir: | ||
assert path.exists(tmpdir) | ||
res = tmpdir | ||
assert not path.exists(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters