Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Jan 7, 2022
1 parent 3422b41 commit 3000421
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## TBD
### Fixed
- Much faster fused dropout [#164]
- Fused dropout repeatability [#173]

### Added
- Embedding weight tying option [#172]
Expand Down
11 changes: 11 additions & 0 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def test_dropout(shape, amp, bias, p):
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.01

# Check that the same seeds lead to the same dropout
torch.manual_seed(0)
torch.cuda.manual_seed(0)
y_1 = triton_dropout(x, p=0.5)

torch.manual_seed(0)
torch.cuda.manual_seed(0)
y_2 = triton_dropout(x, p=0.5)

assert torch.allclose(y_1, y_2)


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand Down
3 changes: 2 additions & 1 deletion xformers/triton/k_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

@triton.jit
def _get_4_bin_masks(seed, rand_offsets, p):
rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets)
seed = tl.load(seed)
rand1, rand2, rand3, rand4 = tl.randint4x(seed, rand_offsets)

# binarize masks, save registers
# NOTE: We keep the random numbers as is there (integers over int32),
Expand Down

0 comments on commit 3000421

Please sign in to comment.