Skip to content

Commit

Permalink
yapf set blank_line_before_nested_class_or_def to false
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Aug 22, 2024
1 parent ea70c9f commit 14f2846
Show file tree
Hide file tree
Showing 38 changed files with 2 additions and 69 deletions.
3 changes: 2 additions & 1 deletion .style.yapf
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ continuation_indent_width = 2
indent_dictionary_value = True
allow_multiline_dictionary_keys = True
each_dict_entry_on_separate_line = False
allow_multiline_lambdas = True
allow_multiline_lambdas = True
blank_line_before_nested_class_or_def = False
5 changes: 0 additions & 5 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class Message:

def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
self.content = content
Expand All @@ -28,7 +27,6 @@ def to_dict(self):


class ChatCompletionRequest:

def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
Expand Down Expand Up @@ -148,15 +146,13 @@ def parse_chat_request(data: dict):


class PromptSession:

def __init__(self, request_id: str, timestamp: int, prompt: str):
self.request_id = request_id
self.timestamp = timestamp
self.prompt = prompt


class ChatGPTAPI:

def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
Expand All @@ -183,7 +179,6 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
self.app.middlewares.append(self.log_request)

async def log_request(self, app, handler):

async def middleware(request):
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
return await handler(request)
Expand Down
2 changes: 0 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class HFShardDownloader(ShardDownloader):

def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
Expand Down Expand Up @@ -63,7 +62,6 @@ async def ensure_shard(self, shard: Shard) -> Path:
self.active_downloads.pop(shard)

async def _download_shard(self, shard: Shard) -> Path:

async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)

Expand Down
1 change: 0 additions & 1 deletion exo/download/shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class ShardDownloader(ABC):

@abstractmethod
async def ensure_shard(self, shard: Shard) -> Path:
"""
Expand Down
3 changes: 0 additions & 3 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def terminal_link(uri, label=None):


class AsyncCallback(Generic[T]):

def __init__(self) -> None:
self.condition: asyncio.Condition = asyncio.Condition()
self.result: Optional[Tuple[T, ...]] = None
Expand All @@ -118,7 +117,6 @@ async def notify(self) -> None:


class AsyncCallbackSystem(Generic[K, T]):

def __init__(self) -> None:
self.callbacks: Dict[K, AsyncCallback[T]] = {}

Expand All @@ -145,7 +143,6 @@ def trigger_all(self, *args: T) -> None:


class PrefixDict(Generic[K, V]):

def __init__(self):
self.items: Dict[K, V] = {}

Expand Down
1 change: 0 additions & 1 deletion exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class InferenceEngine(ABC):

@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass
Expand Down
1 change: 0 additions & 1 deletion exo/inference/mlx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@


class IdentityBlock(nn.Module):

def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
return x
2 changes: 0 additions & 2 deletions exo/inference/mlx/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __post_init__(self):


class DeepseekV2Model(nn.Module):

def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
Expand Down Expand Up @@ -71,7 +70,6 @@ def __call__(


class Model(nn.Module):

def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
Expand Down
2 changes: 0 additions & 2 deletions exo/inference/mlx/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __post_init__(self):


class LlamaModel(nn.Module):

def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
Expand Down Expand Up @@ -70,7 +69,6 @@ def __call__(


class Model(nn.Module):

def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
Expand Down
14 changes: 0 additions & 14 deletions exo/inference/mlx/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def from_dict(cls, params):


class VisionAttention(nn.Module):

def __init__(
self,
dims: int,
Expand Down Expand Up @@ -86,7 +85,6 @@ def __call__(self, queries, keys, values, mask=None):


class VisionMLP(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()
self.activation_fn = nn.GELU(approx="fast")
Expand All @@ -100,7 +98,6 @@ def __call__(self, x: mx.array) -> mx.array:


class VisionEncoderLayer(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
Expand All @@ -119,14 +116,12 @@ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:


class VisionEncoder(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()
self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]


class VisionEmbeddings(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
Expand Down Expand Up @@ -160,7 +155,6 @@ def __call__(self, x: mx.array) -> mx.array:


class ClipVisionModel(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()
self.embeddings = VisionEmbeddings(config)
Expand Down Expand Up @@ -188,7 +182,6 @@ def __call__(


class VisionModel(nn.Module):

def __init__(self, config: VisionConfig):
super().__init__()

Expand Down Expand Up @@ -258,7 +251,6 @@ def __post_init__(self):


class TextAttention(nn.Module):

def __init__(self, config: TextConfig):
super().__init__()

Expand Down Expand Up @@ -313,7 +305,6 @@ def __call__(


class TextMLP(nn.Module):

def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
Expand All @@ -325,7 +316,6 @@ def __call__(self, x) -> mx.array:


class TransformerBlock(nn.Module):

def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
Expand All @@ -350,7 +340,6 @@ def __call__(


class Llama(nn.Module):

def __init__(self, config: TextConfig, shard: Shard):
super().__init__()
self.config = config
Expand Down Expand Up @@ -404,7 +393,6 @@ def __call__(


class LanguageModel(nn.Module):

def __init__(self, config: TextConfig, shard: Shard):
super().__init__()
self.model_type = config.model_type
Expand Down Expand Up @@ -486,7 +474,6 @@ def __post_init__(self):


class LlavaMultiModalProjector(nn.Module):

def __init__(self, config: LlaVAConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
Expand All @@ -501,7 +488,6 @@ def __call__(self, x: mx.array) -> mx.array:


class Model(nn.Module):

def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
Expand Down
1 change: 0 additions & 1 deletion exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class MLXDynamicShardInferenceEngine(InferenceEngine):

def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
Expand Down
2 changes: 0 additions & 2 deletions exo/inference/mlx/sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class StatefulShardedModel:

def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
self.shard = shard
self.model = model
Expand All @@ -27,7 +26,6 @@ def step(
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:

def sample(logits: mx.array) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
Expand Down
1 change: 0 additions & 1 deletion exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class ModelNotFoundError(Exception):

def __init__(self, message):
self.message = message
super().__init__(self.message)
Expand Down
1 change: 0 additions & 1 deletion exo/inference/mlx/test_sharded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class DummyModel(nn.Module):

def __init__(self, shard: Optional[Shard] = None):
self.shard = shard
self.layers = [
Expand Down
1 change: 0 additions & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No


class TinygradDynamicShardInferenceEngine(InferenceEngine):

def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
Expand Down
5 changes: 0 additions & 5 deletions exo/inference/tinygrad/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:


class Attention:

def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
Expand Down Expand Up @@ -88,7 +87,6 @@ def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor


class FeedForward:

def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
Expand All @@ -99,7 +97,6 @@ def __call__(self, x: Tensor) -> Tensor:


class TransformerBlock:

def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
self.feed_forward = feed_forward(dim, hidden_dim, linear)
Expand Down Expand Up @@ -165,7 +162,6 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):


class Transformer:

def __init__(
self,
dim: int,
Expand Down Expand Up @@ -222,7 +218,6 @@ def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0


def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):

def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])

Expand Down
1 change: 0 additions & 1 deletion exo/inference/tinygrad/tinygrad_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

# **** helper functions ****
def concat_weights(models, device=None):

def convert(name) -> Tensor:
disk_tensors: List[Tensor] = [model[name] for model in models]
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
Expand Down
1 change: 0 additions & 1 deletion exo/networking/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Discovery(ABC):

@abstractmethod
async def start(self) -> None:
pass
Expand Down
2 changes: 0 additions & 2 deletions exo/networking/grpc/grpc_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class ListenProtocol(asyncio.DatagramProtocol):

def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
super().__init__()
self.on_message = on_message
Expand All @@ -25,7 +24,6 @@ def datagram_received(self, data, addr):


class GRPCDiscovery(Discovery):

def __init__(
self,
node_id: str,
Expand Down
1 change: 0 additions & 1 deletion exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class GRPCPeerHandle(PeerHandle):

def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
self._id = _id
self.address = address
Expand Down
1 change: 0 additions & 1 deletion exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):

def __init__(self, node: Node, host: str, port: int):
self.node = node
self.host = host
Expand Down
Loading

0 comments on commit 14f2846

Please sign in to comment.