forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_inputs.py
97 lines (75 loc) · 3.17 KB
/
test_inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual) # noqa: E721
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:
for expected_item, actual_item in zip(expected, actual):
assert_nested_tensors_equal(expected_item, actual_item)
def assert_multimodal_inputs_equal(expected: MultiModalKwargs,
actual: MultiModalKwargs):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])
def test_multimodal_input_batch_single_tensor():
t = torch.rand([1, 2])
result = MultiModalKwargs.batch([{"image": t}])
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
def test_multimodal_input_batch_multiple_tensors():
a = torch.rand([1, 1, 2])
b = torch.rand([1, 1, 2])
c = torch.rand([1, 1, 2])
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
a = torch.rand([1, 2, 2])
b = torch.rand([1, 3, 2])
c = torch.rand([1, 4, 2])
result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalKwargs.batch([{
"image": [a]
}, {
"image": [b]
}, {
"image": [c]
}])
assert_multimodal_inputs_equal(result, {
"image":
torch.stack([a.unsqueeze(0),
b.unsqueeze(0),
c.unsqueeze(0)])
})
def test_multimodal_input_batch_heterogeneous_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
def test_multimodal_input_batch_multiple_batchable_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
d = torch.rand([1, 2, 3])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})
def test_multimodal_input_batch_mixed_stacking_depths():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 3, 3])
c = torch.rand([1, 4, 3])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}])
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})