Import petals.utils.peft only when needed to avoid unnecessary import of bitsandbytes (#345)

The motivation is the same as in #180.
pull/348/head
Alexander Borzunov 10 months ago committed by GitHub
parent 294970fe18
commit 43acfe52a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,6 @@ from collections import Counter
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import peft
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
@ -156,9 +155,13 @@ class TransformerBackend(ModuleBackend):
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
adapter_was_loaded = True

@ -13,7 +13,6 @@ from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
from petals.utils.misc import QuantType
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
@ -56,6 +55,10 @@ def convert_block(
shard.to(device)
if adapters:
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(

Loading…
Cancel
Save