|
|
|
@ -150,7 +150,7 @@ class HFModelBase(ModelBase):
|
|
|
|
|
else:
|
|
|
|
|
return outs # type: ignore
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> List[int]:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
@ -170,14 +170,14 @@ class StarChat(HFModelBase):
|
|
|
|
|
)
|
|
|
|
|
super().__init__("star-chat", model, tokenizer)
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> List[int]:
|
|
|
|
|
prompt = ""
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
|
|
|
|
|
if i == len(messages) - 1:
|
|
|
|
|
prompt += "<|assistant|>\n"
|
|
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
return self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
|
out = output.split("<|assistant|>")[1]
|
|
|
|
|