You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since we need to pass max_M to kernel, is this ok for us to pass a max_M and run kernel with input that is smaller than Max_M?
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
weight.shape[0], # N
weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
The text was updated successfully, but these errors were encountered:
Hi, thank you for great works.
Since we need to pass max_M to kernel, is this ok for us to pass a max_M and run kernel with input that is smaller than Max_M?
The text was updated successfully, but these errors were encountered: