diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9e81170..f6b9691 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -4,7 +4,6 @@ from collections import Counter from itertools import chain from typing import Any, Dict, Optional, Sequence, Tuple, Union -import peft import torch from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.expert_uid import ExpertUID @@ -156,9 +155,13 @@ class TransformerBackend(ModuleBackend): def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" + + # Import petals.utils.peft only when necessary to avoid importing bitsandbytes + from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt + adapter_was_loaded = False for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): + if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): layer.active_adapter = active_adapter # empty string for no adapter if active_adapter in layer.lora_A.keys(): adapter_was_loaded = True diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index b1c412e..5c04092 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -13,7 +13,6 @@ from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig from petals.utils.misc import QuantType -from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -56,6 +55,10 @@ def convert_block( shard.to(device) if adapters: + # Import petals.utils.peft only when necessary to avoid importing bitsandbytes + os.environ["BITSANDBYTES_NOWELCOME"] = "1" + from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + create_lora_adapter(block, quant_type=quant_type) for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft(