Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328

Open
wants to merge 84 commits into
base: main
Choose a base branch
from

Conversation

vmpuri
Copy link
Contributor

@vmpuri vmpuri commented Oct 24, 2024

Replace the WeightOnlyInt8Linear quantization code with TorchAO's int8_weight_only quantization.

Note - this commit also contains lintrunner changes.

Testing:

python3 torchchat.py eval llama3.2-1b --quantize '{"linear:int8": {"groupsize": 0}, "executor":{"accelerator":"cuda"}}' --compile
Using device=cuda
Loading model...
Time to load model: 1.21 seconds
Quantizing the model with: {'linear:int8': {'groupsize': 0}, 'executor': {'accelerator': 'cuda'}}
quantizer is linear int8
Time to quantize model: 0.31 seconds
-----------------------------------------------------------
2024-10-24:15:55:20,261 INFO     [huggingface.py:162] Using device 'cuda'
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:55:27,792 WARNING  [task.py:763] [Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
2024-10-24:15:55:27,792 WARNING  [task.py:775] [Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:55:28,687 WARNING  [repocard.py:108] Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:55:28,760 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 501.80it/s]
2024-10-24:15:55:28,889 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:10<00:00,  1.13s/it]
Time to run eval: 78.96s.
Time in model.forward: 62.57s, over 162 model evaluations
forward run time stats - Median: 0.00s Min: 0.00s Max: 41.80s
For model /home/puri/.torchchat/model-cache/meta-llama/Meta-Llama-3.2-1B-Instruct/model.pth
wikitext:
 word_perplexity,none: 19.2032
 byte_perplexity,none: 1.7378
 bits_per_byte,none: 0.7973
 alias: wikitext

From current master:

python3 torchchat.py eval llama3.2-1b --quantize '{"linear:int8": {"groupsize": 0}, "executor":{"accelerator":"cuda"}}' --compile
Using device=cuda
Loading model...
Time to load model: 1.20 seconds
Quantizing the model with: {'linear:int8': {'groupsize': 0}, 'executor': {'accelerator': 'cuda'}}
Time to quantize model: 0.19 seconds
-----------------------------------------------------------
2024-10-24:15:43:59,945 INFO     [huggingface.py:162] Using device 'cuda'
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
2024-10-24:15:44:07,664 WARNING  [task.py:763] [Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
2024-10-24:15:44:07,664 WARNING  [task.py:775] [Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:44:09,261 WARNING  [repocard.py:108] Repo card metadata block was not found. Setting CardData to empty.
2024-10-24:15:44:09,342 INFO     [task.py:395] Building contexts for wikitext on rank 0...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:00<00:00, 463.50it/s]
2024-10-24:15:44:09,482 INFO     [evaluator.py:362] Running loglikelihood_rolling requests
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:00<00:00,  1.03it/s]
Time to run eval: 70.16s.
Time in model.forward: 53.46s, over 162 model evaluations
forward run time stats - Median: 0.00s Min: 0.00s Max: 33.02s
For model /home/puri/.torchchat/model-cache/meta-llama/Meta-Llama-3.2-1B-Instruct/model.pth
wikitext:
 word_perplexity,none: 19.2432
 byte_perplexity,none: 1.7385
 bits_per_byte,none: 0.7978
 alias: wikitext

Lint

pip install -r install/requirements-lintrunner.txt 
lintrunner -a

Copy link

pytorch-bot bot commented Oct 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1328

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 24, 2024
@vmpuri vmpuri force-pushed the torchao_int8_weight_only branch from d43d52e to 92e0a9d Compare October 24, 2024 22:52
@vmpuri vmpuri marked this pull request as ready for review October 24, 2024 22:57
@jerryzh168
Copy link
Contributor

jerryzh168 commented Oct 24, 2024

thanks! can you add a generate.py speed benchmark result for before and after as well

# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("quantizer is linear int8")

"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably use None for now, and remove this later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check for int8_weight_only and finished check before it looks at the table I think

@vmpuri can you check?

@Jack-Khuu
Copy link
Contributor

Can you ack that the numerics look good for MPS and CPU as well?

# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
quantize_(model, int8_weight_only())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, quantizer API is deprecated in favor of quantize_, that's why we are gradually refactoring the quantizer APIs to use quantize_, the reason we do it one by one is because there might be missing support/alignment on numerics etc. that we need to do during the migration

return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Int 8 seems like it special cased for ET, reminder to check that as well

leseb and others added 11 commits February 4, 2025 13:50
The previous Python version check was incorrect, allowing installations
on unsupported interpreter versions, which caused installation failures.
Additionally, we now respect the specified interpreter version if
provided, consistently using it throughout the installation process by
enforcing it with pip.

Signed-off-by: Sébastien Han <[email protected]>
* toeknizer was missing an include

* fix a nit

---------

Co-authored-by: Jesse <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
…resses ctrl+c (#1352)

Setup a SIGINT handler to gracefully exit the program once the user
presses ctrl+c.

Signed-off-by: Sébastien Han <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
…ed values (#1359)

* Update cli.py to make --device/--dtype pre-empt quantize dict-specified values

Users may expect that cli parameters override the JSON, as per #1278.  
Invert logic - case split: 
1 - if none (no value) is specified, use value specified in quantize dict, if present; else
2 - if value is specified, override the respective handler if present.

* Fix typo in cli.py

fix typo

---------

Co-authored-by: Jack-Khuu <[email protected]>
…#1368)

* Update install_requirements.sh to support python 3.10 >= , <3.13

* Update install_requirements.sh

* Update install_requirements.sh
`gguf` was listed twice on the dependency list.

Signed-off-by: Sébastien Han <[email protected]>
If the chat is exited or interrupted it will still print the stats with
NaN values which is unnecessary.

Signed-off-by: Sébastien Han <[email protected]>
)

Let's gracefully fail if no model is given to the `download` command.

Signed-off-by: Sébastien Han <[email protected]>
angelayi and others added 29 commits February 4, 2025 13:52
* Bumping ET Pin to Jan15 2025

pytorch/executorch@d596cd7

* Remove call to capture_pre_auto_grad_graph

* Update naming for sdpa to custom ops

* Fix export_for_train

* Update et-pin.txt

* Update Test perms
Fix typo

Co-authored-by: Jack-Khuu <[email protected]>
fiux typo as separate PR, as per @malfet

Co-authored-by: Jack-Khuu <[email protected]>
* add xpu

* add xpu device

* update

* profile

* update install

* update

* update

* update

---------

Co-authored-by: Jack-Khuu <[email protected]>
Co-authored-by: Guoqiong <[email protected]>
* Create run-readme-pr-linuxaarch64

Test torchchat on aarch64 linux

* Rename run-readme-pr-linuxaarch64 to run-readme-pr-linuxaarch64.yml

add yml extension.

* Update ADVANCED-USERS.md

Update doc to indicate testing for ARMv8/aarch64 on Linux/raspbian is introduced by this PR

---------

Co-authored-by: Jack-Khuu <[email protected]>
* Update install_requirements.sh

* Update pytorch minor version

* Update install_requirements.sh
* Add warning in PTEModel when not defined

* Add missing parans
bump this into the constructor of BuilderArgs

Co-authored-by: Jack-Khuu <[email protected]>
`attention_backend` is a SDPBackend, not a string
* Update run-readme-pr-linuxaarch64.yml to use correct runner

* Move to linux.arm64.m7g.4xlarge

* Explicitly overriding the docker-image

* Bumping Cuda version to 12.6

* Updating GPU Arch type

* Testing various linux_job combos: v2 cuda, v2 cpu, v1 cpu

* Adding permissions to linux job v2

* Switch everything to CPU linux v2

* Test with devtoolset-11

* Remove devtoolset install

* Removing devtoolset from commands
* Add encoded size to start_pos

* Only in chat mode

---------

Co-authored-by: nlpfollower <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
)

ExecuTorch now has XNN pybinding built by default pytorch/executorch#7473

Previously it was not built by default
In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301

Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization

Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times.
- mul is called 275 times per token
- copy_ is called 202 times per token
- add is called 97 times per token
- where is called 34 times per token
- mean is called 33 times per token
- rsqrt is called 33 times per token
- sub is called 32 times per token
- cat is called 32 times per token
- stack is called 32 times per token

Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this:
```
norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
output = norm(x.float()).type_as(x) * weight
```
This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.
…1409)

* Add evaluation, multimodal, native tests to run-readme-pr-macos.yml

Add evaluation, multimodal, native tests to run-readme-pr-macos.yml

* Update run-readme-pr-mps.yml

* Update build_native.sh

Update to C++11 ABI for AOTI, similar to ET

* Update run-readme-pr-macos.yml

fix typo

---------

Co-authored-by: Jack-Khuu <[email protected]>
)

* Add evaluation, multimodal, native tests to run-readme-pr-mps.yml

Add evaluation, multimodal, native tests to run-readme-pr-mps.yml

* Update run-readme-pr-mps.yml

Typos

* Update run-readme-pr-mps.yml

* Update run-readme-pr-mps.yml

Extend timeout for test-readme-mps to avoid test failing from timeout.

* Update build_native.sh

Update to C++11 ABI for AOTI, similar to ET

---------

Co-authored-by: Jack-Khuu <[email protected]>
…ng to MPS (#1417)

* bandaid for run-readme-pr-macos.yml incorrectly loading to MPS

as per #1416 torchchat on hosts without MPS (which is all github hosts which use kvm to virtualize MacOS, but not MPS) should choose CPU as "fast" device.  The logic is present (see discussion in #1416 ), but either not fully functional (that would be the easier one to fix, just print the result of get_device_str and fix the code!) or specifically ignored on load in torch/serialization.py (If this is the case, we're effectively looking at a core PyTorch issue....)

In the meantime, this bandaid just forces the use of CPU on MacOS tests, to make MacOS tests run on CPU -- labeit hsortcircuiting test/execution of the "fast" device logic.  Not ideal, but some testing beats no testing.

* Update run-readme-pr-macos.yml

Add informational message to MacOS CPU tests

* Update build_native.sh

Update to C++11 ABI for AOTI, similar to ET

---------

Co-authored-by: Jack-Khuu <[email protected]>
* Add distributed tests to run-readme-pr.yml

Need to ensure this is the right runner, @lessw2020 can you please have a look -- torchchat uses the same runners as pytorch.

* Update run-docs

Remove HF login because tokens not available as git secret

* Update run-docs

Replace llama3.1 with open-llama to avoid need for token.
If this turns out running too long, then we can switch to stories110M

* Update run-docs

open-llama -> stories.
* Update run-docs to avoid duplicate code

Update run-docs to avoid duplicate code

* Update run-docs

Add back command explaining seemingly extraneous `echo exit 1`

* Update build_native.sh

Update to C++11 ABI for AOTI, similar to ET

* Update run-docs

* Update run-docs

Update to run distributed inference test with open-llama instead of llama3.1

* Update run-docs

Open-llama -> stories to avoid tokens.

* Update README.md

Remove -l 3 since no longer necessary after Angea's change

* Update quantization.md

remove -l 3 from aoti run , and write -l3 for et_run

* Update run-docs

-l 3:-l 2 -> -l3:-l2

after modifying the command lines.  Hopefull this is legal for et_run

* Update run.cpp

Update to support non-space separated args

* Update run.cpp

typo

* Create cuda-32.json

Add a gs=32 cuda.json for test runs with stories15M

* Create mobile-32.json

add gs=32 variant of mobile for tests

* Update run-docs

Use gs=32 variants with stories models

* Update run-docs

undo gs32

* Update run-readme-pr-mps.yml

Extend timeout to avoid timeout of mps quantization test

* Update run.cpp

enforce that and argument must have at least length 2, and refine check for uniarg (ie arg plus flag value in one option) to be args with more than 2 characters

* Update run.cpp

typos

---------

Co-authored-by: Jack-Khuu <[email protected]>
…p.tc` (#1465)

* support model snapshots to save quantized models

* import set backend

---------

Co-authored-by: Michael Gschwind <[email protected]>
Hnadle situation where aspell not available
* Add DeepSeek R1 Distill 8B

* Update aliases to match Ollama

* Update README
@facebook-github-bot
Copy link

Hi @vmpuri!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. Quantization Issues related to Quantization or torchao
Projects
None yet
Development

Successfully merging this pull request may close these issues.