Import petals.utils.peft only when needed to avoid unnecessary import of bitsandbytes (#345)

The motivation is the same as in #180.
pull/348/head
Alexander Borzunov 11 months ago committed by GitHub
parent 294970fe18
commit 43acfe52a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,6 @@ from collections import Counter
from itertools import chain from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import Any, Dict, Optional, Sequence, Tuple, Union
import peft
import torch import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.expert_uid import ExpertUID
@ -156,9 +155,13 @@ class TransformerBackend(ModuleBackend):
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
adapter_was_loaded = False adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
layer.active_adapter = active_adapter # empty string for no adapter layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys(): if active_adapter in layer.lora_A.keys():
adapter_was_loaded = True adapter_was_loaded = True

@ -13,7 +13,6 @@ from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig from transformers import PretrainedConfig
from petals.utils.misc import QuantType from petals.utils.misc import QuantType
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__) logger = get_logger(__name__)
@ -56,6 +55,10 @@ def convert_block(
shard.to(device) shard.to(device)
if adapters: if adapters:
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type) create_lora_adapter(block, quant_type=quant_type)
for adapter_name in adapters: for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft( adapter_config, adapter_state_dict = load_peft(

Loading…
Cancel
Save