148 lines
5.8 KiB
Python
148 lines
5.8 KiB
Python
import threading
|
|
from src.communication import ServiceExplorationModule, ClusterCommunicationModule
|
|
import torch
|
|
import time
|
|
import docker
|
|
from web3 import Web3
|
|
from src.scheduler import Scheduler
|
|
from constant import *
|
|
|
|
class NodeManager():
|
|
def __init__(self, host, port):
|
|
self.status = 'none'
|
|
self.actions = [
|
|
{'explanation': 'Add another node into our cluster', 'function': 'add_node'},
|
|
{'explanation': 'Start waiting for the new task', 'function': 'start_listen_task'},
|
|
{'explanation': 'Exit', 'function': 'exit'},
|
|
]
|
|
self.get_GPU_info()
|
|
print(f"You have {self.GPU} * {self.GPU_num}")
|
|
|
|
# start Cluster Communication Module
|
|
# let the nodes in the cluster can communicate
|
|
self.cluster_communication_module = ClusterCommunicationModule(host, port, self)
|
|
|
|
# start Service Exploration Module
|
|
# let all client can know which IP address has our service so that it can link to.
|
|
self.service_exploration_module = ServiceExplorationModule(host, port+1, self)
|
|
|
|
# docker client
|
|
self.docker_client = docker.from_env()
|
|
|
|
# web3 provider
|
|
# if this is master, it should init a Web object.
|
|
self.w3 = None
|
|
self.scheduler = None
|
|
self.wallet = None
|
|
|
|
time.sleep(2)
|
|
|
|
def get_GPU_info(self):
|
|
self.GPU_num = torch.cuda.device_count()
|
|
assert self.GPU_num > 0, "Your computer doesn't have GPU resource"
|
|
|
|
self.GPU = torch.cuda.get_device_name(0)
|
|
for i in range(self.GPU_num):
|
|
assert torch.cuda.get_device_name(i) == self.GPU, "Please provide same type of GPUs."
|
|
|
|
def start_service(self):
|
|
communication_thread = threading.Thread(target=self.cluster_communication_module.listen)
|
|
communication_thread.daemon = True
|
|
communication_thread.start()
|
|
|
|
explore_service_thread = threading.Thread(target=self.service_exploration_module.listen)
|
|
explore_service_thread.daemon = True
|
|
explore_service_thread.start()
|
|
|
|
def add_node(self):
|
|
hosts = self.service_exploration_module.explore()
|
|
if len(hosts) != 0:
|
|
msg = "These are the nodes you can request for join into our cluster: \n"
|
|
msg += '\n'.join([f'{index+1}) {host}' for index, host in enumerate(hosts)])
|
|
msg += '\n> '
|
|
|
|
choose = input(msg)
|
|
try:
|
|
choose = int(choose)-1
|
|
accept = self.cluster_communication_module.request(hosts[choose])
|
|
if accept:
|
|
self.actions = [
|
|
{'explanation': 'Add another node into our cluster', 'function': 'add_node'},
|
|
{'explanation': 'Cluster info', 'function': 'cluster_info'},
|
|
{'explanation': 'Start waiting for the new task', 'function': 'start_listen_task'},
|
|
{'explanation': 'Exit', 'function': 'exit'},
|
|
]
|
|
except:
|
|
print("=== FAIL ===")
|
|
else:
|
|
print("No other nodes in your subnet.")
|
|
|
|
def start_listen_task(self):
|
|
self.w3 = Web3(Web3.HTTPProvider(WEB3_PROVIDER_URL + WEB3_PROVIDER_KEY))
|
|
self.scheduler = Scheduler(self.w3, SCHEDULER_ADDR, SCHEDULER_ABI_FILE)
|
|
self.wallet = self.w3.eth.account.from_key(WALLET_KEY)
|
|
print(f"We have use {WEB3_PROVIDER_URL+WEB3_PROVIDER_KEY} as the web3 provider.")
|
|
print(f"And we have load your wallet private key {WALLET_KEY} (address={self.wallet.address})")
|
|
print()
|
|
if self.w3.is_connected():
|
|
'''
|
|
print("[INFO] Connected Successfully.")
|
|
print()
|
|
|
|
# Register the cluster
|
|
gpu_num = self.cluster_info()
|
|
gpu_id = GPU_NAME2ID[self.GPU]
|
|
print(f"\nWe will register this cluster({self.GPU} * {gpu_num})...")
|
|
receipt = self.scheduler.register_cluster(self.wallet, gpu_id, gpu_num)
|
|
event = self.scheduler.get_cluster_event(receipt)
|
|
print("\n[INFO] Register our cluster succefully. \nThis is our cluster event on the blockchain: ")
|
|
print(f" {event[0]['args']}")
|
|
|
|
# start waiting
|
|
self.cluster_communication_module.start_listen_task()
|
|
print("\nWaiting for the new task from Sepolia testnet...")
|
|
print("Ctrl+C to stop the waiting...")
|
|
try:
|
|
next_task = self.scheduler.listen_task(self.wallet.address)
|
|
|
|
except:
|
|
print("[INFO] stop the waiting")
|
|
return
|
|
|
|
# get task info
|
|
task_index = next_task['args']['taskIndex']
|
|
data_image = next_task['args']['dataImage']
|
|
train_image = next_task['args']['trainImage']
|
|
|
|
print("\n[INFO] You Receive a new task:")
|
|
print(f" - Download Image: {data_image}")
|
|
print(f" - Training Image: {train_image}")
|
|
'''
|
|
|
|
data_image = "test/test"
|
|
|
|
# Start Downloading
|
|
self.cluster_communication_module.run_container(data_image)
|
|
|
|
|
|
else:
|
|
print("[ERROR] Connected Failed.")
|
|
print("Please check for your provider key & wallet key")
|
|
|
|
|
|
def cluster_info(self):
|
|
info = self.cluster_communication_module.cluster_info()
|
|
print(f"\nThere are {len(info)+1} nodes in this cluster.")
|
|
print("Cluster Info:")
|
|
print(f" {self.service_exploration_module.host}(local) -> {self.GPU} * {self.GPU_num}")
|
|
GPU_num = self.GPU_num
|
|
for host in info:
|
|
GPU_num += host['GPU_num']
|
|
print(f" {host['host']} -> {host['GPU']} * {host['GPU_num']}")
|
|
return GPU_num
|
|
|
|
def exit(self):
|
|
self.cluster_communication_module.exit()
|
|
self.service_exploration_module.exit()
|
|
|