From 337871c6fd581b74949849ad645064318896801b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 28 May 2023 02:51:42 -0700 Subject: [PATCH] Enable LLaMA fast tokenizer (#132) --- cacheflow/sampling_params.py | 2 +- cacheflow/server/tokenizer_utils.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 031eb82058..b7623ff58d 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -129,7 +129,7 @@ class SamplingParams: f"frequency_penalty={self.frequency_penalty}, " f"temperature={self.temperature}, " f"top_p={self.top_p}, " - f"top_k={self.top_k}," + f"top_k={self.top_k}, " f"use_beam_search={self.use_beam_search}, " f"stop={self.stop}, " f"ignore_eos={self.ignore_eos}, " diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index 6e12249952..8aede295d2 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -7,11 +7,7 @@ from cacheflow.logger import init_logger logger = init_logger(__name__) -_MODEL_TYPES_WITH_SLOW_TOKENIZER = [ - # LLaMA fast tokenizer has a bug related to protobuf. - # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 - "llama", -] +_MODEL_TYPES_WITH_SLOW_TOKENIZER = [] def get_tokenizer( @@ -20,7 +16,15 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: config = AutoConfig.from_pretrained(model_name) - if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: + if config.model_type == "llama" and getattr(kwargs, "use_fast", True): + # LLaMA fast tokenizer causes protobuf errors in some environments. + # However, we found that the below LLaMA fast tokenizer works well in + # most environments. + model_name = "hf-internal-testing/llama-tokenizer" + logger.info( + f"Using the LLaMA fast tokenizer in '{model_name}' to avoid " + "potential protobuf errors.") + elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: if getattr(kwargs, "use_fast", False) == True: raise ValueError( f"Cannot use the fast tokenizer for {config.model_type} due to "