-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error deploying GPTQ models to sagemaker #235
Comments
Hi @GlacierPurpleBison, taking a look. Do you know if this bug is a SageMaker-specific bug, or if this occurs when initializing a vanilla LoRAX docker container as well? |
Hi @geoffreyangus, it isn't working on sagemaker notebooks using docker directly as well. It seems to be getting stuck at waiting for the shard to be ready. When i run the client, i am getting connection refused error. When I use tekium's Openhermes directly I am able to connect to the client and the output is as expected. ` 2024-02-16T07:04:33.693079Z INFO download: lorax_launcher: Successfully downloaded weights. 2024-02-16T07:05:53.758610Z INFO shard-manager: lorax_launcher: Waiting for shard to be ready... rank=0 I am unfortunately not able to verify this outside of sagemaker since I have a mac and hence can't test it locally for consistency. I am using the same A10G machine for sagemaker notebook instance. But you'll notice that i am not getting the CUDA Triton error here. |
Btw, i am not getting any error when i try and deploy AWQ. Only getting this with GPTQ |
hey @geoffreyangus - were you able to check further on this? |
Hi @GlacierPurpleBison– I apologize, I took President's day weekend off. I'll take a look at it this week, thanks! |
System Info
I have used the following guide to deploy lorax to sagemaker. I am able to do so successfully using the unquantized models. Have deployed OpenHermes 2.5 successfully. However when i try GPTQ version of Openhermes or Mixtral, I am consistently getting the following error:
> File "/opt/conda/lib/python3.10/site-packages/lorax_server/interceptor.py", line 38, in intercept return await response File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor raise error File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor return await behavior(request_or_iterator, context) File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 79, in Warmup max_supported_total_tokens = self.model.warmup(batch) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 726, in warmup _, batch = self.generate_token(batch) File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 855, in generate_token raise e File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 852, in generate_token out = self.forward(batch, adapter_data) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_mixtral.py", line 426, in forward logits = model.forward( File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mixtral_modeling.py", line 979, in forward hidden_states = self.model( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mixtral_modeling.py", line 922, in forward hidden_states, residual = layer( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mixtral_modeling.py", line 849, in forward attn_output = self.self_attn( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_mixtral_modeling.py", line 379, in forward qkv = self.query_key_value(hidden_states, adapter_data) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/layers.py", line 601, in forward result = self.base_layer(input) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/layers.py", line 399, in forward return self.linear.forward(x) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/quant_linear.py", line 349, in forward out = QuantLinearFunction.apply( File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 123, in decorate_fwd return fwd(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/quant_linear.py", line 244, in forward output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/quant_linear.py", line 216, in matmul248 matmul_248_kernel[grid]( File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/custom_autotune.py", line 110, in run timings = { File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/custom_autotune.py", line 111, in <dictcomp> config: self._bench(*args, config=config, **kwargs) File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/custom_autotune.py", line 90, in _bench return triton.testing.do_bench( File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench fn() File "/opt/conda/lib/python3.10/site-packages/lorax_server/utils/gptq/custom_autotune.py", line 80, in kernel_call self.fn.run( File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run bin.c_wrapper( File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in __getattribute__ self._init_handles() File "/opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _init_handles mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) RuntimeError: Triton Error [CUDA]: device kernel image is invalid
I am using the latest lorax image, and unable to figure out how to resolve this. Can you someone please help in figuring this out?
Information
Tasks
Reproduction
Expected behavior
Expected successful deployment, but not working with GPTQ.
The text was updated successfully, but these errors were encountered: