From 793726b041d5d4b9622ef70e84fdf93ab6cbdc3d Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 12 Mar 2023 22:49:04 +0100 Subject: [PATCH] Speed up loading blocks using init with meta weights (#285) * Init WrappedBloomBlock with meta weights --------- Co-authored-by: Alexander Borzunov --- pyproject.toml | 3 +- src/petals/bloom/from_pretrained.py | 26 +++++++++------ src/petals/server/block_utils.py | 2 +- tests/test_aux_functions.py | 4 +-- tests/test_block_exact_match.py | 51 +++++++++++++++++++++++++++-- tests/test_chained_calls.py | 2 +- tests/test_full_model.py | 2 +- tests/test_remote_sequential.py | 4 +-- tests/test_sequence_manager.py | 2 +- tests/test_server_stats.py | 2 +- tests/test_tensor_parallel.py | 2 +- 11 files changed, 77 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6f5197..cfc991c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ profile = "black" line_length = 120 combine_as_imports = true combine_star = true -known_local_folder = ["tests", "cli"] \ No newline at end of file +known_local_folder = ["tests", "cli"] +known_first_party = ["test_utils"] diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index f8e41a7..9f1d12b 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -13,6 +13,8 @@ import time from typing import Optional, OrderedDict, Union import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device from hivemind.utils.logging import get_logger from transformers.modeling_utils import WEIGHTS_NAME from transformers.models.bloom.configuration_bloom import BloomConfig @@ -38,13 +40,16 @@ def load_pretrained_block( max_disk_space: Optional[int] = None, ) -> WrappedBloomBlock: """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it.""" + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR - block = WrappedBloomBlock(config) + with init_empty_weights(): + block = WrappedBloomBlock(config) + state_dict = _load_state_dict( converted_model_name_or_path, block_index, @@ -54,16 +59,17 @@ def load_pretrained_block( max_disk_space=max_disk_space, ) - if torch_dtype == "auto": - with torch.no_grad(): - for name, param in block.named_parameters(): - assert name in state_dict, f"{name} not in state dict" - param.data = param.data.to(state_dict[name].dtype) - else: - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - block = block.to(dtype=torch_dtype) - + # dummy load, check that keys match report = block.load_state_dict(state_dict, strict=True) + assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}" + + for param_name, _ in block.named_parameters(): + assert param_name in state_dict, f"{param_name} not in state dict" + param = state_dict[param_name] + if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + param = param.to(torch_dtype) + set_module_tensor_to_device(block, param_name, "cpu", value=param) + logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}") return block diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eca7143..fd39ad6 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -30,7 +30,7 @@ def get_block_size( dtype is not None and load_in_8bit is not None ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' - with init_empty_weights(): + with init_empty_weights(include_buffers=True): block = WrappedBloomBlock(config) n_params = sum(param.numel() for param in block.parameters()) diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 1986f0a..6909ccf 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -1,9 +1,9 @@ import pytest import torch -from test_utils import MODEL_NAME from petals.client import DistributedBloomConfig -from petals.server.throughput import measure_compute_rps, measure_network_rps +from petals.server.throughput import measure_compute_rps +from test_utils import MODEL_NAME @pytest.mark.forked diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 664f255..d2fbdde 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,15 +1,18 @@ import random +from typing import Union import hivemind import pytest import torch -from test_utils import * +from transformers.models.bloom.configuration_bloom import BloomConfig -from petals.bloom.from_pretrained import load_pretrained_block +from petals.bloom.block import WrappedBloomBlock +from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteTransformerBlock from petals.data_structures import UID_DELIMITER from petals.dht_utils import get_remote_module +from test_utils import * @pytest.mark.forked @@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) + + +def _old_load_pretrained_block( + converted_model_name_or_path: str, + block_index: int, + torch_dtype: Union[torch.dtype, str] = "auto", +) -> WrappedBloomBlock: + """Load the BLOOM block by directly initializing the weights. + This test is used to check consistency with the previous implementation and can be removed in the future.""" + config = BloomConfig.from_pretrained(converted_model_name_or_path) + + block = WrappedBloomBlock(config) + state_dict = _load_state_dict( + converted_model_name_or_path, + block_index, + config, + cache_dir=None, + ) + + if torch_dtype == "auto": + with torch.no_grad(): + for name, param in block.named_parameters(): + assert name in state_dict, f"{name} not in state dict" + param.data = param.data.to(state_dict[name].dtype) + else: + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + block = block.to(dtype=torch_dtype) + + block.load_state_dict(state_dict, strict=True) + return block + + +@pytest.mark.forked +def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + torch.random.manual_seed(0) + inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype) + + block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) + ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype) + + outputs = block.forward(inputs)[0] + outputs_ref = ref_block.forward(inputs)[0] + assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 261361b..9a619b7 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,12 +7,12 @@ import hivemind import pytest import torch -from test_utils import * from petals.bloom.from_pretrained import load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential from petals.dht_utils import get_remote_sequence +from test_utils import * @pytest.mark.forked diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 1c48c87..cef002e 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -2,11 +2,11 @@ import pytest import torch import transformers from hivemind import get_logger -from test_utils import * from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 7f49a6e..18b41a1 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,14 +1,14 @@ import pytest import torch import torch.nn.functional as F -from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler +from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 -from test_utils import * from petals.bloom.from_pretrained import load_pretrained_block from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 7c175a8..9185ef1 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -4,11 +4,11 @@ import time import pytest import torch from hivemind import DHT, get_logger -from test_utils import * from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER +from test_utils import * logger = get_logger(__name__) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 0f2b3f0..54d6d33 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -3,12 +3,12 @@ import time import hivemind import pytest import torch -from test_utils import * from petals.client import DistributedBloomConfig from petals.data_structures import UID_DELIMITER from petals.dht_utils import get_remote_sequence from petals.server.handler import CACHE_TOKENS_AVAILABLE +from test_utils import * @pytest.mark.forked diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 9d3ba59..84fcab4 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -5,9 +5,9 @@ import torch import transformers from tensor_parallel import TensorParallel from tensor_parallel.slicing_configs import get_bloom_config -from test_utils import MODEL_NAME from petals.bloom.from_pretrained import load_pretrained_block +from test_utils import MODEL_NAME @pytest.mark.forked