Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

While operator test generates condition input as a parameter instead of a constant #7986

Open
aws-rhsoln opened this issue Sep 10, 2024 · 9 comments

Comments

@aws-rhsoln
Copy link
Contributor

When running the while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28

I see an HLO that looks like the following (this is the un-optimized graph):

while_loop.40 {
  p0.41 = s64[] parameter(0)
  p1.42 = s32[] parameter(1)
  tuple.43 = (s64[], s32[]) tuple(p0.41, p1.42)
  ROOT while.44 = (s64[], s32[]) while(tuple.43), condition=PyLoweringContext.5.35, body=PyLoweringContext.12.25
}

ENTRY SyncTensorsGraph.49 {
  p0.3 = s64[] parameter(0), sharding={replicated}
  constant.2 = s64[] constant(1)
  constant.1 = s64[] constant(1)
  multiply.4 = s64[] multiply(constant.2, constant.1)
  subtract.5 = s64[] subtract(p0.3, multiply.4)
  p1.8 = s32[] parameter(1), sharding={replicated}
  constant.7 = s32[] constant(1)
  constant.6 = s32[] constant(1)
  multiply.9 = s32[] multiply(constant.7, constant.6)
  add.10 = s32[] add(p1.8, multiply.9)
  constant.11 = s64[] constant(0)
  compare.12 = pred[] compare(p0.3, constant.11), direction=GT
  call.45 = (s64[], s32[]) call(p0.3, p1.8), to_apply=while_loop.40
  get-tuple-element.46 = s64[] get-tuple-element(call.45), index=0
  get-tuple-element.47 = s32[] get-tuple-element(call.45), index=1
  ROOT tuple.48 = (s64[], s32[], pred[], s64[], s32[]) tuple(subtract.5, add.10, compare.12, get-tuple-element.46, get-tuple-element.47)
} // SyncTensorsGraph.49

In this case the input tuple has a parameter. This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation. Shouldn't this be a constant?

I have tried other ways to make the condition input a constant, however, that input gets optimized away and I end up getting an HLO as follows:

cond_hlo:

%PyLoweringContext.10 (p0.11: f32[2,2], UnusedArgumentsPlaceholder.15: f32[2,2]) -> pred[] {
  %constant.13 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/home/ubuntu/while_loop.py" source_line=80}
  %p0.11 = f32[2,2]{1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/ubuntu/while_loop.py" source_line=80}
  %call.12 = s64[] call(f32[2,2]{1,0} %p0.11), to_apply=%return_constReturnConst.7, metadata={op_type="xla___op_return_constReturnConst" op_name="xla___op_return_constReturnConst" source_file="/home/ubuntu/pt24/lib/python3.10/site-packages/torch_xla/core/xla_op_registry.py" source_line=44}
  ROOT %compare.14 = pred[] compare(s64[] %constant.13, s64[] %call.12), direction=LT, metadata={op_type="aten__lt" op_name="aten__lt" source_file="/home/ubuntu/while_loop.py" source_line=75}
  %UnusedArgumentsPlaceholder.15 = f32[2,2]{1,0} parameter(1)
}

ENTRY %PyLoweringContext.12.17 (in.1: (f32[2,2], f32[2,2])) -> pred[] {
  %in.1 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  ROOT %call.16 = pred[] call(f32[2,2]{1,0} %get-tuple-element.2, f32[2,2]{1,0} %get-tuple-element.3), to_apply=%PyLoweringContext.10
}

Question:

  1. Should the condition input be a constant?
  2. I see some constraints for the while op, plan to make it workable for different values of inputs?
@JackCaoG
Copy link
Collaborator

@ManfeiBai @tengyifei

@tengyifei
Copy link
Collaborator

This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation.

Correct. Wondering do you have a specific motivation to teach XLA the trip count? Even if one input to the While op is a constant, the cond computation still has to compare it with another constant to determine if it's going to break the loop.

I have tried other ways to make the condition input a constant, however, that input gets optimized away

Could you share how did you produce the HLO in the second section? Thanks.

@aws-rhsoln
Copy link
Contributor Author

I want the ability to unroll the while loop, which requires the xla to know the trip count.

I generated the constant by using xla_builder and passed that as input to _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ()) . Here iteri is a constant.

@ManfeiBai
Copy link
Collaborator

I want the ability to unroll the while loop, which requires the xla to know the trip count.

I generated the constant by using xla_builder and passed that as input to _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ()) . Here iteri is a constant.

do you want to debug code/HLO when the iteri get to a specific value? such as for:

    def cond_fn(iteri, x):
      return iteri > 0

    def body_fn(iteri, x):
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(3, dtype=torch.int32, device=device)
    iteri = torch.tensor(10, device=device)
    _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))

you might want to debug when iteri==5 ?

if in that case, I would suggest to use _xla_while_loop_wrapper directly, like added test case in #7993, and you could do some debug in code, for HLO, I would say HLO for each iteri/trip should be the same once they wrapped into while_loop, so the HLO we get before would be the HLO used by each iteri/trip

_xla_while_loop_wrapper will skip some constraints, and we could catch a specific iteri value during the loops;

if you want the loop only run once with iteri==5, we might want to limit iteri and init_val to 5 like:

    def cond_fn(iteri, x):
      # if iteri==5:
      #   print("cond_fn: iteri is 5 now !!!")
      # print("iteri: ", iteri)
      return iteri > 5

    def body_fn(iteri, x):
      # if iteri==5:
      #   print("body_fn: iteri is 5 now !!!")
      # print("iteri: ", iteri)
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(5, dtype=torch.int32, device=device)
    iteri = torch.tensor(5, device=device)
    _, res_with_loop = _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), additional_inputs=())

this above test case will only run once with iteri==5

@ManfeiBai
Copy link
Collaborator

Question:

  1. Should the condition input be a constant?

for while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28:

def test_while_loop_addition(self):
    device = xm.xla_device()

    def cond_fn(iteri, x):
      return iteri > 0

    def body_fn(iteri, x):
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(3, dtype=torch.int32, device=device)
    iteri = torch.tensor(10, device=device)
    _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
    _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
    self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

cond input has iteri and x: iteri is the iteriator used as trip count, it get initialized and get changed in each trip; x is the carry value to be executed in each trip;

due to iteri and x need to be changed in each trip, so do we have more context for requirement for the condition input be a constant, any use cases?

or you want to try _xla_while_loop_wrapper like the above comment?

  1. I see some constraints for the while op, plan to make it workable for different values of inputs?

yes, there are some constraints for the while_loop(XLA::While), such as:

  • cond_fn's input, body_fn's input and return, while_loops' input should be the same size and shape
  • cond_fn return bool

these constraints are limited due to XLA:While op limitation to run on TPU;

for make it workable for different values of inputs, do we have use/test case for example?

  • if different values of inputs means different value of iteri, x, we would add more use case such as mnist model, and more complex model in the future, its still WIP;
  • if different values of inputs means use different iteri, x in cond_fn's arg, body_fn's arg&return, while_loop's inputs, this would break the restriction of cond_fn's input, body_fn's input and return, while_loops' input should be the same size and shape mentioned above, this will break XLA::While's prerequirement, and that would be a XLA op question;

@aws-rhsoln
Copy link
Contributor Author

So for the while loop unroller pass in openxla, the condition input to the while loop needs to be a constant for the unroller pass to determine the trip count.

Can you expand a bit more on the mnist model example? Are we planning to do grad accumulation using this while loop ?

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Oct 7, 2024

@ManfeiBai, how is it handling for 0th'ed dimension tensors of values 0 or 1? These tensors are, by default, treated as constants:

XLA_NO_SPECIAL_SCALARS:

As it stands, it ends up sinking the constants in the body and condition, which could be optimized away and bypassing the intended loop semantics. I can see this likely not occurring when using >1 tensor values for the iterator (as we have with the test) since these are not treated as constant, but wondering how this is handled at the moment.

Consider the following example:

def cond_fn(iteri_, x):
  return iteri_ < 5
 
def body_fn(iteri_, x):
  return iteri_ + 1, torch.add(x, 1)

init_val = torch.tensor(1).to(device)
iteri = torch.tensor(0, dtype=torch.int64).to(device)
_, res = _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ())
xm.mark_step()

The body computation will look as follows:

%PyLoweringContext.4 (UnusedArgumentsPlaceholder.15: s64[], UnusedArgumentsPlaceholder.16: s64[]) -> (s64[], s64[]) {
  %UnusedArgumentsPlaceholder.15 = s64[] parameter(0)
  %UnusedArgumentsPlaceholder.16 = s64[] parameter(1)
  %constant.7 = s64[] constant(0)
  %constant.6 = s64[] constant(1)
  %constant.5 = s64[] constant(1)
  %multiply.8 = s64[] multiply(s64[] %constant.6, s64[] %constant.5)
  %add.9 = s64[] add(s64[] %constant.7, s64[] %multiply.8)
  %constant.12 = s64[] constant(1)
  %constant.11 = s64[] constant(1)
  %constant.10 = s64[] constant(1)
  %multiply.13 = s64[] multiply(s64[] %constant.11, s64[] %constant.10)
  %add.14 = s64[] add(s64[] %constant.12, s64[] %multiply.13)
  ROOT %tuple.17 = (s64[], s64[]) tuple(s64[] %add.9, s64[] %add.14)
}

ENTRY %PyLoweringContext.14.19 (in.1: (s64[], s64[])) -> (s64[], s64[]) {
  %in.1 = (s64[], s64[]) parameter(0)
  %get-tuple-element.2 = s64[] get-tuple-element((s64[], s64[]) %in.1), index=0
  %get-tuple-element.3 = s64[] get-tuple-element((s64[], s64[]) %in.1), index=1
  ROOT %call.18 = (s64[], s64[]) call(s64[] %get-tuple-element.2, s64[] %get-tuple-element.3), to_apply=%PyLoweringContext.4
}

and the condition as follows:

%PyLoweringContext.4 (UnusedArgumentsPlaceholder.8: s64[], UnusedArgumentsPlaceholder.9: s64[]) -> pred[] {
  %constant.6 = s64[] constant(0)
  %constant.5 = s64[] constant(5)
  ROOT %compare.7 = pred[] compare(s64[] %constant.6, s64[] %constant.5), direction=LT
  %UnusedArgumentsPlaceholder.8 = s64[] parameter(0)
  %UnusedArgumentsPlaceholder.9 = s64[] parameter(1)
}

ENTRY %PyLoweringContext.6.11 (in.1: (s64[], s64[])) -> pred[] {
  %in.1 = (s64[], s64[]) parameter(0)
  %get-tuple-element.2 = s64[] get-tuple-element((s64[], s64[]) %in.1), index=0
  %get-tuple-element.3 = s64[] get-tuple-element((s64[], s64[]) %in.1), index=1
  ROOT %call.10 = pred[] call(s64[] %get-tuple-element.2, s64[] %get-tuple-element.3), to_apply=%PyLoweringContext.4
}

The final HLO computation is then:

%PyLoweringContext.4 (UnusedArgumentsPlaceholder.15: s64[], UnusedArgumentsPlaceholder.16: s64[]) -> (s64[], s64[]) {
  %UnusedArgumentsPlaceholder.15 = s64[] parameter(0)
  %UnusedArgumentsPlaceholder.16 = s64[] parameter(1)
  %constant.7 = s64[] constant(0)
  %constant.6 = s64[] constant(1)
  %constant.5 = s64[] constant(1)
  %multiply.8 = s64[] multiply(s64[] %constant.6, s64[] %constant.5)
  %add.9 = s64[] add(s64[] %constant.7, s64[] %multiply.8)
  %constant.12 = s64[] constant(1)
  %constant.11 = s64[] constant(1)
  %constant.10 = s64[] constant(1)
  %multiply.13 = s64[] multiply(s64[] %constant.11, s64[] %constant.10)
  %add.14 = s64[] add(s64[] %constant.12, s64[] %multiply.13)
  ROOT %tuple.17 = (s64[], s64[]) tuple(s64[] %add.9, s64[] %add.14)
}

%PyLoweringContext.14.18 (in.19: (s64[], s64[])) -> (s64[], s64[]) {
  %in.19 = (s64[], s64[]) parameter(0)
  %get-tuple-element.20 = s64[] get-tuple-element((s64[], s64[]) %in.19), index=0
  %get-tuple-element.21 = s64[] get-tuple-element((s64[], s64[]) %in.19), index=1
  ROOT %call.22 = (s64[], s64[]) call(s64[] %get-tuple-element.20, s64[] %get-tuple-element.21), to_apply=%PyLoweringContext.4
}

%PyLoweringContext.23 (UnusedArgumentsPlaceholder.27: s64[], UnusedArgumentsPlaceholder.28: s64[]) -> pred[] {
  %constant.25 = s64[] constant(0)
  %constant.24 = s64[] constant(5)
  ROOT %compare.26 = pred[] compare(s64[] %constant.25, s64[] %constant.24), direction=LT
  %UnusedArgumentsPlaceholder.27 = s64[] parameter(0)
  %UnusedArgumentsPlaceholder.28 = s64[] parameter(1)
}

%PyLoweringContext.6.29 (in.30: (s64[], s64[])) -> pred[] {
  %in.30 = (s64[], s64[]) parameter(0)
  %get-tuple-element.31 = s64[] get-tuple-element((s64[], s64[]) %in.30), index=0
  %get-tuple-element.32 = s64[] get-tuple-element((s64[], s64[]) %in.30), index=1
  ROOT %call.33 = pred[] call(s64[] %get-tuple-element.31, s64[] %get-tuple-element.32), to_apply=%PyLoweringContext.23
}

ENTRY %while_loop.35 (p0.1: s64[], p1.2: s64[]) -> (s64[], s64[]) {
  %p0.1 = s64[] parameter(0)
  %p1.2 = s64[] parameter(1)
  %tuple.3 = (s64[], s64[]) tuple(s64[] %p0.1, s64[] %p1.2)
  ROOT %while.34 = (s64[], s64[]) while((s64[], s64[]) %tuple.3), condition=%PyLoweringContext.6.29, body=%PyLoweringContext.14.18
}

In which case wouldn't the condition:

  %constant.25 = s64[] constant(0)
  %constant.24 = s64[] constant(5)
  ROOT %compare.26 = pred[] compare(s64[] %constant.25, s64[] %constant.24), direction=LT

be problematic here, and potentially optimized away to true/false?

Naturally, I see this happen with openxla's hlo_parser when reading the HLO's text file module received from torch xla (only including the condition snippet):

PyLoweringContext.22 {
  constant.24 = s64[] constant(0)
  constant.23 = s64[] constant(5)
  ROOT compare.25 = pred[] compare(constant.24, constant.23), direction=LT
  UnusedArgumentsPlaceholder.26 = s64[] parameter(0)
  UnusedArgumentsPlaceholder.27 = s64[] parameter(1)
}

PyLoweringContext.6.28 {
  in.29 = (s64[], s64[]) parameter(0)
  get-tuple-element.30 = s64[] get-tuple-element(in.29), index=0
  get-tuple-element.31 = s64[] get-tuple-element(in.29), index=1
  ROOT call.32 = pred[] call(get-tuple-element.30, get-tuple-element.31), to_apply=PyLoweringContext.22
}

which after a generic call inliner pass becomes (only including the condition and the entry snippet):

PyLoweringContext.6.28 {
  in.29 = (s64[], s64[]) parameter(0)
  constant.14 = s64[] constant(0)
  constant.15 = s64[] constant(5)
  ROOT compare.0 = pred[] compare(constant.14, constant.15), direction=LT
}

ENTRY SyncTensorsGraph.42 {
  constant.1 = s64[] constant(1)
  constant.2 = s64[] constant(0)
  tuple.1 = (s64[], s64[]) tuple(constant.2, constant.1)
  while.0 = (s64[], s64[]) while(tuple.1), condition=PyLoweringContext.6.28, body=PyLoweringContext.14.17
  get-tuple-element.39 = s64[] get-tuple-element(while.0), index=0
  get-tuple-element.40 = s64[] get-tuple-element(while.0), index=1
  ROOT tuple.41 = (s64[], s64[], s64[], s64[]) tuple(constant.1, constant.2, get-tuple-element.39, get-tuple-element.40)
} // SyncTensorsGraph.42

and once encountering a constant folding pass, we get the following (only including the condition snippet):

PyLoweringContext.6.28 {
  in.29 = (s64[], s64[]) parameter(0)
  constant.14 = s64[] constant(0)
  constant.15 = s64[] constant(5)
  ROOT constant.17 = pred[] constant(true)
}

Generally, these two passes are quite generic and usually at the start of a pass pipeline. Is it different in the other backends that you have tested with?

Note that, on the other hand, when using mkwhile directly, without the lowering context, we instead see the parameters on the body and condition without sinking and adding the constant in both computations:

HloModule while_loop.20, entry_computation_layout={(s64[], s64[])->(s64[], s64[])}

%Body.4 (p0.5: (s64[], s64[])) -> (s64[], s64[]) {
  %p0.5 = (s64[], s64[]) parameter(0)
  %get-tuple-element.6 = s64[] get-tuple-element((s64[], s64[]) %p0.5), index=0
  %constant.8 = s64[] constant(1)
  %add.9 = s64[] add(s64[] %get-tuple-element.6, s64[] %constant.8)
  %get-tuple-element.7 = s64[] get-tuple-element((s64[], s64[]) %p0.5), index=1
  %constant.10 = s64[] constant(1)
  %add.11 = s64[] add(s64[] %get-tuple-element.7, s64[] %constant.10)
  ROOT %tuple.12 = (s64[], s64[]) tuple(s64[] %add.9, s64[] %add.11)
}

%Condition.13 (p0.14: (s64[], s64[])) -> pred[] {
  %p0.14 = (s64[], s64[]) parameter(0)
  %get-tuple-element.16 = s64[] get-tuple-element((s64[], s64[]) %p0.14), index=1
  %get-tuple-element.15 = s64[] get-tuple-element((s64[], s64[]) %p0.14), index=0
  %constant.17 = s64[] constant(5)
  ROOT %compare.18 = pred[] compare(s64[] %get-tuple-element.15, s64[] %constant.17), direction=LT
}

ENTRY %while_loop.20 (p0.1: s64[], p1.2: s64[]) -> (s64[], s64[]) {
  %p0.1 = s64[] parameter(0)
  %p1.2 = s64[] parameter(1)
  %tuple.3 = (s64[], s64[]) tuple(s64[] %p0.1, s64[] %p1.2)
  ROOT %while.19 = (s64[], s64[]) while((s64[], s64[]) %tuple.3), condition=%Condition.13, body=%Body.4
}

In this case, it succeeds, so it would help to also shed light on the difference and the need to introduce this different variant.

@JackCaoG
Copy link
Collaborator

@tengyifei do you have time to take a look at the question above?

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Oct 11, 2024

It seems that it deliberately had every test with an iterator tensor value of >= 2. I am investigating on how to fix this, but it would be helpful to get/discuss insights on how we want to sort this out. This is clearly an existing bug.

Simple repro test change:

diff --git a/test/test_while_loop.py b/test/test_while_loop.py
index e8ea617b0..ee5254dda 100644
--- a/test/test_while_loop.py
+++ b/test/test_while_loop.py
@@ -84,7 +84,7 @@ class WhileLoopTest(unittest.TestCase):
     linear_model = SimpleLinear()
     linear_model.to(device)
     l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
-    iteri = torch.tensor(10, dtype=torch.int32, device=device)
+    iteri = torch.tensor(1, dtype=torch.int32, device=device)
     _, res_with_loop = linear_model(iteri, l_in_0)
     _, res_without_loop = linear_model.forward_without_while_loop_op(
         iteri, l_in_0)

Error:

+ run_test /workspaces/pytorch/xla/test/test_while_loop.py
+ echo 'Running in PjRt runtime: /workspaces/pytorch/xla/test/test_while_loop.py'
Running in PjRt runtime: /workspaces/pytorch/xla/test/test_while_loop.py
++ command -v nvidia-smi
+ '[' -x '' ']'
+ PJRT_DEVICE=CPU
+ CPU_NUM_DEVICES=1
+ run_coverage /workspaces/pytorch/xla/test/test_while_loop.py
+ '[' 0 '!=' 0 ']'
+ python3 /workspaces/pytorch/xla/test/test_while_loop.py
s..WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1728686405.656992  289240 debug_macros.h:21] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (in: (s32[], f32[2], f32[2,2], f32[2,2])) -> pred[]; body: (in: (f32[2], f32[2,2], f32[2,2], f32[2,2])) -> (s32[], f32[2], f32[2,2], f32[2,2]); init: (s32[], f32[2], f32[2,2], f32[2,2])..: 
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        xla::XlaComputation ConsumeValue<xla::XlaComputation>(absl::lts_20230802::StatusOr<xla::XlaComputation>&&)

Note that, interestingly, the missing inlined constant in the parameter signature that is expected is only on the body computation's HLO (see below). As mentioned above, it's also an issue with the inlined constant being optimized and bypassing while loop semantics in both computations.

  • Body
HloModule PyLoweringContext.18.25, entry_computation_layout={((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}))->(s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0})}

