Files
TexTeller/texteller/utils/device.py

42 lines
1.0 KiB
Python
Raw Normal View History

2025-04-16 14:23:02 +00:00
from typing import Literal
import torch
def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device:
if device_str == "cpu":
return torch.device("cpu")
elif device_str == "cuda":
return torch.device("cuda")
elif device_str == "mps":
return torch.device("mps")
else:
raise ValueError(f"Invalid device: {device_str}")
def get_device(device_index: int = None) -> torch.device:
"""
Automatically detect the best available device for inference.
Args:
device_index: The index of GPU device to use if multiple are available.
Defaults to None, which uses the first available GPU.
Returns:
torch.device: Selected device for model inference.
"""
if cuda_available():
return str2device("cuda")
elif mps_available():
return str2device("mps")
else:
return str2device("cpu")
def cuda_available() -> bool:
return torch.cuda.is_available()
def mps_available() -> bool:
return torch.backends.mps.is_available()