@ -1,4 +1,4 @@
from generators . model import ModelBase , Message , message_to_str , messages_to_str
from generators . model import ModelBase , Message
import random
from typing import Union , List , Optional , Callable
@ -41,7 +41,7 @@ def generic_generate_func_impl(
content = prompt ,
) ,
Message (
role = " user " , # TODO: check this
role = " user " , # TODO: check this
content = reflexion_few_shot ,
) ,
Message (
@ -61,8 +61,7 @@ def generic_generate_func_impl(
content = f " [improved impl]: \n { func_sig } " ,
) ,
]
func_bodies = model . generate_chat (
messages = messages , num_comps = num_comps , temperature = temperature )
func_bodies = model . generate_chat ( messages = messages , num_comps = num_comps , temperature = temperature )
else :
system_prompt = f " { simple_chat_instruction } \n { code_block_instruction } "
print_messages ( system_prompt , func_sig )
@ -76,8 +75,7 @@ def generic_generate_func_impl(
content = func_sig ,
) ,
]
func_bodies = model . generate_chat (
messages = messages , num_comps = num_comps , temperature = temperature )
func_bodies = model . generate_chat ( messages = messages , num_comps = num_comps , temperature = temperature )
else :
if strategy == " reflexion " :
prompt = f " { reflexion_completion_instruction } \n { add_code_block ( prev_func_impl ) } \n \n unit tests: \n { feedback } \n \n hint: \n { self_reflection } \n \n # improved implementation \n { func_sig } \n { code_block_instruction } "
@ -95,8 +93,7 @@ def generic_generate_func_impl(
return func_body_str
else :
func_bodies = [ parse_code_block ( func_body )
for func_body in func_bodies ]
func_bodies = [ parse_code_block ( func_body ) for func_body in func_bodies ]
print_generated_func_body ( " \n \n " . join ( func_bodies ) )
return func_bodies
@ -105,7 +102,7 @@ def generic_generate_internal_tests(
func_sig : str ,
model : ModelBase ,
max_num_tests : int ,
test_generation_few_shot : List [ Message ] ,
test_generation_few_shot : str ,
test_generation_chat_instruction : str ,
test_generation_completion_instruction : str ,
parse_tests : Callable [ [ str ] , List [ str ] ] ,
@ -122,7 +119,7 @@ def generic_generate_internal_tests(
) ,
Message (
role = " user " ,
content = f " { messages_to_str( test_generation_few_shot) } \n \n [func signature]: \n { func_sig } \n \n [think]: "
content = f " { test_generation_few_shot} \n \n [func signature]: \n { func_sig } \n \n [think]: "
)
]
output = model . generate_chat ( messages = messages , max_tokens = 1024 )
@ -133,10 +130,9 @@ def generic_generate_internal_tests(
role = " system " ,
content = test_generation_chat_instruction ,
) ,
] + test_generation_few_shot + [
Message (
role = " user " ,
content = f " { func_sig} "
content = f " { test_generation_few_shot} \n \n [func signature]: \n { func_sig} \n \n [unit tests]: " ,
)
]
output = model . generate_chat ( messages = messages , max_tokens = 1024 )
@ -197,7 +193,6 @@ def sample_n_random(items: List[str], n: int) -> List[str]:
return items
return random . sample ( items , n )
def print_messages ( system_message_text : str , user_message_text : str ) - > None :
print ( f """ ----------------------- SYSTEM MESSAGE -----------------------)
{ system_message_text }
@ -207,9 +202,7 @@ def print_messages(system_message_text: str, user_message_text: str) -> None:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
""" , flush=True)
def print_generated_func_body ( func_body_str : str ) - > None :
print ( f """ --------------------- GENERATED FUNC BODY ---------------------
{ func_body_str }
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - """ )