79 lines
3.1 KiB
Python
79 lines
3.1 KiB
Python
import threading
|
|
from src.communication import ServiceExplorationModule, ClusterCommunicationModule
|
|
import torch
|
|
import time
|
|
|
|
class NodeManager():
|
|
def __init__(self, host, port):
|
|
self.status = 'none'
|
|
self.actions = [
|
|
{'explanation': 'Add another node into our cluster', 'function': 'add_node'},
|
|
{'explanation': 'Exit', 'function': 'exit'},
|
|
]
|
|
self.get_GPU_info()
|
|
print(f"You have {self.GPU} * {self.GPU_num}")
|
|
|
|
# start Cluster Communication Module
|
|
# let the nodes in the cluster can communicate
|
|
self.cluster_communication_module = ClusterCommunicationModule(host, port, self)
|
|
|
|
# start Service Exploration Module
|
|
# let all client can know which IP address has our service so that it can link to.
|
|
self.service_exploration_module = ServiceExplorationModule(host, port+1, self)
|
|
|
|
time.sleep(2)
|
|
|
|
def get_GPU_info(self):
|
|
self.GPU_num = torch.cuda.device_count()
|
|
assert self.GPU_num > 0, "Your computer doesn't have GPU resource"
|
|
|
|
self.GPU = torch.cuda.get_device_name(0)
|
|
for i in range(self.GPU_num):
|
|
assert torch.cuda.get_device_name(i) == self.GPU, "Please provide same type of GPUs."
|
|
|
|
def start_service(self):
|
|
communication_thread = threading.Thread(target=self.cluster_communication_module.listen)
|
|
communication_thread.daemon = True
|
|
communication_thread.start()
|
|
|
|
explore_service_thread = threading.Thread(target=self.service_exploration_module.listen)
|
|
explore_service_thread.daemon = True
|
|
explore_service_thread.start()
|
|
|
|
def add_node(self):
|
|
hosts = self.service_exploration_module.explore()
|
|
if len(hosts) != 0:
|
|
msg = "These are the nodes you can request for join into our cluster: \n"
|
|
msg += '\n'.join([f'{index+1}) {host}' for index, host in enumerate(hosts)])
|
|
msg += '\n> '
|
|
|
|
choose = input(msg)
|
|
try:
|
|
choose = int(choose)-1
|
|
accept = self.cluster_communication_module.request(hosts[choose])
|
|
if accept:
|
|
exit_func = self.actions[-1]
|
|
self.actions = self.actions[:-1]
|
|
info_func = {'explanation': 'cluster info', 'function': 'cluster_info'}
|
|
if info_func not in self.actions:
|
|
self.actions.append(info_func)
|
|
self.actions.append(exit_func)
|
|
except:
|
|
print("=== FAIL ===")
|
|
else:
|
|
print("No other nodes in your subnet.")
|
|
|
|
def cluster_info(self):
|
|
info = self.cluster_communication_module.cluster_info()
|
|
print(f"\nThere are {len(info)+1} nodes in this cluster.")
|
|
print("Cluster Info:")
|
|
print(f" {self.service_exploration_module.IP}(local) -> {self.GPU} * {self.GPU_num}")
|
|
for host in info:
|
|
print(f" {host['host']} -> {host['GPU']} * {host['GPU_num']}")
|
|
|
|
def exit(self):
|
|
self.cluster_communication_module.exit()
|
|
self.service_exploration_module.exit()
|
|
|
|
|