NavGPT_explore_module/nav_src/scripts/action_planner.py
2023-10-20 03:41:33 +10:30

37 lines
912 B
Python

import json
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.prompts import PromptTemplate
from prompt.planner_prompt import (
PLANNER_PROMPT,
)
from data_utils import construct_instrs
# Using OpenAI davinci-text-003
llm = OpenAI(temperature=0.0)
plan_prompt = PromptTemplate(
template=PLANNER_PROMPT,
input_variables=["instruction"],
)
plan_chain = LLMChain(llm=llm, prompt=plan_prompt)
splits = ['val_72']
anno_dir = '../datasets/R2R/annotations'
dataset = 'R2R'
data = construct_instrs(anno_dir, dataset, splits)
for i, sample in enumerate(data):
print(f"Sample {i}:")
print(sample['instruction'])
action_plan = plan_chain.run(sample['instruction'])
print(action_plan)
data[i]['action_plan'] = action_plan
with open('../datasets/R2R/annotations/R2R_val_72_action_plan.json', 'w') as f:
json.dump(data, f, indent=2)