Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMD Support #173

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions gpustat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@
from blessed import Terminal

from gpustat import util
from gpustat import nvml
from gpustat.nvml import pynvml as N
from gpustat.nvml import check_driver_nvml_version

if util.has_AMD():
from gpustat import rocml as nvml
from gpustat import rocml as N
from gpustat.rocml import check_driver_nvml_version
else:
from gpustat import nvml
from gpustat.nvml import pynvml as N
from gpustat.nvml import check_driver_nvml_version

NOT_SUPPORTED = 'Not Supported'
MB = 1024 * 1024
Expand Down Expand Up @@ -612,6 +618,8 @@ def _wrapped(*args, **kwargs):
gpu_stat = InvalidGPU(index, "((Unknown Error))", e)
except N.NVMLError_GpuIsLost as e:
gpu_stat = InvalidGPU(index, "((GPU is lost))", e)
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise the N.NVMLError_Unknown Error for consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ps: we can catch NVMLError instead of Base Exception, since you may ignore some python native errors

gpu_stat = InvalidGPU(index, "((Unknown Error))", e)

if isinstance(gpu_stat, InvalidGPU):
log.add_exception("GPU %d" % index, gpu_stat.exception)
Expand Down
181 changes: 181 additions & 0 deletions gpustat/rocml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""Imports rocmi and wraps it in a pynvml compatible interface."""

import sys
import textwrap
import warnings

from collections import namedtuple

try:
# Check for rocmi.
import rocmi
except (ImportError, SyntaxError, RuntimeError) as e:
_rocmi = sys.modules.get("rocmi", None)

raise ImportError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a dedicated NVMLError subclass?

textwrap.dedent(
"""\
rocmi is missing or an outdated version is installed.

The root cause: """
+ str(e)
+ """

Your rocmi installation: """
+ repr(_rocmi)
+ """

-----------------------------------------------------------
(Suggested Fix) Please install rocmi using pip.
"""
)
) from e

NVML_TEMPERATURE_GPU = 1


class NVMLError(Exception):
def __init__(self, message="ROCM Error"):
self.message = message
super().__init__(self.message)


class NVMLError_Unknown(Exception):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these NVMLError_xxx inherit NVMLError?

def __init__(self, message="An unknown ROCM Error has occurred"):
self.message = message
super().__init__(self.message)


class NVMLError_GpuIsLost(Exception):
def __init__(self, message="ROCM Device is lost."):
self.message = message
super().__init__(self.message)


def nvmlDeviceGetCount():
return len(rocmi.get_devices())


def nvmlDeviceGetHandleByIndex(dev):
return rocmi.get_devices()[dev]


def nvmlDeviceGetIndex(handle):
for i, d in enumerate(rocmi.get_devices()):
if d.bus_id == handle.bus_id:
return i

return -1


def nvmlDeviceGetName(handle):
return handle.name


def nvmlDeviceGetUUID(handle):
return handle.unique_id


def nvmlDeviceGetTemperature(handle, loc=NVML_TEMPERATURE_GPU):
metrics = handle.get_metrics()
return metrics.temperature_hotspot


def nvmlSystemGetDriverVersion():
retval = rocmi.get_driver_version()
if retval is None:
return ""
return retval


def check_driver_nvml_version(driver_version_str: str):
"""Show warnings when an incompatible driver is used."""

def safeint(v) -> int:
try:
return int(v)
except (ValueError, TypeError):
return 0

driver_version = tuple(safeint(v) for v in driver_version_str.strip().split("."))

if len(driver_version) == 0 or driver_version <= (0,):
return
if driver_version < (6, 7, 8):
warnings.warn(f"This version of ROCM Driver {driver_version_str} is untested, ")


def nvmlDeviceGetFanSpeed(handle):
try:
speed = handle.get_metrics().current_fan_speed
except AttributeError:
return None

return speed


MemoryInfo = namedtuple("MemoryInfo", ["total", "used"])


def nvmlDeviceGetMemoryInfo(handle):

return MemoryInfo(
total=handle.vram_total,
used=handle.vram_used,
)


UtilizationRates = namedtuple("UtilizationRates", ["gpu"])


def nvmlDeviceGetUtilizationRates(handle):
metrics = handle.get_metrics()
return UtilizationRates(gpu=metrics.average_gfx_activity)


def nvmlDeviceGetEncoderUtilization(dev):
return None


def nvmlDeviceGetDecoderUtilization(dev):
return None


def nvmlDeviceGetPowerUsage(handle):
return handle.current_power / 1000


def nvmlDeviceGetEnforcedPowerLimit(handle):
return handle.power_limit / 1000


ComputeProcess = namedtuple("ComputeProcess", ["pid", "usedGpuMemory"])


def nvmlDeviceGetComputeRunningProcesses(handle):
results = handle.get_processes()
return [ComputeProcess(pid=x.pid, usedGpuMemory=x.vram_usage) for x in results]


def nvmlDeviceGetGraphicsRunningProcesses(dev):
return None


def nvmlDeviceGetClockInfo(handle):
metrics = handle.get_metrics()

try:
clk = metrics.current_gfxclks[0]
except AttributeError:
clk = metrics.current_gfxclk

return clk


def nvmlDeviceGetMaxClockInfo(handle):
return handle.get_clock_info()[-1]


# rocmi does not require initialization
def ensure_initialized():
pass
9 changes: 9 additions & 0 deletions gpustat/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import os.path
import subprocess
import sys
import traceback
from typing import Callable, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -101,3 +102,11 @@ def report_summary(self, concise=True):
self._write("{msg} -> Total {value} occurrences.".format(
msg=msg, value=value))
self._write('')


def has_AMD():
try:
subprocess.check_output('rocm-smi')
Stonesjtu marked this conversation as resolved.
Show resolved Hide resolved
return True
except Exception:
return False
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def run(self):

install_requires = [
'nvidia-ml-py>=12.535.108', # see #107, #143, #161
'rocmi>=0.3', # see #137
Stonesjtu marked this conversation as resolved.
Show resolved Hide resolved
'psutil>=5.6.0', # GH-1447
'blessed>=1.17.1', # GH-126
'typing_extensions',
Expand Down
Loading