75 lines
3.0 KiB
Python
75 lines
3.0 KiB
Python
'''
|
|
Use LLM chain to summarize the observations
|
|
'''
|
|
import os
|
|
import json
|
|
import asyncio
|
|
import argparse
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.llms.openai import OpenAI
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
async def async_generate(chain, viewpointID, ob_list):
|
|
print(f"Summarizing {viewpointID} ...")
|
|
tasks = [chain.arun(description=ob) for ob in ob_list]
|
|
resp_list = await asyncio.gather(*tasks)
|
|
print(f"Summarized {viewpointID}'s observations: {resp_list}\n")
|
|
return resp_list
|
|
|
|
|
|
async def generate_concurrently(chain, obs):
|
|
tasks = [async_generate(chain, viewpointID, ob) for viewpointID, ob in obs.items()]
|
|
results = await asyncio.gather(*tasks)
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--batch_size", type=int, default=5)
|
|
parser.add_argument("--obs_dir", type=str, default="../datasets/R2R/observations_list/")
|
|
parser.add_argument("--output_dir", type=str, default="../datasets/R2R/observations_list_summarized/")
|
|
parser.add_argument("--sum_type", type=str, default="list", choices=["list", "single"])
|
|
args = parser.parse_args()
|
|
|
|
obs_dir = args.obs_dir
|
|
obs_files = os.listdir(obs_dir)
|
|
output_dir = args.output_dir
|
|
# make sure the output directory exists
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
llm = OpenAI(
|
|
temperature=0.0,
|
|
model_name="gpt-3.5-turbo",
|
|
)
|
|
|
|
if args.sum_type == "single":
|
|
summarize_prompt = PromptTemplate(
|
|
template='Given the description of a viewpoint. Summarize the scene from the viewpoint in one concise sentence.\n\nDescription:\n{description}\n\nSummarization: The scene from the viewpoint is a',
|
|
input_variables=["description"],
|
|
)
|
|
elif args.sum_type == "list":
|
|
summarize_prompt = PromptTemplate(
|
|
template='Here is a single scene view from top, down and middle:\n{description}\nSummarize the scene in one sentence:',
|
|
input_variables=["description"],
|
|
)
|
|
|
|
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt)
|
|
|
|
for obs_file in obs_files:
|
|
obs_path = os.path.join(obs_dir, obs_file)
|
|
with open(obs_path) as f:
|
|
obs = json.load(f)
|
|
summary = {}
|
|
viewpointIDs = list(obs.keys())
|
|
# Get the viewpointIDs in batches
|
|
for i in range(0, len(viewpointIDs), args.batch_size):
|
|
batch = viewpointIDs[i:i+args.batch_size]
|
|
print(f"Summarizing scan {obs_file.split('.')[0]} batch [{i//args.batch_size}/{len(viewpointIDs)//args.batch_size}]")
|
|
batch_obs = {viewpointID:obs[viewpointID] for viewpointID in batch}
|
|
summarized_obs = asyncio.run(generate_concurrently(summarize_chain, batch_obs))
|
|
summarized_obs = {viewpointID: summarized_obs[i] for i, viewpointID in enumerate(batch)}
|
|
summary.update(summarized_obs)
|
|
output_path = os.path.join(output_dir, f'{obs_file}.json')
|
|
with open(output_path, 'w') as f:
|
|
json.dump(summary, f, indent=2) |