main
cassanof 10 months ago
parent 9c8dc6617e
commit 49485efc8f

@ -1,4 +1,4 @@
from generators.model import ModelBase, Message, message_to_str, messages_to_str from generators.model import ModelBase, Message
import random import random
from typing import Union, List, Optional, Callable from typing import Union, List, Optional, Callable
@ -41,7 +41,7 @@ def generic_generate_func_impl(
content=prompt, content=prompt,
), ),
Message( Message(
role="user", # TODO: check this role="user", # TODO: check this
content=reflexion_few_shot, content=reflexion_few_shot,
), ),
Message( Message(
@ -61,8 +61,7 @@ def generic_generate_func_impl(
content=f"[improved impl]:\n{func_sig}", content=f"[improved impl]:\n{func_sig}",
), ),
] ]
func_bodies = model.generate_chat( func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
messages=messages, num_comps=num_comps, temperature=temperature)
else: else:
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}" system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
print_messages(system_prompt, func_sig) print_messages(system_prompt, func_sig)
@ -76,8 +75,7 @@ def generic_generate_func_impl(
content=func_sig, content=func_sig,
), ),
] ]
func_bodies = model.generate_chat( func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
messages=messages, num_comps=num_comps, temperature=temperature)
else: else:
if strategy == "reflexion": if strategy == "reflexion":
prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}" prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
@ -95,8 +93,7 @@ def generic_generate_func_impl(
return func_body_str return func_body_str
else: else:
func_bodies = [parse_code_block(func_body) func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
for func_body in func_bodies]
print_generated_func_body("\n\n".join(func_bodies)) print_generated_func_body("\n\n".join(func_bodies))
return func_bodies return func_bodies
@ -105,7 +102,7 @@ def generic_generate_internal_tests(
func_sig: str, func_sig: str,
model: ModelBase, model: ModelBase,
max_num_tests: int, max_num_tests: int,
test_generation_few_shot: List[Message], test_generation_few_shot: str,
test_generation_chat_instruction: str, test_generation_chat_instruction: str,
test_generation_completion_instruction: str, test_generation_completion_instruction: str,
parse_tests: Callable[[str], List[str]], parse_tests: Callable[[str], List[str]],
@ -122,7 +119,7 @@ def generic_generate_internal_tests(
), ),
Message( Message(
role="user", role="user",
content=f"{messages_to_str(test_generation_few_shot)}\n\n[func signature]:\n{func_sig}\n\n[think]:" content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:"
) )
] ]
output = model.generate_chat(messages=messages, max_tokens=1024) output = model.generate_chat(messages=messages, max_tokens=1024)
@ -133,10 +130,9 @@ def generic_generate_internal_tests(
role="system", role="system",
content=test_generation_chat_instruction, content=test_generation_chat_instruction,
), ),
] + test_generation_few_shot + [
Message( Message(
role="user", role="user",
content=f"{func_sig}" content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:",
) )
] ]
output = model.generate_chat(messages=messages, max_tokens=1024) output = model.generate_chat(messages=messages, max_tokens=1024)
@ -197,7 +193,6 @@ def sample_n_random(items: List[str], n: int) -> List[str]:
return items return items
return random.sample(items, n) return random.sample(items, n)
def print_messages(system_message_text: str, user_message_text: str) -> None: def print_messages(system_message_text: str, user_message_text: str) -> None:
print(f"""----------------------- SYSTEM MESSAGE -----------------------) print(f"""----------------------- SYSTEM MESSAGE -----------------------)
{system_message_text} {system_message_text}
@ -207,9 +202,7 @@ def print_messages(system_message_text: str, user_message_text: str) -> None:
---------------------------------------------- ----------------------------------------------
""", flush=True) """, flush=True)
def print_generated_func_body(func_body_str: str) -> None: def print_generated_func_body(func_body_str: str) -> None:
print(f"""--------------------- GENERATED FUNC BODY --------------------- print(f"""--------------------- GENERATED FUNC BODY ---------------------
{func_body_str} {func_body_str}
------------------------------------------""") ------------------------------------------""")

@ -1,4 +1,4 @@
from generators.model import Message, ModelBase, messages_to_str from generators.model import ModelBase, message_to_str
from .generator_types import Generator from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
@ -221,22 +221,24 @@ The implementation failed 4 out of the 7 test cases due to an IndexError. The is
END OF EXAMPLES END OF EXAMPLES
""" """
PY_TEST_GENERATION_FEW_SHOT = [ PY_TEST_GENERATION_FEW_SHOT = """Examples:
Message(role="user", content="""def add3Numbers(x, y, z): func signature:
def add3Numbers(x, y, z):
\"\"\" Add three numbers together. \"\"\" Add three numbers together.
This function takes three numbers as input and returns the sum of the three numbers. This function takes three numbers as input and returns the sum of the three numbers.
\"\"\""""), \"\"\"
Message(role="assistant", content="""assert add3Numbers(1, 2, 3) == 6 unit tests:
assert add3Numbers(1, 2, 3) == 6
assert add3Numbers(-1, 2, 3) == 4 assert add3Numbers(-1, 2, 3) == 4
assert add3Numbers(1, -2, 3) == 2 assert add3Numbers(1, -2, 3) == 2
assert add3Numbers(1, 2, -3) == 0 assert add3Numbers(1, 2, -3) == 0
assert add3Numbers(-3, -2, -1) == -6 assert add3Numbers(-3, -2, -1) == -6
assert add3Numbers(0, 0, 0) == 0""") assert add3Numbers(0, 0, 0) == 0
] """
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.
{messages_to_str(PY_TEST_GENERATION_FEW_SHOT)}""" {PY_TEST_GENERATION_FEW_SHOT}"""
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring."""

@ -1,4 +1,4 @@
from generators.model import Message, ModelBase, messages_to_str from generators.model import ModelBase
from .generator_types import Generator from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
from .parse import parse_code_block, add_code_block from .parse import parse_code_block, add_code_block
@ -47,18 +47,21 @@ fn add(a: i32, b: i32) -> i32 {
END EXAMPLES END EXAMPLES
''' '''
RS_TEST_GENERATION_FEW_SHOT = [ RS_TEST_GENERATION_FEW_SHOT = """For example:
Message(role="user", content="""/// Add three numbers together.
func signature:
/// Add three numbers together.
/// This function takes three numbers as input and returns the sum of the three numbers. /// This function takes three numbers as input and returns the sum of the three numbers.
fn add3Numbers(x: i32, y: i32, z: i32) -> i32 { fn add3Numbers(x: i32, y: i32, z: i32) -> i32 {
"""),
Message(role="assistant", content="""assert_eq!(add3Numbers(1, 2, 3), 6); unit tests:
assert_eq!(add3Numbers(1, 2, 3), 6);
assert_eq!(add3Numbers(-1, 2, 3), 4); assert_eq!(add3Numbers(-1, 2, 3), 4);
assert_eq!(add3Numbers(1, -2, 3), 2); assert_eq!(add3Numbers(1, -2, 3), 2);
assert_eq!(add3Numbers(1, 2, -3), 0); assert_eq!(add3Numbers(1, 2, -3), 0);
assert_eq!(add3Numbers(-3, -2, -1), -6); assert_eq!(add3Numbers(-3, -2, -1), -6);
assert_eq!(add3Numbers(0, 0, 0), 0);""") assert_eq!(add3Numbers(0, 0, 0), 0);
] """
RS_SELF_REFLECTION_FEW_SHOT = '''Example 1: RS_SELF_REFLECTION_FEW_SHOT = '''Example 1:
[function impl]: [function impl]:

Loading…
Cancel
Save