Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generation from input embedding #1265

Open
wants to merge 33 commits into
base: main
Choose a base branch
from

Conversation

pfldy2850
Copy link
Contributor

@pfldy2850 pfldy2850 commented Oct 5, 2023

This PR implements the feature of generating text from embedding input (popularly known as inputs_embeds).
This is related to #369 and #416.

More to do

  • Enhance test codes for generate.
  • Determine whether the feature reduces core performance.
  • Add more details to the comments.
  • apply it to async_llm_engine and api_server

@pfldy2850 pfldy2850 changed the title [WIP] Support generate from input embedding [WIP] Support generation from input embedding Oct 12, 2023
@pfldy2850
Copy link
Contributor Author

We conducted several tests and confirmed that the performance degradation was not significant.

In fact, we measured the benchmark 5 times for the main branch and feature branch using the command below.

python benchmarks/benchmark_latency.py --input-len=2048 --num-iters=5

## main
Avg latency: 0.36247589644044637 seconds
Avg latency: 0.35677395705133674 seconds
Avg latency: 0.3622682703658938 seconds
Avg latency: 0.36043337155133487 seconds
Avg latency: 0.3593990854918957 seconds

## feature
Avg latency: 0.3586543008685112 seconds
Avg latency: 0.3557318979874253 seconds
Avg latency: 0.36645207908004523 seconds
Avg latency: 0.3598199490457773 seconds
Avg latency: 0.36111502479761837 seconds

@pfldy2850 pfldy2850 changed the title [WIP] Support generation from input embedding Support generation from input embedding Oct 12, 2023
Copy link

@bobchen1980 bobchen1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed input embedding function related to #369 and #416

@pfldy2850
Copy link
Contributor Author

@WoosukKwon @zhuohan123

Hello authors, I have tested this PR and completed the alignment with the latest prepare_inputs function.
Could you please review this PR?

@WoosukKwon WoosukKwon mentioned this pull request Nov 2, 2023
3 tasks
@js8544
Copy link
Contributor

js8544 commented Jan 3, 2024

We've been using this branch in production and it works like a charm. Thanks so much for your contribution. Can't wait for it to be merged!

@fedshyvana
Copy link

thanks for this! Any plan to merge this into main anytime soon?

@pfldy2850
Copy link
Contributor Author

Hello @zhuohan123 ,

I just saw that you created an issue for the vLLM Q1 2024 roadmap.

If you have any plans to consider this feature or merge for this PR,
I would like to resume the updating work for that PR.

@matankley
Copy link

This PR would be super valuable for us. @pfldy2850 Do you plan to adjust it to the current master branch ? Because I see it is a bit outdated.

- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prompt_embeds = request_dict.pop("prompt_embeds", None)
if prompt_embeds is not None:
prompt_embeds = torch.tensor(prompt_embeds).to("cuda")
Copy link

@bks5881 bks5881 Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loads stuff in float32. Eats all the GPU.

@@ -29,16 +30,27 @@ async def generate(request: Request) -> Response:

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- prompt_embeds: the prompt embedding to use for the generation
instead of the prompt.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an error when only prompt_embeds are passed.

Copy link

@bks5881 bks5881 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for doing this. I tested it and had some issues I ran into but fixed them locally.
Also, for some reason when serializing i got torch.cuda.is_available() as false. so had to set CUDA_VISBILE_DEVICES in ray. init.py

@tweeter0830
Copy link

@zhuohan123 Do you have plans for this? It would be really helpful to me for this MR to get merged. I can help push it through if you need.

@zhuohan123
Copy link
Member

@zhuohan123 Do you have plans for this? It would be really helpful to me for this MR to get merged. I can help push it through if you need.

We are doing this in this PR for llava support: #3042. Please take a look and let us know any suggestions!

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 30, 2024
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @pfldy2850 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants