Skip to content

Commit

Permalink
add CI to disallow syntax errors and undefined vars in all Python fil…
Browse files Browse the repository at this point in the history
…es (#861)

Summary:

Adds two codebase-wide checks for Python files:
1. syntax errors (E999)
2. undefined variables (F821)

Both of these resulted in internal breakages recently, so would be good
to just have CI block these from landing in OSS.

Test Plan:

Tested that the new rules pass locally:

```
ruff check --isolated --select E999,F821
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Sep 10, 2024
1 parent 3e9746c commit 93b5869
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ruff_linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
- name: Analyzing the code with ruff
run: |
ruff check .
- name: Check all Python files for syntax errors (E999) and undefined vars (F821)
run: |
ruff check --isolated --select E999,F821
- name: Check well formatted code
run: |
ruff format --check
6 changes: 4 additions & 2 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def run_gpu_sparse_benchmark(m, k, n, args):
elif args.eval_fn == "mm":
dense_output = torch.mm(A, x.t())
sparse_output = torch.mm(A_sparse, x.t())
dense_time = benchmark_in_us(torch.mm, A, x.t())
sparse_time = benchmark_in_us(torch.mm, A_sparse, x.t())
# dense_time = benchmark_in_us(torch.mm, A, x.t())
# sparse_time = benchmark_in_us(torch.mm, A_sparse, x.t())
# TODO(future PR) fixme
dense_time, sparse_time = 1.0, 1.0
else:
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def run(
scale_b = torch.tensor([1.0], device=device)

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):


def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.multiprocessing as mp
from ax.modelbridge.cross_validation import cross_validate
from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples
from BO_acc_throughput import define_parameter_list

# return evaluation results to complete BO trials
def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_load_model,
)

from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config, load_parameters_from_json
from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -380,6 +380,8 @@ def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, nu
parameters_list = load_parameters_from_json(args.parameters_list)

# sample initial points
# TODO(future PR): fix me
initial_samples = []
initial_points_set = load_initial_samples(initial_samples)
num_BO_initial_samples = len(initial_points_set)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase):
@staticmethod
def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
if dtype is None:
dtype = qscales.dtype
dtype = q_scales.dtype
kwargs["dtype"] = dtype
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]

Expand Down
4 changes: 2 additions & 2 deletions torchao/sparsity/prototype/superblock/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchao.sparsity import sparsify_, semi_sparse_weight
from torchao.sparsity.prototype.superblock.supermask import apply_supermask
from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args
from torchao.sparsity.prototype.superblock.utils import apply_sparsity, verify_sparsity, mlp_only_with_args, simulate_sparsity, accelerate_with_sparsity
from torchao.sparsity.prototype.superblock.train import evaluate, _get_cache_path, load_data
from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier

Expand Down Expand Up @@ -56,7 +56,7 @@ def main(args):
model.to(device).bfloat16()

if sparsifier_or_none is not None:
sparsifier.squash_mask()
sparsifier_or_none.squash_mask()
accelerate_with_sparsity(model, args)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
Expand Down
1 change: 1 addition & 0 deletions torchao/sparsity/prototype/superblock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed as dist

from torchao.quantization import quantize_, int8_dynamic_activation_int8_semi_sparse_weight
from torchao.sparsity import sparsify_, semi_sparse_weight
from torchao.sparsity.prototype.superblock.supermask import SupermaskLinear, apply_supermask
from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
Expand Down
2 changes: 1 addition & 1 deletion tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LayoutType,
PlainLayoutType,
)
from torchao.utils import TorchAOBaseTensor
from torchao.utils import TorchAOBaseTensor, _register_layout_cls, _get_layout_tensor_constructor

aten = torch.ops.aten

Expand Down

0 comments on commit 93b5869

Please sign in to comment.