NavGPT_explore_module/nav_src/extract_obj.py
2024-04-28 20:44:10 +08:00

77 lines
1.8 KiB
Python

import openai
from openai import OpenAI
import os
import json
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
PROMPT_TEMPLATE = """
You are a good housework assistant, please help me to find the target object in a housework instruction.
You will receive a housework instruction, and you need to return the target object and its location.
For example:
Input:
{
"instruction": "Enter the kitchen and pick up the cup on the table"
}
Output:
{
"target": "the cup on the table"
}
Now, it's your turn:
Input:
{
"instruction": ___input___
}
Output:
"""
def query(openai: OpenAI, prompt: str):
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_format={ "type": "json_object" },
messages=[
{"role": "system", "content": "Please output JSON."},
{"role": "user", "content": prompt}
]
)
return (
json.loads(response.choices[0].message.content),
response.usage.total_tokens
)
def load_json(filename):
with open(filename) as fp:
data = json.load(fp)
return data
def dump_json(data, filename):
with open(filename, 'w') as fp:
json.dump(data, fp)
if __name__ == '__main__':
client = OpenAI(api_key=OPENAI_API_KEY)
data = load_json('../datasets/REVERIE/annotations/REVERIE_val_unseen_instr.json')
for index, i in enumerate(data):
instr = i['instruction']
prompt = PROMPT_TEMPLATE.replace('___input___', instr)
OK = False
while not OK:
response, token = query(client, prompt)
if 'target' in response:
target = response['target']
OK = True
i['target'] = target
print(instr)
print(target)
print()
dump_json(data, 'new_REVERIE_val_unseen_instr.json')