|
|
|
@ -151,7 +151,7 @@ class HFModelBase(ModelBase):
|
|
|
|
|
else:
|
|
|
|
|
return outs # type: ignore
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> List[int]:
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
@ -172,7 +172,7 @@ class StarChat(HFModelBase):
|
|
|
|
|
)
|
|
|
|
|
super().__init__("starchat", model, tokenizer, eos_token_id=49155)
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> List[int]:
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]):
|
|
|
|
|
prompt = ""
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
|
|
|
|
@ -214,7 +214,7 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|
|
|
|
)
|
|
|
|
|
super().__init__("codellama", model, tokenizer)
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]):
|
|
|
|
|
if messages[0].role != "system":
|
|
|
|
|
messages = [
|
|
|
|
|
Message(role="system", content=self.DEFAULT_SYSTEM_PROMPT)
|
|
|
|
@ -247,7 +247,8 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|
|
|
|
)
|
|
|
|
|
# remove eos token from last message
|
|
|
|
|
messages_tokens = messages_tokens[:-1]
|
|
|
|
|
return messages_tokens
|
|
|
|
|
import torch
|
|
|
|
|
return torch.tensor([messages_tokens]).to(self.model.device)
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
|
print(output)
|
|
|
|
|