fix wrong condition

main
Noah Shinn 10 months ago
parent 7e8b29a5bb
commit 59db1fb92a

@ -49,14 +49,12 @@ def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='', mode
else:
env_history = EnvironmentHistory(base_prompt, ob, memory, [])
env_history.reset()
# init_prompt = prompt + ob + '\n>'
# prompt = ''
if to_print:
print(ob)
sys.stdout.flush()
cur_step = 0
while cur_step < 49:
action = llm(str(env_history) + ">", stop=['\n']).strip()
action = llm(str(env_history) + ">", stop=['\n'], model=model).strip()
env_history.add("action", action)
observation, reward, done, info = env.step([action])
observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]

@ -3,9 +3,6 @@ import random
from typing import Union, List, Optional, Callable
# openai.api_key = os.getenv("OPENAI_API_KEY")
def generic_generate_func_impl(
func_sig: str,

@ -123,7 +123,7 @@ class StarChat(ModelBase):
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}<|end|>\n"
if i != len(messages) - 1:
if i == len(messages) - 1:
prompt += "\n<|assistant|>"
outputs = self.pipe(

Loading…
Cancel
Save