Skip to content

Commit

Permalink
Add max_tokens argument
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-doerr committed Sep 6, 2021
1 parent 342f6a5 commit 7df174c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
6 changes: 4 additions & 2 deletions plugin/vim_codex.vim
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ EOF



function! CreateCompletion()
function! CreateCompletion(max_tokens)
python3 plugin.create_completion()
endfunction

function! CreateCompletionLine()
python3 plugin.create_completion(stop='\n')
endfunction

command! -nargs=0 CreateCompletion call CreateCompletion()


command! -nargs=? CreateCompletion call CreateCompletion(<q-args>)
command! -nargs=0 CreateCompletionLine call CreateCompletionLine()

map <Leader>co :CreateCompletion<CR>
Expand Down
25 changes: 19 additions & 6 deletions python/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,45 @@
openai.api_key = SECRET_KEY
MAX_SUPPORTED_INPUT_LENGTH = 4096
USE_STREAM_FEATURE = True
MAX_TOKENS_DEFAULT = 64

def complete_input_max_length(input_prompt, max_input_length=MAX_SUPPORTED_INPUT_LENGTH, stop=None):
def complete_input_max_length(input_prompt, max_input_length=MAX_SUPPORTED_INPUT_LENGTH, stop=None, max_tokens=64):
input_prompt = input_prompt[-max_input_length:]

response = openai.Completion.create(engine='davinci-codex', prompt=input_prompt, best_of=1, temperature=0.5, max_tokens=512, stream=USE_STREAM_FEATURE, stop=stop)
response = openai.Completion.create(engine='davinci-codex', prompt=input_prompt, best_of=1, temperature=0.5, max_tokens=max_tokens, stream=USE_STREAM_FEATURE, stop=stop)
return response

def complete_input(input_prompt, stop):
def complete_input(input_prompt, stop, max_tokens):
try:
response = complete_input_max_length(input_prompt, int(2.5 * MAX_SUPPORTED_INPUT_LENGTH), stop=stop)
response = complete_input_max_length(input_prompt, int(2.5 * MAX_SUPPORTED_INPUT_LENGTH), stop=stop, max_tokens=max_tokens)
except openai.error.InvalidRequestError:
response = complete_input_max_length(input_prompt, MAX_SUPPORTED_INPUT_LENGTH, stop=stop)
# print('Using shorter input.')

return response

def get_max_tokens():
max_tokens = None
if vim.eval('exists("a:max_tokens")') == '1':
max_tokens_str = vim.eval('a:max_tokens')
if max_tokens_str:
max_tokens = int(max_tokens_str)

if not max_tokens:
max_tokens = MAX_TOKENS_DEFAULT

return max_tokens


def create_completion(stop=None):
max_tokens = get_max_tokens()
vim_buf = vim.current.buffer
input_prompt = '\n'.join(vim_buf[:])

row, col = vim.current.window.cursor
input_prompt = '\n'.join(vim_buf[row:])
input_prompt += '\n'.join(vim_buf[:row-1])
input_prompt += '\n' + vim_buf[row-1][:col]
response = complete_input(input_prompt, stop=stop)
response = complete_input(input_prompt, stop=stop, max_tokens=max_tokens)
write_response(response, stop=stop)

def write_response(response, stop):
Expand Down

0 comments on commit 7df174c

Please sign in to comment.