cassanof 9 months ago
parent 7470891d85
commit 3fec014d0a

@ -150,7 +150,7 @@ class HFModelBase(ModelBase):
else:
return outs # type: ignore
def prepare_prompt(self, messages: List[Message]) -> str:
def prepare_prompt(self, messages: List[Message]) -> List[int]:
raise NotImplementedError
def extract_output(self, output: str) -> str:
@ -170,14 +170,14 @@ class StarChat(HFModelBase):
)
super().__init__("star-chat", model, tokenizer)
def prepare_prompt(self, messages: List[Message]) -> str:
def prepare_prompt(self, messages: List[Message]) -> List[int]:
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
return prompt
return self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
def extract_output(self, output: str) -> str:
out = output.split("<|assistant|>")[1]

Loading…
Cancel
Save