Skip to content

Commit

Permalink
Remove .clone() calls in profiler
Browse files Browse the repository at this point in the history
ghstack-source-id: fb9e0c0815328c2feacdbf4dec8164d52e0e951e
Pull Request resolved: https://github.com/fairinternal/xformers/pull/485

__original_commit__ = fairinternal/xformers@6f3f928d8c709d05d6dd353c4a42128b43744719
  • Loading branch information
danthe3rd authored and xFormers Bot committed Mar 10, 2023
1 parent f831656 commit ec1e933
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions xformers/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch.nn as nn
import torch.profiler
import torch.utils.hooks
from torch.utils._pytree import tree_map

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,9 +266,6 @@ def _enter_module_hook(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(
lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
)
if len(args) == 1:
return args[0]
return args
Expand All @@ -291,9 +287,6 @@ def _exit_module_hook(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(
lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args
)
if len(args) == 1:
return args[0]
return args
Expand Down

0 comments on commit ec1e933

Please sign in to comment.