decent few-shot

main
cassanof 10 months ago
parent 7247f0c306
commit 9c8dc6617e

@ -1,4 +1,4 @@
from generators.model import ModelBase, Message
from generators.model import ModelBase, Message, message_to_str, messages_to_str
import random
from typing import Union, List, Optional, Callable
@ -41,7 +41,7 @@ def generic_generate_func_impl(
content=prompt,
),
Message(
role="user", # TODO: check this
role="user", # TODO: check this
content=reflexion_few_shot,
),
Message(
@ -61,7 +61,8 @@ def generic_generate_func_impl(
content=f"[improved impl]:\n{func_sig}",
),
]
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(
messages=messages, num_comps=num_comps, temperature=temperature)
else:
system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}"
print_messages(system_prompt, func_sig)
@ -75,7 +76,8 @@ def generic_generate_func_impl(
content=func_sig,
),
]
func_bodies = model.generate_chat(messages=messages, num_comps=num_comps, temperature=temperature)
func_bodies = model.generate_chat(
messages=messages, num_comps=num_comps, temperature=temperature)
else:
if strategy == "reflexion":
prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}"
@ -93,7 +95,8 @@ def generic_generate_func_impl(
return func_body_str
else:
func_bodies = [parse_code_block(func_body) for func_body in func_bodies]
func_bodies = [parse_code_block(func_body)
for func_body in func_bodies]
print_generated_func_body("\n\n".join(func_bodies))
return func_bodies
@ -102,7 +105,7 @@ def generic_generate_internal_tests(
func_sig: str,
model: ModelBase,
max_num_tests: int,
test_generation_few_shot: str,
test_generation_few_shot: List[Message],
test_generation_chat_instruction: str,
test_generation_completion_instruction: str,
parse_tests: Callable[[str], List[str]],
@ -119,7 +122,7 @@ def generic_generate_internal_tests(
),
Message(
role="user",
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:"
content=f"{messages_to_str(test_generation_few_shot)}\n\n[func signature]:\n{func_sig}\n\n[think]:"
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
@ -130,9 +133,10 @@ def generic_generate_internal_tests(
role="system",
content=test_generation_chat_instruction,
),
] + test_generation_few_shot + [
Message(
role="user",
content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:",
content=f"{func_sig}"
)
]
output = model.generate_chat(messages=messages, max_tokens=1024)
@ -193,6 +197,7 @@ def sample_n_random(items: List[str], n: int) -> List[str]:
return items
return random.sample(items, n)
def print_messages(system_message_text: str, user_message_text: str) -> None:
print(f"""----------------------- SYSTEM MESSAGE -----------------------)
{system_message_text}
@ -202,7 +207,9 @@ def print_messages(system_message_text: str, user_message_text: str) -> None:
----------------------------------------------
""", flush=True)
def print_generated_func_body(func_body_str: str) -> None:
print(f"""--------------------- GENERATED FUNC BODY ---------------------
{func_body_str}
------------------------------------------""")

@ -17,6 +17,14 @@ class Message():
content: str
def message_to_str(message: Message) -> str:
return f"{message.role}: {message.content}"
def messages_to_str(messages: List[Message]) -> str:
return "\n".join([message_to_str(message) for message in messages])
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gpt_completion(
model: str,

@ -1,4 +1,4 @@
from generators.model import ModelBase
from generators.model import Message, ModelBase, messages_to_str
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
@ -221,33 +221,26 @@ The implementation failed 4 out of the 7 test cases due to an IndexError. The is
END OF EXAMPLES
"""
PY_TEST_GENERATION_FEW_SHOT = """Examples:
func signature:
def has_close_elements(numbers: List[float], threshold: float) -> bool:
\"\"\" Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
\"\"\"
unit tests:
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False
assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True
assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True
assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False"""
PY_TEST_GENERATION_FEW_SHOT = [
Message(role="user", content="""def add3Numbers(x, y, z):
\"\"\" Add three numbers together.
This function takes three numbers as input and returns the sum of the three numbers.
\"\"\""""),
Message(role="assistant", content="""assert add3Numbers(1, 2, 3) == 6
assert add3Numbers(-1, 2, 3) == 4
assert add3Numbers(1, -2, 3) == 2
assert add3Numbers(1, 2, -3) == 0
assert add3Numbers(-3, -2, -1) == -6
assert add3Numbers(0, 0, 0) == 0""")
]
PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.
{PY_TEST_GENERATION_FEW_SHOT}"""
{messages_to_str(PY_TEST_GENERATION_FEW_SHOT)}"""
PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring."""
class PyGenerator(Generator):
def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str:
return generic_generate_self_reflection(

@ -1,4 +1,4 @@
from generators.model import ModelBase
from generators.model import Message, ModelBase, messages_to_str
from .generator_types import Generator
from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection
from .parse import parse_code_block, add_code_block
@ -47,32 +47,18 @@ fn add(a: i32, b: i32) -> i32 {
END EXAMPLES
'''
RS_TEST_GENERATION_FEW_SHOT = """For example:
func signature:
```rust
/// For a given number n, find the largest number that divides n evenly, smaller than n
/// >>> largest_divisor(15)
/// 5
fn largest_divisor(n: isize) -> isize {
for i in (1..n).rev() {
if n % i == 0 {
return i;
}
}
// if no divisor is found, return 1
1
}
```
unit tests:
assert_eq!(candidate(3), 1);
assert_eq!(candidate(7), 1);
assert_eq!(candidate(10), 5);
assert_eq!(candidate(100), 50);
assert_eq!(candidate(49), 7);
"""
RS_TEST_GENERATION_FEW_SHOT = [
Message(role="user", content="""/// Add three numbers together.
/// This function takes three numbers as input and returns the sum of the three numbers.
fn add3Numbers(x: i32, y: i32, z: i32) -> i32 {
"""),
Message(role="assistant", content="""assert_eq!(add3Numbers(1, 2, 3), 6);
assert_eq!(add3Numbers(-1, 2, 3), 4);
assert_eq!(add3Numbers(1, -2, 3), 2);
assert_eq!(add3Numbers(1, 2, -3), 0);
assert_eq!(add3Numbers(-3, -2, -1), -6);
assert_eq!(add3Numbers(0, 0, 0), 0);""")
]
RS_SELF_REFLECTION_FEW_SHOT = '''Example 1:
[function impl]:

Loading…
Cancel
Save