|
|
|
@ -49,14 +49,12 @@ def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='', mode
|
|
|
|
|
else:
|
|
|
|
|
env_history = EnvironmentHistory(base_prompt, ob, memory, [])
|
|
|
|
|
env_history.reset()
|
|
|
|
|
# init_prompt = prompt + ob + '\n>'
|
|
|
|
|
# prompt = ''
|
|
|
|
|
if to_print:
|
|
|
|
|
print(ob)
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
cur_step = 0
|
|
|
|
|
while cur_step < 49:
|
|
|
|
|
action = llm(str(env_history) + ">", stop=['\n']).strip()
|
|
|
|
|
action = llm(str(env_history) + ">", stop=['\n'], model=model).strip()
|
|
|
|
|
env_history.add("action", action)
|
|
|
|
|
observation, reward, done, info = env.step([action])
|
|
|
|
|
observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]
|
|
|
|
|