From 776a9c22936c0bd0f1e17b44bf269e64cba62166 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 26 Jul 2023 21:30:23 +0530 Subject: [PATCH] Fix for Langchain (#1694) For CPU, remove max time stopping criteria Fix web UI issue --- .flake8 | 2 +- .../langchain/h2oai_pipeline.py | 27 ++++++++++++++++--- .../langchain/langchain_requirements.txt | 6 ++--- apps/stable_diffusion/web/index.py | 2 +- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/.flake8 b/.flake8 index e97d80b17c..a1ecf37d07 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ count = 1 show-source = 1 select = E9,F63,F7,F82 -exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py +exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py diff --git a/apps/language_models/langchain/h2oai_pipeline.py b/apps/language_models/langchain/h2oai_pipeline.py index 8f09cb486f..aaac360a21 100644 --- a/apps/language_models/langchain/h2oai_pipeline.py +++ b/apps/language_models/langchain/h2oai_pipeline.py @@ -30,7 +30,15 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model from brevitas_examples.llm.llm_quant.run_utils import get_model_impl -def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + +def brevitas〇matmul_rhs_group_quant〡shape( + lhs: List[int], + rhs: List[int], + rhs_scale: List[int], + rhs_zero_point: List[int], + rhs_bit_width: int, + rhs_group_size: int, +) -> List[int]: if len(lhs) == 3 and len(rhs) == 2: return [lhs[0], lhs[1], rhs[0]] elif len(lhs) == 2 and len(rhs) == 2: @@ -39,20 +47,30 @@ def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rh raise ValueError("Input shapes not supported.") -def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: +def brevitas〇matmul_rhs_group_quant〡dtype( + lhs_rank_dtype: Tuple[int, int], + rhs_rank_dtype: Tuple[int, int], + rhs_scale_rank_dtype: Tuple[int, int], + rhs_zero_point_rank_dtype: Tuple[int, int], + rhs_bit_width: int, + rhs_group_size: int, +) -> int: # output dtype is the dtype of the lhs float input lhs_rank, lhs_dtype = lhs_rank_dtype return lhs_dtype -def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: +def brevitas〇matmul_rhs_group_quant〡has_value_semantics( + lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size +) -> None: return brevitas_matmul_rhs_group_quant_library = [ brevitas〇matmul_rhs_group_quant〡shape, brevitas〇matmul_rhs_group_quant〡dtype, - brevitas〇matmul_rhs_group_quant〡has_value_semantics] + brevitas〇matmul_rhs_group_quant〡has_value_semantics, +] global_device = "cuda" global_precision = "fp16" @@ -541,6 +559,7 @@ def generate_new_token(self): return next_token def generate_token(self, **generate_kwargs): + del generate_kwargs["max_time"] self.truncated_input_ids = [] generation_config_ = GenerationConfig.from_model_config( diff --git a/apps/language_models/langchain/langchain_requirements.txt b/apps/language_models/langchain/langchain_requirements.txt index b301a373e6..78bd6e7562 100644 --- a/apps/language_models/langchain/langchain_requirements.txt +++ b/apps/language_models/langchain/langchain_requirements.txt @@ -1,12 +1,10 @@ # for generate (gradio server) and finetune datasets==2.13.0 sentencepiece==0.1.99 -# gradio==3.37.0 huggingface_hub==0.16.4 appdirs==1.4.4 fire==0.5.0 docutils==0.20.1 -# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64" evaluate==0.4.0 rouge_score==0.1.2 sacrebleu==2.3.1 @@ -21,7 +19,7 @@ bitsandbytes==0.39.0 accelerate==0.20.3 peft==0.4.0 # 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026) -# transformers==4.30.2 +transformers==4.30.2 tokenizers==0.13.3 APScheduler==3.10.1 @@ -67,7 +65,7 @@ tiktoken==0.4.0 openai==0.27.8 # optional for chat with PDF -langchain==0.0.235 +langchain==0.0.202 pypdf==3.12.2 # avoid textract, requires old six #textract==1.6.5 diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 2d93f4d087..9c7fbff560 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -244,7 +244,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): upscaler_status, ] ) - with gr.TabItem(label="DocuChat(Experimental)", id=9): + with gr.TabItem(label="DocuChat(Experimental)", id=10): h2ogpt_web.render() # send to buttons