feat: get GPU name in the node
This commit is contained in:
parent
68123809bb
commit
847b6f8814
@ -1,8 +1,8 @@
|
||||
import threading
|
||||
from src.communication import ServiceExplorationModule, ClusterCommunicationModule
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class NodeManager():
|
||||
def __init__(self, host, port):
|
||||
self.status = 'none'
|
||||
@ -10,8 +10,8 @@ class NodeManager():
|
||||
{'explanation': 'Add another node into our cluster', 'function': 'add_node'},
|
||||
{'explanation': 'Exit', 'function': 'exit'},
|
||||
]
|
||||
self.GPU = 'RTX 4090'
|
||||
self.GPU_num = 1
|
||||
self.get_GPU_info()
|
||||
print(f"You have {self.GPU} * {self.GPU_num}")
|
||||
|
||||
# start Cluster Communication Module
|
||||
# let the nodes in the cluster can communicate
|
||||
@ -22,6 +22,14 @@ class NodeManager():
|
||||
self.service_exploration_module = ServiceExplorationModule(host, port+1, self)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
def get_GPU_info(self):
|
||||
self.GPU_num = torch.cuda.device_count()
|
||||
assert self.GPU_num > 0, "Your computer doesn't have GPU resource"
|
||||
|
||||
self.GPU = torch.cuda.get_device_name(0)
|
||||
for i in range(self.GPU_num):
|
||||
assert torch.cuda.get_device_name(i) == self.GPU, "Please provide same type of GPUs."
|
||||
|
||||
def start_service(self):
|
||||
communication_thread = threading.Thread(target=self.cluster_communication_module.listen)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user