-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization #1328
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1328
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d43d52e
to
92e0a9d
Compare
thanks! can you add a generate.py speed benchmark result for before and after as well |
# Use tensor subclass API for int4 weight only. | ||
if device == "cuda" and quantizer == "linear:int4": | ||
quantize_(model, int4_weight_only(q_kwargs["groupsize"])) | ||
elif quantizer == "linear:int8": | ||
print("quantizer is linear int8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("quantizer is linear int8") |
torchchat/utils/quantize.py
Outdated
"precision": PrecisionHandler, | ||
"executor": ExecutorHandler, | ||
"linear:int4": Int4WeightOnlyQuantizer, | ||
"linear:int8": int8_weight_only, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can probably use None for now, and remove this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check for int8_weight_only and finished check before it looks at the table I think
@vmpuri can you check?
Can you ack that the numerics look good for MPS and CPU as well? |
torchchat/utils/quantize.py
Outdated
# Use tensor subclass API for int4 weight only. | ||
if device == "cuda" and quantizer == "linear:int4": | ||
quantize_(model, int4_weight_only(q_kwargs["groupsize"])) | ||
elif quantizer == "linear:int8": | ||
print("quantizer is linear int8") | ||
quantize_(model, int8_weight_only()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not integrate it into a QuantHandler class dispatched thru the handler dict at a single call site rather than build a chain of if statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mikekgfb, we will refactor this part in the future after all quant APIs are moved to torchao I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchAO already has a class-based API that is used for other quantizers? Why do these differently, and then later refactor them? Or why not do them all a consistent way now, and if you refactor later, do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, quantizer API is deprecated in favor of quantize_
, that's why we are gradually refactoring the quantizer APIs to use quantize_
, the reason we do it one by one is because there might be missing support/alignment on numerics etc. that we need to do during the migration
torchchat/utils/quantize.py
Outdated
return linear_int8_aoti(input, self.weight, self.scales) | ||
|
||
def et_forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return linear_int8_et(input, self.weight, self.scales) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Int 8 seems like it special cased for ET, reminder to check that as well
The previous Python version check was incorrect, allowing installations on unsupported interpreter versions, which caused installation failures. Additionally, we now respect the specified interpreter version if provided, consistently using it throughout the installation process by enforcing it with pip. Signed-off-by: Sébastien Han <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
* toeknizer was missing an include * fix a nit --------- Co-authored-by: Jesse <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
…resses ctrl+c (#1352) Setup a SIGINT handler to gracefully exit the program once the user presses ctrl+c. Signed-off-by: Sébastien Han <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
…ed values (#1359) * Update cli.py to make --device/--dtype pre-empt quantize dict-specified values Users may expect that cli parameters override the JSON, as per #1278. Invert logic - case split: 1 - if none (no value) is specified, use value specified in quantize dict, if present; else 2 - if value is specified, override the respective handler if present. * Fix typo in cli.py fix typo --------- Co-authored-by: Jack-Khuu <[email protected]>
…1369) * Only set up during the first sample * Cleaner
…#1368) * Update install_requirements.sh to support python 3.10 >= , <3.13 * Update install_requirements.sh * Update install_requirements.sh
`gguf` was listed twice on the dependency list. Signed-off-by: Sébastien Han <[email protected]>
If the chat is exited or interrupted it will still print the stats with NaN values which is unnecessary. Signed-off-by: Sébastien Han <[email protected]>
) Let's gracefully fail if no model is given to the `download` command. Signed-off-by: Sébastien Han <[email protected]>
* Bumping ET Pin to Jan15 2025 pytorch/executorch@d596cd7 * Remove call to capture_pre_auto_grad_graph * Update naming for sdpa to custom ops * Fix export_for_train * Update et-pin.txt * Update Test perms
Fix typo Co-authored-by: Jack-Khuu <[email protected]>
fiux typo as separate PR, as per @malfet Co-authored-by: Jack-Khuu <[email protected]>
* add xpu * add xpu device * update * profile * update install * update * update * update --------- Co-authored-by: Jack-Khuu <[email protected]> Co-authored-by: Guoqiong <[email protected]>
* Create run-readme-pr-linuxaarch64 Test torchchat on aarch64 linux * Rename run-readme-pr-linuxaarch64 to run-readme-pr-linuxaarch64.yml add yml extension. * Update ADVANCED-USERS.md Update doc to indicate testing for ARMv8/aarch64 on Linux/raspbian is introduced by this PR --------- Co-authored-by: Jack-Khuu <[email protected]>
* Update install_requirements.sh * Update pytorch minor version * Update install_requirements.sh
* Add warning in PTEModel when not defined * Add missing parans
bump this into the constructor of BuilderArgs Co-authored-by: Jack-Khuu <[email protected]>
`attention_backend` is a SDPBackend, not a string
Co-authored-by: Jack-Khuu <[email protected]>
* Update run-readme-pr-linuxaarch64.yml to use correct runner * Move to linux.arm64.m7g.4xlarge * Explicitly overriding the docker-image * Bumping Cuda version to 12.6 * Updating GPU Arch type * Testing various linux_job combos: v2 cuda, v2 cpu, v1 cpu * Adding permissions to linux job v2 * Switch everything to CPU linux v2 * Test with devtoolset-11 * Remove devtoolset install * Removing devtoolset from commands
* Add encoded size to start_pos * Only in chat mode --------- Co-authored-by: nlpfollower <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
) ExecuTorch now has XNN pybinding built by default pytorch/executorch#7473 Previously it was not built by default
In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301 Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times. - mul is called 275 times per token - copy_ is called 202 times per token - add is called 97 times per token - where is called 34 times per token - mean is called 33 times per token - rsqrt is called 33 times per token - sub is called 32 times per token - cat is called 32 times per token - stack is called 32 times per token Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this: ``` norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) output = norm(x.float()).type_as(x) * weight ``` This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.
Co-authored-by: Jack-Khuu <[email protected]>
…1409) * Add evaluation, multimodal, native tests to run-readme-pr-macos.yml Add evaluation, multimodal, native tests to run-readme-pr-macos.yml * Update run-readme-pr-mps.yml * Update build_native.sh Update to C++11 ABI for AOTI, similar to ET * Update run-readme-pr-macos.yml fix typo --------- Co-authored-by: Jack-Khuu <[email protected]>
) * Add evaluation, multimodal, native tests to run-readme-pr-mps.yml Add evaluation, multimodal, native tests to run-readme-pr-mps.yml * Update run-readme-pr-mps.yml Typos * Update run-readme-pr-mps.yml * Update run-readme-pr-mps.yml Extend timeout for test-readme-mps to avoid test failing from timeout. * Update build_native.sh Update to C++11 ABI for AOTI, similar to ET --------- Co-authored-by: Jack-Khuu <[email protected]>
…ng to MPS (#1417) * bandaid for run-readme-pr-macos.yml incorrectly loading to MPS as per #1416 torchchat on hosts without MPS (which is all github hosts which use kvm to virtualize MacOS, but not MPS) should choose CPU as "fast" device. The logic is present (see discussion in #1416 ), but either not fully functional (that would be the easier one to fix, just print the result of get_device_str and fix the code!) or specifically ignored on load in torch/serialization.py (If this is the case, we're effectively looking at a core PyTorch issue....) In the meantime, this bandaid just forces the use of CPU on MacOS tests, to make MacOS tests run on CPU -- labeit hsortcircuiting test/execution of the "fast" device logic. Not ideal, but some testing beats no testing. * Update run-readme-pr-macos.yml Add informational message to MacOS CPU tests * Update build_native.sh Update to C++11 ABI for AOTI, similar to ET --------- Co-authored-by: Jack-Khuu <[email protected]>
* Add distributed tests to run-readme-pr.yml Need to ensure this is the right runner, @lessw2020 can you please have a look -- torchchat uses the same runners as pytorch. * Update run-docs Remove HF login because tokens not available as git secret * Update run-docs Replace llama3.1 with open-llama to avoid need for token. If this turns out running too long, then we can switch to stories110M * Update run-docs open-llama -> stories.
* Update run-docs to avoid duplicate code Update run-docs to avoid duplicate code * Update run-docs Add back command explaining seemingly extraneous `echo exit 1` * Update build_native.sh Update to C++11 ABI for AOTI, similar to ET * Update run-docs * Update run-docs Update to run distributed inference test with open-llama instead of llama3.1 * Update run-docs Open-llama -> stories to avoid tokens. * Update README.md Remove -l 3 since no longer necessary after Angea's change * Update quantization.md remove -l 3 from aoti run , and write -l3 for et_run * Update run-docs -l 3:-l 2 -> -l3:-l2 after modifying the command lines. Hopefull this is legal for et_run * Update run.cpp Update to support non-space separated args * Update run.cpp typo * Create cuda-32.json Add a gs=32 cuda.json for test runs with stories15M * Create mobile-32.json add gs=32 variant of mobile for tests * Update run-docs Use gs=32 variants with stories models * Update run-docs undo gs32 * Update run-readme-pr-mps.yml Extend timeout to avoid timeout of mps quantization test * Update run.cpp enforce that and argument must have at least length 2, and refine check for uniarg (ie arg plus flag value in one option) to be args with more than 2 characters * Update run.cpp typos --------- Co-authored-by: Jack-Khuu <[email protected]>
…p.tc` (#1465) * support model snapshots to save quantized models * import set backend --------- Co-authored-by: Michael Gschwind <[email protected]>
Hnadle situation where aspell not available
* Add DeepSeek R1 Distill 8B * Update aliases to match Ollama * Update README
Hi @vmpuri! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Replace the WeightOnlyInt8Linear quantization code with TorchAO's int8_weight_only quantization.
Note - this commit also contains lintrunner changes.
Testing:
From current master:
Lint