From c84b2fd10f4ca9a4f4a588491c3530c0a0e4f160 Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Tue, 16 Jan 2024 22:51:02 -0800 Subject: [PATCH] fix: pass do sample generation (#118) --- Makefile | 2 +- manifest/api/models/diffuser.py | 2 +- manifest/api/models/huggingface.py | 7 ++++--- manifest/api/models/model.py | 2 +- manifest/api/models/sentence_transformer.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 5fc797f..6aaf992 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ format: black manifest/ tests/ web_app/ check: - isort -c -v manifest/ tests/ web_app/ + isort -c manifest/ tests/ web_app/ black manifest/ tests/ web_app/ --check flake8 manifest/ tests/ web_app/ mypy manifest/ tests/ web_app/ diff --git a/manifest/api/models/diffuser.py b/manifest/api/models/diffuser.py index b42ed3a..e04db4f 100644 --- a/manifest/api/models/diffuser.py +++ b/manifest/api/models/diffuser.py @@ -75,7 +75,7 @@ class DiffuserModel(Model): @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index 89c4c2b..912832b 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -132,7 +132,7 @@ class GenerationPipeline: def __call__( self, text: Union[str, List[str]], **kwargs: Any - ) -> List[Dict[str, Union[str, List[float]]]]: + ) -> List[Dict[str, Union[str, List[float], List[str]]]]: """Generate from text. Args: @@ -162,6 +162,7 @@ class GenerationPipeline: top_p=kwargs.get("top_p"), repetition_penalty=kwargs.get("repetition_penalty"), num_return_sequences=kwargs.get("num_return_sequences"), + do_sample=kwargs.get("do_sample"), ) kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None} output_dict = self.model.generate( # type: ignore @@ -587,7 +588,7 @@ class TextGenerationModel(HuggingFaceModel): @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. @@ -616,7 +617,7 @@ class TextGenerationModel(HuggingFaceModel): ( cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])), - cast(List[int], r["tokens"]), + cast(List[str], r["tokens"]), cast(List[float], r["logprobs"]), ) for r in result diff --git a/manifest/api/models/model.py b/manifest/api/models/model.py index 3317211..dcb04b9 100644 --- a/manifest/api/models/model.py +++ b/manifest/api/models/model.py @@ -45,7 +45,7 @@ class Model: def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model. diff --git a/manifest/api/models/sentence_transformer.py b/manifest/api/models/sentence_transformer.py index bd3f5fa..5f6c2fb 100644 --- a/manifest/api/models/sentence_transformer.py +++ b/manifest/api/models/sentence_transformer.py @@ -66,7 +66,7 @@ class SentenceTransformerModel(Model): @torch.no_grad() def generate( self, prompt: Union[str, List[str]], **kwargs: Any - ) -> List[Tuple[Any, float, List[int], List[float]]]: + ) -> List[Tuple[Any, float, List[str], List[float]]]: """ Generate the prompt from model.