diff --git a/src/communication.py b/src/communication.py index 1b9bd42..99fa584 100644 --- a/src/communication.py +++ b/src/communication.py @@ -223,11 +223,11 @@ class ClusterCommunicationModule(): docker.types.DeviceRequest(count=-1, capabilities=[['gpu']]) ], name=f'train-{train_args["index"]}', - env={ + environment={ 'GPU_NUM': self.node_manager.GPU_num, 'NODE_NUM': train_args['node_num'], 'NODE_RANK': train_args['index'], - 'MASTER_IP': 'train-0', + 'MASTER_IP': 'train-0' if self.node_manager.status == 'worker' else '127.0.0.1', 'MASTER_PORT': 21046, }, detach=True @@ -241,7 +241,7 @@ class ClusterCommunicationModule(): status_code = result['StatusCode'] print(status_code, type(status_code)) - def scatter_container(self, image_name, train=False): + def scatter_container(self, image_name, train): def master_run(image_name): print("[Master] run") train_args = { diff --git a/src/node_manager.py b/src/node_manager.py index 1d0d5c9..460b158 100644 --- a/src/node_manager.py +++ b/src/node_manager.py @@ -120,9 +120,13 @@ class NodeManager(): ''' data_image = "snsd0805/cifar100-dataset:v1" + train_image = "snsd0805/cifar100-train:v3" # Start Downloading - self.cluster_communication_module.scatter_container(data_image) + # self.cluster_communication_module.scatter_container(data_image, train=False) + + # start training + self.cluster_communication_module.scatter_container(train_image, train=True) else: