-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
While operator test generates condition input as a parameter instead of a constant #7986
Comments
Correct. Wondering do you have a specific motivation to teach XLA the trip count? Even if one input to the
Could you share how did you produce the HLO in the second section? Thanks. |
I want the ability to unroll the while loop, which requires the xla to know the trip count. I generated the constant by using xla_builder and passed that as input to |
do you want to debug code/HLO when the def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) you might want to debug when if in that case, I would suggest to use
if you want the loop only run once with def cond_fn(iteri, x):
# if iteri==5:
# print("cond_fn: iteri is 5 now !!!")
# print("iteri: ", iteri)
return iteri > 5
def body_fn(iteri, x):
# if iteri==5:
# print("body_fn: iteri is 5 now !!!")
# print("iteri: ", iteri)
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(5, dtype=torch.int32, device=device)
iteri = torch.tensor(5, device=device)
_, res_with_loop = _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), additional_inputs=()) this above test case will only run once with |
for def test_while_loop_addition(self):
device = xm.xla_device()
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) cond input has due to or you want to try
yes, there are some constraints for the while_loop(XLA::While), such as:
these constraints are limited due to for
|
So for the while loop unroller pass in openxla, the condition input to the while loop needs to be a constant for the unroller pass to determine the trip count. Can you expand a bit more on the mnist model example? Are we planning to do grad accumulation using this while loop ? |
@ManfeiBai, how is it handling for 0th'ed dimension tensors of values 0 or 1? These tensors are, by default, treated as constants: Line 113 in e3cf356
As it stands, it ends up sinking the constants in the body and condition, which could be optimized away and bypassing the intended loop semantics. I can see this likely not occurring when using >1 tensor values for the iterator (as we have with the test) since these are not treated as constant, but wondering how this is handled at the moment. Consider the following example:
The body computation will look as follows:
and the condition as follows:
The final HLO computation is then:
In which case wouldn't the condition:
be problematic here, and potentially optimized away to true/false? Naturally, I see this happen with openxla's
which after a generic call inliner pass becomes (only including the condition and the entry snippet):
and once encountering a constant folding pass, we get the following (only including the condition snippet):
Generally, these two passes are quite generic and usually at the start of a pass pipeline. Is it different in the other backends that you have tested with? Note that, on the other hand, when using
In this case, it succeeds, so it would help to also shed light on the difference and the need to introduce this different variant. |
@tengyifei do you have time to take a look at the question above? |
It seems that it deliberately had every test with an iterator tensor value of >= 2. I am investigating on how to fix this, but it would be helpful to get/discuss insights on how we want to sort this out. This is clearly an existing bug. Simple repro test change:
Error:
Note that, interestingly, the missing inlined constant in the parameter signature that is expected is only on the body computation's HLO (see below). As mentioned above, it's also an issue with the inlined constant being optimized and bypassing while loop semantics in both computations.
|
When running the while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28
I see an HLO that looks like the following (this is the un-optimized graph):
In this case the input tuple has a parameter. This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation. Shouldn't this be a constant?
I have tried other ways to make the condition input a constant, however, that input gets optimized away and I end up getting an HLO as follows:
cond_hlo:
Question:
The text was updated successfully, but these errors were encountered: