-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PjRtComputationClient::ExecuteReplicated core dump when encountering a scalar #8057
Comments
hmmm why is the last data's device is |
This can be replicated with the following ut, the 8*a100 environment I used
|
busy this week, will try to find some time to take a look |
I have a more easy code @JackCaoG ,pure torch/torch_xla import torch
import torch.optim as optim
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla.distributed.spmd.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch import nn
import numpy as np
import torch_xla.debug.profiler as xp
import time
import os
class InternRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # cause here + self.variance_epsilon
return self.weight * hidden_states.to(input_dtype)
os.environ['XLA_HLO_DEBUG'] = '1'
server = xp.start_server(9012)
# Enable XLA SPMD execution mode.
xr.use_spmd()
model = InternRMSNorm(448)
model = model.to(xm.xla_device())
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices // 1, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))
def shard_output(output, mesh):
xs.mark_sharding(output, mesh, ('fsdp', None, None))
model = FSDPv2(model, mesh=mesh, shard_output=shard_output)
image = torch.rand([64, 448, 448]).to(xm.xla_device(), torch.bfloat16)
xs.mark_sharding(image, mesh, (('fsdp', 'replica'), ) + (None, None))
with autocast(xm.xla_device(), dtype=torch.bfloat16):
output = model(image)
loss = output.sum()#output.last_hidden_state[:,1:].sum()
xm.unlazy([loss])
print("ans is ", loss) |
# incorrect
server = xp.start_server(9012)
xr.use_spmd()
# correct
xr.use_spmd()
server = xp.start_server(9012) |
❓ Questions and Help
In my test code, I found that there might be PjRtData as the type argument(the argument is a scalar), and then the core dump.
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/runtime/pjrt_computation_client.cc#L806
I wrote a test function earlier that tried to transform all arguments manually, but core dumped.
The text was updated successfully, but these errors were encountered: