Skip to content

Commit

Permalink
Add new models, drop old models, remove default_max_tokens
Browse files Browse the repository at this point in the history
Refs hex#4, hex#5
  • Loading branch information
simonw committed May 8, 2024
1 parent acb4a5d commit 3ce76c2
Showing 1 changed file with 14 additions and 27 deletions.
41 changes: 14 additions & 27 deletions llm_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
from pydantic import Field, field_validator, model_validator
from typing import Optional, List


@llm.hookimpl
def register_models(register):
# https://docs.perplexity.ai/docs/model-cards
register(Perplexity("sonar-small-chat", default_max_tokens=16384), aliases=("pp-small-chat",))
register(Perplexity("sonar-small-online", default_max_tokens=12288), aliases=("pp-small-online",))
register(Perplexity("sonar-medium-chat", default_max_tokens=16384), aliases=("pp-medium-chat",))
register(Perplexity("sonar-medium-online", default_max_tokens=12288), aliases=("pp-medium-online",))
register(Perplexity("codellama-70b-instruct", default_max_tokens=16384), aliases=("pp-70b-instruct",))
register(Perplexity("mistral-7b-instruct", default_max_tokens=16384), aliases=("pp-7b-instruct",))
register(Perplexity("mixtral-8x7b-instruct", default_max_tokens=16384), aliases=("pp-8x7b-instruct",))
register(Perplexity("mixtral-8x22b-instruct", default_max_tokens=16384), aliases=("pp-8x22b-instruct",))
register(Perplexity("llama-3-8b-instruct", default_max_tokens=8192))
register(Perplexity("llama-3-70b-instruct", default_max_tokens=8192))
register(Perplexity("llama-3-sonar-small-32k-chat"))
register(Perplexity("llama-3-sonar-small-32k-online"))
register(Perplexity("llama-3-sonar-large-32k-chat"))
register(Perplexity("llama-3-sonar-large-32k-online"))
register(Perplexity("llama-3-8b-instruct"))
register(Perplexity("llama-3-70b-instruct"))
register(Perplexity("mixtral-8x7b-instruct"), aliases=("pp-8x7b-instruct",))


class PerplexityOptions(llm.Options):
Expand Down Expand Up @@ -54,15 +52,6 @@ class PerplexityOptions(llm.Options):
default=None,
)

@field_validator("max_tokens")
def validate_max_tokens(cls, values):
max_tokens = values.get("max_tokens")
default_max_tokens = values.get("default_max_tokens")
if max_tokens is not None and default_max_tokens is not None:
if not (0 < max_tokens <= default_max_tokens):
raise ValueError(f"max_tokens must be in range 1-{default_max_tokens}")
return max_tokens

@field_validator("temperature")
@classmethod
def validate_temperature(cls, temperature):
Expand All @@ -76,14 +65,14 @@ def validate_top_p(cls, top_p):
if top_p is not None and not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be in range 0.0-1.0")
return top_p

@field_validator("top_k")
@classmethod
def validate_top_k(cls, top_k):
if top_k is not None and top_k <= 0 or top_k > 2048:
raise ValueError("top_k must be in range 0-2048")
return top_k

@model_validator(mode="after")
def validate_temperature_top_p(self):
if self.temperature != 1.0 and self.top_p is not None:
Expand All @@ -97,12 +86,10 @@ class Perplexity(llm.Model):
model_id = "perplexity"
can_stream = True

class Options(PerplexityOptions):
...
class Options(PerplexityOptions): ...

def __init__(self, model_id, default_max_tokens=None):
def __init__(self, model_id):
self.model_id = model_id
self.default_max_tokens = default_max_tokens

def build_messages(self, prompt, conversation) -> List[dict]:
messages = []
Expand Down Expand Up @@ -145,8 +132,8 @@ def execute(self, prompt, stream, response, conversation):
for text in stream:
yield text.choices[0].delta.content
else:
completion = client.chat.completions.create(**kwargs)
yield completion.choices[0].message.content
completion = client.chat.completions.create(**kwargs)
yield completion.choices[0].message.content

def __str__(self):
return f"Perplexity: {self.model_id}"

0 comments on commit 3ce76c2

Please sign in to comment.