main
cassanof 9 months ago
parent 0365be2c6e
commit f7d1613f8e

@ -151,7 +151,7 @@ class HFModelBase(ModelBase):
else:
return outs # type: ignore
def prepare_prompt(self, messages: List[Message]) -> List[int]:
def prepare_prompt(self, messages: List[Message]):
raise NotImplementedError
def extract_output(self, output: str) -> str:
@ -172,7 +172,7 @@ class StarChat(HFModelBase):
)
super().__init__("starchat", model, tokenizer, eos_token_id=49155)
def prepare_prompt(self, messages: List[Message]) -> List[int]:
def prepare_prompt(self, messages: List[Message]):
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
@ -214,7 +214,7 @@ If a question does not make any sense, or is not factually coherent, explain why
)
super().__init__("codellama", model, tokenizer)
def prepare_prompt(self, messages: List[Message]) -> str:
def prepare_prompt(self, messages: List[Message]):
if messages[0].role != "system":
messages = [
Message(role="system", content=self.DEFAULT_SYSTEM_PROMPT)
@ -247,7 +247,8 @@ If a question does not make any sense, or is not factually coherent, explain why
)
# remove eos token from last message
messages_tokens = messages_tokens[:-1]
return messages_tokens
import torch
return torch.tensor([messages_tokens]).to(self.model.device)
def extract_output(self, output: str) -> str:
print(output)

Loading…
Cancel
Save