Skip to content

Commit

Permalink
fix wrong condition
Browse files Browse the repository at this point in the history
  • Loading branch information
noahshinn committed Aug 4, 2023
1 parent 7e8b29a commit 59db1fb
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
4 changes: 1 addition & 3 deletions alfworld_runs/alfworld_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ def alfworld_run(env, base_prompt, memory: List[str], to_print=True, ob='', mode
else:
env_history = EnvironmentHistory(base_prompt, ob, memory, [])
env_history.reset()
# init_prompt = prompt + ob + '\n>'
# prompt = ''
if to_print:
print(ob)
sys.stdout.flush()
cur_step = 0
while cur_step < 49:
action = llm(str(env_history) + ">", stop=['\n']).strip()
action = llm(str(env_history) + ">", stop=['\n'], model=model).strip()
env_history.add("action", action)
observation, reward, done, info = env.step([action])
observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]
Expand Down
3 changes: 0 additions & 3 deletions programming_runs/generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

from typing import Union, List, Optional, Callable

# openai.api_key = os.getenv("OPENAI_API_KEY")



def generic_generate_func_impl(
func_sig: str,
Expand Down
2 changes: 1 addition & 1 deletion programming_runs/generators/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def generate_chat(self, messages: List[Message], max_tokens: int = 1024, tempera
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}<|end|>\n"
if i != len(messages) - 1:
if i == len(messages) - 1:
prompt += "\n<|assistant|>"

outputs = self.pipe(
Expand Down

0 comments on commit 59db1fb

Please sign in to comment.