Skip to content

Commit

Permalink
Bug fix for illegal memory access error caused when running medusa lo…
Browse files Browse the repository at this point in the history
…ra and plain loras in parallel. (predibase#525)
  • Loading branch information
ajtejankar authored Jun 26, 2024
1 parent 3247ef6 commit f3a67bb
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
22 changes: 16 additions & 6 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ def load(
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
Expand Down Expand Up @@ -338,10 +336,22 @@ def load(
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices]
batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device)
# `indices` indexes the `segment_indices` which contains segment wise adapter index
# `lora_a_ptr` contains segment wise pointers to lora weights
# lengths of `lora_a_ptr` and `segment_indices` must be same
# `indices` will be used to slice the `lora_a_ptr` tensor
# first, find the mapping between adapter index and its location in the `indices` array
idx_locs = {}
for loc, idx in enumerate(indices):
# use the idx to find the adapter index
if segment_indices[idx] not in idx_locs:
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc
# second, iterate over the adapter index for each token and find its location in the `indices` array
batch_indices = torch.tensor([
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
], dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
Expand Down
65 changes: 64 additions & 1 deletion server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type
from unittest import mock

import pytest
Expand Down Expand Up @@ -102,6 +102,69 @@ def test_batched_lora_weights(lora_ranks: List[int]):
assert rd.segment_ends.shape == (2,)



@pytest.mark.parametrize(
"lora_ranks,adapter_indices,expected",
[
(
[8, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0])
}
),
(
[4, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]),
8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]),
}
),
],
)
def test_batched_lora_weights_decode(
lora_ranks: List[int],
adapter_indices: List[int],
expected: Dict[int, Tuple[int, List[int]]]
):
from lorax_server.utils.segments import find_segments
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()

h = 1024
for idx, lora_rank in enumerate(lora_ranks):
weights = LoraWeights(
weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)],
weights_b=[torch.randn((lora_rank, h), dtype=torch.float16)],
adapter_config=LoraConfig(r=lora_rank),
)
batched_weights.add_adapter(idx, weights)

segments, segment_indices = find_segments(adapter_indices)

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64),
adapter_set=set(adapter_indices),
adapter_segments=torch.tensor(segments, dtype=torch.int64),
segment_indices=segment_indices,
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA)

for lora_rank, rd in data.rank_data.items():
expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device)
assert rd.lora_a_ptr.shape == (expected[lora_rank][0],)
assert rd.lora_b_ptr.shape == (expected[lora_rank][0],)
assert all(rd.indices == expected_indices)
assert rd.segment_starts == None
assert rd.segment_ends == None
assert rd.tmp_shrink == None
assert rd.tmp_expand == None

def test_batched_lora_weights_no_segments():
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()
Expand Down

0 comments on commit f3a67bb

Please sign in to comment.