Skip to content

Commit

Permalink
isort imports
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyansh26 committed Nov 17, 2024
1 parent b9ff744 commit 63a63fd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
20 changes: 13 additions & 7 deletions src/medusa.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Iterable, List, Optional, Tuple, Union
import os
import sys
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from torch import nn
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PretrainedConfig, PreTrainedModel)
from transformers.modeling_outputs import CausalLMOutputWithPast

from conversation_format import get_conv_template
from medusa_utils import generate_medusa_buffers, initialize_past_key_values, initialize_medusa, generate_candidates, format_input, reset_medusa_mode, tree_decoding, evaluate_posterior, update_inference_inputs
from medusa_utils import (evaluate_posterior, format_input,
generate_candidates, generate_medusa_buffers,
initialize_medusa, initialize_past_key_values,
reset_medusa_mode, tree_decoding,
update_inference_inputs)
from modeling_llama_kv import LlamaForCausalLM
from typing import Any, Generator


class ResBlock(nn.Module):
"""
Expand Down
7 changes: 4 additions & 3 deletions src/medusa_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sklearn import tree
import torch
from conversation_format import Conversation
import torch.nn.functional as F
from sklearn import tree
from transformers import PreTrainedModel

from conversation_format import Conversation
from eta import eta_sampling_with_temperature
from top_p import top_p_sampling_with_temperature
import torch.nn.functional as F

TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)

Expand Down

0 comments on commit 63a63fd

Please sign in to comment.