Skip to content

Commit 9788a74

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant][bug] Fix histogram observer with 0 input (pytorch#40191)
Summary: Pull Request resolved: pytorch#40191 When the first couple of inputs passed to histogram observer are all 0's subsequent non-zero inputs cause a div by 0 error Test Plan: python test/test_quantization.py TestHistogramObserver.test_histogram_observer_zero_inputs Imported from OSS Differential Revision: D22119422 fbshipit-source-id: 8bbbba914ba7f343121830c768ca0444439f8e03
1 parent 262ad8e commit 9788a74

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

test/quantization/test_workflow_module.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,20 @@ def test_histogram_observer_one_sided(self):
397397
qparams = myobs.calculate_qparams()
398398
self.assertEqual(qparams[1].item(), 0)
399399

400+
def test_histogram_observer_zero_inputs(self):
401+
myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
402+
x = torch.zeros(4, requires_grad=True)
403+
y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
404+
z = torch.tensor([5.0, 6.0, 7.0, 8.0])
405+
myobs(x)
406+
myobs(x)
407+
myobs(y)
408+
myobs(z)
409+
qparams = myobs.calculate_qparams()
410+
self.assertEqual(myobs.min_val, 2.0)
411+
self.assertEqual(myobs.max_val, 8.0)
412+
self.assertEqual(myobs.histogram, [2., 3., 3.])
413+
400414
class TestFakeQuantizePerTensor(TestCase):
401415
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
402416
X=hu.tensor(shapes=hu.array_shapes(1, 5,),

torch/quantization/observer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,10 @@ def forward(self, x_orig):
828828
x = x_orig.detach()
829829
min_val = self.min_val
830830
max_val = self.max_val
831-
if min_val.numel() == 0 or max_val.numel() == 0:
831+
prev_zeros = False
832+
if min_val.numel() > 0 and max_val.numel() > 0:
833+
prev_zeros = (min_val.item() == 0) and (max_val.item() == 0)
834+
if min_val.numel() == 0 or max_val.numel() == 0 or prev_zeros:
832835
min_val = torch.min(x)
833836
max_val = torch.max(x)
834837
self.min_val.resize_(min_val.shape)

0 commit comments

Comments
 (0)