NavGPT_explore_module/nav_src/scripts/obs_summarizer.py
2023-10-20 03:41:33 +10:30

75 lines
3.0 KiB
Python

'''
Use LLM chain to summarize the observations
'''
import os
import json
import asyncio
import argparse
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.prompts import PromptTemplate
async def async_generate(chain, viewpointID, ob_list):
print(f"Summarizing {viewpointID} ...")
tasks = [chain.arun(description=ob) for ob in ob_list]
resp_list = await asyncio.gather(*tasks)
print(f"Summarized {viewpointID}'s observations: {resp_list}\n")
return resp_list
async def generate_concurrently(chain, obs):
tasks = [async_generate(chain, viewpointID, ob) for viewpointID, ob in obs.items()]
results = await asyncio.gather(*tasks)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=5)
parser.add_argument("--obs_dir", type=str, default="../datasets/R2R/observations_list/")
parser.add_argument("--output_dir", type=str, default="../datasets/R2R/observations_list_summarized/")
parser.add_argument("--sum_type", type=str, default="list", choices=["list", "single"])
args = parser.parse_args()
obs_dir = args.obs_dir
obs_files = os.listdir(obs_dir)
output_dir = args.output_dir
# make sure the output directory exists
os.makedirs(output_dir, exist_ok=True)
llm = OpenAI(
temperature=0.0,
model_name="gpt-3.5-turbo",
)
if args.sum_type == "single":
summarize_prompt = PromptTemplate(
template='Given the description of a viewpoint. Summarize the scene from the viewpoint in one concise sentence.\n\nDescription:\n{description}\n\nSummarization: The scene from the viewpoint is a',
input_variables=["description"],
)
elif args.sum_type == "list":
summarize_prompt = PromptTemplate(
template='Here is a single scene view from top, down and middle:\n{description}\nSummarize the scene in one sentence:',
input_variables=["description"],
)
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
for obs_file in obs_files:
obs_path = os.path.join(obs_dir, obs_file)
with open(obs_path) as f:
obs = json.load(f)
summary = {}
viewpointIDs = list(obs.keys())
# Get the viewpointIDs in batches
for i in range(0, len(viewpointIDs), args.batch_size):
batch = viewpointIDs[i:i+args.batch_size]
print(f"Summarizing scan {obs_file.split('.')[0]} batch [{i//args.batch_size}/{len(viewpointIDs)//args.batch_size}]")
batch_obs = {viewpointID:obs[viewpointID] for viewpointID in batch}
summarized_obs = asyncio.run(generate_concurrently(summarize_chain, batch_obs))
summarized_obs = {viewpointID: summarized_obs[i] for i, viewpointID in enumerate(batch)}
summary.update(summarized_obs)
output_path = os.path.join(output_dir, f'{obs_file}.json')
with open(output_path, 'w') as f:
json.dump(summary, f, indent=2)