|
|
|
@ -132,7 +132,7 @@ class GenerationPipeline:
|
|
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
|
self, text: Union[str, List[str]], **kwargs: Any
|
|
|
|
|
) -> List[Dict[str, Union[str, List[float]]]]:
|
|
|
|
|
) -> List[Dict[str, Union[str, List[float], List[str]]]]:
|
|
|
|
|
"""Generate from text.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -162,6 +162,7 @@ class GenerationPipeline:
|
|
|
|
|
top_p=kwargs.get("top_p"),
|
|
|
|
|
repetition_penalty=kwargs.get("repetition_penalty"),
|
|
|
|
|
num_return_sequences=kwargs.get("num_return_sequences"),
|
|
|
|
|
do_sample=kwargs.get("do_sample"),
|
|
|
|
|
)
|
|
|
|
|
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
|
|
|
|
|
output_dict = self.model.generate( # type: ignore
|
|
|
|
@ -587,7 +588,7 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def generate(
|
|
|
|
|
self, prompt: Union[str, List[str]], **kwargs: Any
|
|
|
|
|
) -> List[Tuple[Any, float, List[int], List[float]]]:
|
|
|
|
|
) -> List[Tuple[Any, float, List[str], List[float]]]:
|
|
|
|
|
"""
|
|
|
|
|
Generate the prompt from model.
|
|
|
|
|
|
|
|
|
@ -616,7 +617,7 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
(
|
|
|
|
|
cast(str, r["generated_text"]),
|
|
|
|
|
sum(cast(List[float], r["logprobs"])),
|
|
|
|
|
cast(List[int], r["tokens"]),
|
|
|
|
|
cast(List[str], r["tokens"]),
|
|
|
|
|
cast(List[float], r["logprobs"]),
|
|
|
|
|
)
|
|
|
|
|
for r in result
|
|
|
|
|