Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PjRtComputationClient::ExecuteReplicated core dump when encountering a scalar #8057

Open
mars1248 opened this issue Sep 24, 2024 · 6 comments

Comments

@mars1248
Copy link
Contributor

❓ Questions and Help

In my test code, I found that there might be PjRtData as the type argument(the argument is a scalar), and then the core dump.
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/runtime/pjrt_computation_client.cc#L806
I wrote a test function earlier that tried to transform all arguments manually, but core dumped.
image
image

@JackCaoG
Copy link
Collaborator

hmmm why is the last data's device is CUDA:0 instead of SPMD:0, under SPMD mode all tensor should be in SPMD:0 device. Do you have a small repo?

@mars1248
Copy link
Contributor Author

This can be replicated with the following ut, the 8*a100 environment I used

from transformers import AutoModel
import torch
import torch.optim as optim
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2

import numpy as np
import torch_xla.debug.profiler as xp
import time
import os
# Setup profiler env var
os.environ['XLA_HLO_DEBUG'] = '1'
server = xp.start_server(9012)
# Enable XLA SPMD execution mode.
xr.use_spmd()
model = AutoModel.from_pretrained("/root/OpenGVLab__InternViT-6B-448px-V1-5", low_cpu_mem_usage=True, device_map="cpu", trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(xm.xla_device())
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices // 1, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))
def shard_output(output, mesh):
    xs.mark_sharding(output.last_hidden_state[:,1:], mesh, ('fsdp', None, None))
model = FSDPv2(model, mesh, shard_output)
image = torch.rand([8, 3, 448, 448]).to(xm.xla_device(), torch.bfloat16)
xs.mark_sharding(image, mesh, (('fsdp', 'replica'), ) + (None, None, None))
with autocast(xm.xla_device(), dtype=torch.bfloat16):
    output = model(image)
    loss = output.last_hidden_state[:,1:].sum()
xm.unlazy([loss])
print("ans is ", loss)

@zjjott
Copy link

zjjott commented Sep 27, 2024

@JackCaoG

@JackCaoG
Copy link
Collaborator

busy this week, will try to find some time to take a look

@zjjott
Copy link

zjjott commented Oct 11, 2024

I have a more easy code @JackCaoG ,pure torch/torch_xla

import torch
import torch.optim as optim
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla.distributed.spmd.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch import nn
import numpy as np
import torch_xla.debug.profiler as xp
import time
import os
class InternRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # cause here + self.variance_epsilon
        return self.weight * hidden_states.to(input_dtype)
os.environ['XLA_HLO_DEBUG'] = '1'
server = xp.start_server(9012)
# Enable XLA SPMD execution mode.
xr.use_spmd()
model = InternRMSNorm(448)
model = model.to(xm.xla_device())
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices // 1, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))
def shard_output(output, mesh):
    xs.mark_sharding(output, mesh, ('fsdp', None, None))
model = FSDPv2(model, mesh=mesh, shard_output=shard_output)
image = torch.rand([64, 448, 448]).to(xm.xla_device(), torch.bfloat16)
xs.mark_sharding(image, mesh, (('fsdp', 'replica'), ) + (None, None))
with autocast(xm.xla_device(), dtype=torch.bfloat16):
    output = model(image)
    loss = output.sum()#output.last_hidden_state[:,1:].sum()
xm.unlazy([loss])
print("ans is ", loss)

@zjjott
Copy link

zjjott commented Oct 12, 2024

server = xp.start_server(9012) cause segmantfault,when quote this line, this code work fine;
must start_server after use_spmd(),otherwise it'll segmentfault

# incorrect
server = xp.start_server(9012)
xr.use_spmd()
# correct
xr.use_spmd()
server = xp.start_server(9012)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants