-
-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Smooth ReLU activation implementation as Triton kernel #196
Smooth ReLU activation implementation as Triton kernel #196
Conversation
dtunai
commented
Apr 11, 2024
- Description: Introduces Smooth ReLU from Google Research as a Triton computation inside Zeta library
- Issue: None
- Dependencies:None
- Tag maintainer: @kyegomez
- Twitter handle:@simudt
@staticmethod | ||
@triton.jit | ||
def smooth_relu_activation_kernel( | ||
x_ptr, output_ptr, n_elements, beta, BLOCK_SIZE: tl.constexpr |
Check warning
Code scanning / Pylintpython3 (reported by Codacy)
Argument name "BLOCK_SIZE" doesn't conform to snake_case naming style Warning
@@ -50,6 +50,16 @@ | |||
) | |||
|
|||
|
|||
def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0): |
Check warning
Code scanning / Pylintpython3 (reported by Codacy)
Missing function or method docstring Warning
@staticmethod | ||
@triton.jit | ||
def smooth_relu_activation_kernel( | ||
x_ptr, output_ptr, n_elements, beta, BLOCK_SIZE: tl.constexpr |
Check warning
Code scanning / Pylint (reported by Codacy)
Wrong hanging indentation before block (add 4 spaces). Warning
@@ -93,6 +93,27 @@ | |||
output = tl.maximum(x, alpha * x) | |||
tl.store(output_ptr + offsets, output, mask=mask) | |||
|
|||
@staticmethod | |||
@triton.jit | |||
def smooth_relu_activation_kernel( |
Check warning
Code scanning / Pylint (reported by Codacy)
Argument name "BLOCK_SIZE" doesn't conform to snake_case naming style Warning
block_st = idx * BLOCK_SIZE | ||
offsets = block_st + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(x_ptr + offsets, mask=mask) |
Check warning
Code scanning / Pylint (reported by Codacy)
Variable name "x" doesn't conform to snake_case naming style Warning
@@ -50,6 +50,16 @@ | |||
) | |||
|
|||
|
|||
def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0): |
Check warning
Code scanning / Pylint (reported by Codacy)
Missing function docstring Warning
@@ -50,6 +50,16 @@ | |||
) | |||
|
|||
|
|||
def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0): |
Check warning
Code scanning / Pylint (reported by Codacy)
Argument name "x" doesn't conform to snake_case naming style Warning