From 9e5d2f95ba105a2b7689da933cae08d29a27d49a Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Sun, 12 Nov 2023 22:29:08 +0800 Subject: [PATCH] fix: train set can get 1.0 found score but unseen val only get about 0.44 found score I ignore the target object when this is the adversarial instruction --- r2r_src/agent.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 93d6a72..c1e3eb6 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -231,7 +231,7 @@ class Seq2SeqAgent(BaseAgent): else: candidate_objs = ob['candidate_obj'][2] for k, kid in enumerate(candidate_objs): - if kid == ob['objId']: + if kid == ob['objId'] and ob['found']: a[i] = k break else: @@ -454,10 +454,28 @@ class Seq2SeqAgent(BaseAgent): ''' ''' Supervised training for REF ''' + if train_ml is not None: target_obj = self._teacher_REF(perm_obs, just_ended) ref_loss += self.criterion_REF(logit_REF, target_obj) + ''' + print("LENG ", candidate_leng) + print("TARGET:", target) + print("OBJ: ", target_obj) + print() + for index, ob in enumerate(perm_obs): + if target[index] == visual_temp_mask.size(1): + output = -1 + elif target[index] == (candidate_leng[index]-1): + output = -2 + else: + output = target[index].item() + print(index, candidate_leng[index], cpu_a_t[index], a_t.cpu().numpy()[index], target[index].item(), output) + print(target_obj[index].item(), ob['found']) + + print() + ''' # print('logit', logit) # print('logit_REF', logit_REF) # print('just_ended', just_ended)