%PyLoweringContext.6 (p0.12: f32[2], p1.13: f32[2,2], p2.15: f32[2,2], UnusedArgumentsPlaceholder.22: f32[2,2]) -> (s32[], f32[2], f32[2,2], f32[2,2]) {
  %UnusedArgumentsPlaceholder.22 = f32[2,2]{1,0} parameter(3)
  %constant.9 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/workspaces/pytorch/xla/test/test_while_loop.py" source_line=89}
  %constant.8 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.1" source_line=5}
  %constant.7 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.1" source_line=5}
  %multiply.10 = s32[] multiply(s32[] %constant.8, s32[] %constant.7), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub" source_file="<eval_with_key>.1" source_line=5}
  %subtract.11 = s32[] subtract(s32[] %constant.9, s32[] %multiply.10), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub" source_file="<eval_with_key>.1" source_line=5}
  %p0.12 = f32[2]{0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/torch/nn/modules/module.py" source_line=1326}
  %p1.13 = f32[2,2]{1,0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/torch/nn/modules/module.py" source_line=1326}
  %p2.15 = f32[2,2]{1,0} parameter(2), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/xla/torch_xla/experimental/fori_loop.py" source_line=78}
  %transpose.14 = f32[2,2]{0,1} transpose(f32[2,2]{1,0} %p1.13), dimensions={1,0}, metadata={op_type="aten__as_strided" op_name="aten__as_strided" source_file="<eval_with_key>.1" source_line=6}
  %dot.16 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %p2.15, f32[2,2]{0,1} %transpose.14), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %reshape.17 = f32[1,2]{1,0} reshape(f32[2]{0} %p0.12), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %broadcast.18 = f32[1,2]{1,0} broadcast(f32[1,2]{1,0} %reshape.17), dimensions={0,1}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %reshape.19 = f32[2]{0} reshape(f32[1,2]{1,0} %broadcast.18), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %broadcast.20 = f32[2,2]{1,0} broadcast(f32[2]{0} %reshape.19), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %add.21 = f32[2,2]{1,0} add(f32[2,2]{1,0} %dot.16, f32[2,2]{1,0} %broadcast.20), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  ROOT %tuple.23 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(s32[] %subtract.11, f32[2]{0} %p0.12, f32[2,2]{1,0} %p1.13, f32[2,2]{1,0} %add.21)
}

ENTRY %PyLoweringContext.18.25 (in.1: (f32[2], f32[2,2], f32[2,2], f32[2,2])) -> (s32[], f32[2], f32[2,2], f32[2,2]) {
  %in.1 = (f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = f32[2]{0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  %get-tuple-element.4 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=2
  %get-tuple-element.5 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=3
  ROOT %call.24 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) call(f32[2]{0} %get-tuple-element.2, f32[2,2]{1,0} %get-tuple-element.3, f32[2,2]{1,0} %get-tuple-element.4, f32[2,2]{1,0} %get-tuple-element.5), to_apply=%PyLoweringContext.6
}
  • Condition
HloModule PyLoweringContext.9.16, entry_computation_layout={((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}))->pred[]}

%PyLoweringContext.6 (UnusedArgumentsPlaceholder.11: s32[], UnusedArgumentsPlaceholder.12: f32[2], UnusedArgumentsPlaceholder.13: f32[2,2], UnusedArgumentsPlaceholder.14: f32[2,2]) -> pred[] {
  %constant.8 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/workspaces/pytorch/xla/test/test_while_loop.py" source_line=89}
  %convert.9 = s64[] convert(s32[] %constant.8), metadata={op_type="aten__gt" op_name="aten__gt" source_file="<eval_with_key>.0" source_line=5}
  %constant.7 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.0" source_line=5}
  ROOT %compare.10 = pred[] compare(s64[] %convert.9, s64[] %constant.7), direction=GT, metadata={op_type="aten__gt" op_name="aten__gt" source_file="<eval_with_key>.0" source_line=5}
  %UnusedArgumentsPlaceholder.11 = s32[] parameter(0)
  %UnusedArgumentsPlaceholder.12 = f32[2]{0} parameter(1)
  %UnusedArgumentsPlaceholder.13 = f32[2,2]{1,0} parameter(2)
  %UnusedArgumentsPlaceholder.14 = f32[2,2]{1,0} parameter(3)
}

ENTRY %PyLoweringContext.9.16 (in.1: (s32[], f32[2], f32[2,2], f32[2,2])) -> pred[] {
  %in.1 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = s32[] get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2]{0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  %get-tuple-element.4 = f32[2,2]{1,0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=2
  %get-tuple-element.5 = f32[2,2]{1,0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=3
  ROOT %call.15 = pred[] call(s32[] %get-tuple-element.2, f32[2]{0} %get-tuple-element.3, f32[2,2]{1,0} %get-tuple-element.4, f32[2,2]{1,0} %get-tuple-element.5), to_apply=%PyLoweringContext.6
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants