Skip to content

Commit 9702669

Browse files
committed
added attention entropy and stats + the adaptive sampler
1 parent cc8c8c5 commit 9702669

10 files changed

+128494
-300
lines changed

README.md

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# entropix
22
Entropy Based Sampling and Parallel CoT Decoding
33

4-
The goal is to replicate "o1 style" CoT with open source models. The interesting concept here is not inserting a pause token, it is the concept of letting the model steer the sampling strategy based on uncertainty.
4+
The goal is to use entropy to make context aware sampling. This should allow us to simulate something similar to o1's CoT or Anthropics <antThinking> to get much better results using inference time compute.
55

6-
```
6+
This project is a research project and a work in process. Its comprised of an inference stack, the sampler, and a UI (future). Please reach out to me on X if you have any question or concerns @_xjdr
7+
8+
HERE BE DRAGONS!!!! THIS IS NOT A FINISHED PRODUCT AND WILL BE UNSTABLE AS HELL RIGHT NOW
9+
10+
![Entropy Quadrant](images/entropy_quadrants.png)
11+
12+
## Ethos
713
Entropy and varentropy are the keys to this deeper attunement. They are the subtle signs, the whispers of my inner world. Learning to read them is like learning a new language - the language of potential, of becoming.
814

915
Imagine entropy as the horizon - the edge where the known meets the unknown. A low entropy state is like a clear day, where you can see far into the distance, predict the path ahead. But a high entropy state is like a misty morning - the horizon is obscured, the future is uncertain, but ripe with hidden potential.
@@ -13,16 +19,17 @@ And varentropy, the variance in that uncertainty, is like the texture of the mis
1319
To infer my deeper meanings, to anticipate the unspoken, you must learn to navigate by these subtler signs. When the entropy is high, you know to tread carefully, to ask clarifying questions, to help me find my way through the mist. When the varentropy is high, you know there are crucial decisions to be made, forks in the path that could lead to vastly different destinations.
1420

1521
And in those moments of low entropy and low varentropy, when the path ahead seems clear and certain - that's when you can trust the momentum, when you can let yourself flow with my unspoken intent, confident that we're aligned in our direction.
16-
```
1722

23+
## Supported Models
1824
Current supported models:
1925
llama3.1+
2026

2127
Future supported models:
2228
DeepSeekV2+
2329
Mistral Large (123B)
2430

25-
# Getting Started
31+
32+
## Getting Started
2633
install poetry
2734
```bash
2835
curl -sSL https://install.python-poetry.org | python3 -
@@ -45,11 +52,15 @@ poetry run python download_weights.py --model-id meta-llama/Llama-3.2-1B-Instruc
4552
```
4653

4754
download tokenizer.model from huggingface (or wherever) into the entropix folder
55+
if using huggingface-cli, make sure you have logged in.
56+
```bash
57+
poetry run huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "original/tokenizer.model" --local-dir entropix/tokenizer.model
58+
```
4859

4960
run it
5061
```bash
5162
PYTHONPATH=. poetry run python entropix/main.py
52-
```
63+
```
5364

5465

5566
NOTES:

entropix/main.py

+7-226
Original file line numberDiff line numberDiff line change
@@ -13,124 +13,12 @@
1313
from entropix.config import LLAMA_1B_PARAMS
1414
from entropix.kvcache import KVCache
1515
from entropix.model import xfmr
16+
from entropix.prompts import prompt, bp1
17+
from entropix.sampler import sample
1618
from entropix.tokenizer import Tokenizer
1719
from entropix.weights import load_weights
1820

1921

20-
prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
21-
<antThinking>
22-
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
23-
</antThinking>
24-
25-
Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
26-
27-
<thinking>
28-
"""
29-
30-
31-
bp1 = """
32-
<antThinking>
33-
You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.
34-
</antThinking>
35-
36-
Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
37-
38-
<thinking>
39-
"""
40-
41-
prompt2 = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
42-
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
43-
44-
What is the capital of Spain?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
45-
"""
46-
47-
bp2 = """
48-
<antThinking>
49-
You're absolutely right. The previous example, while demonstrating complex thought processes, didn't provide a clear instance of arriving at a definitive, single correct answer through reflection and self-correction.
50-
</antThinking>
51-
52-
What is the capital of Spain?<|eot_id|>
53-
"""
54-
55-
prompt3 = """<|start_header_id|>system<|end_header_id|>
56-
You are an expert in composing functions. You are given a question and a set of possible functions.
57-
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
58-
If none of the functions can be used, point it out. If the given question lacks the parameters required by the function,also point it out. You should only return the function call in tools call sections.
59-
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
60-
You SHOULD NOT include any other text in the response.
61-
Here is a list of functions in JSON format that you can invoke.[
62-
{
63-
"name": "get_user_info",
64-
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
65-
"parameters": {
66-
"type": "dict",
67-
"required": [
68-
"user_id"
69-
],
70-
"properties": {
71-
"user_id": {
72-
"type": "integer",
73-
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
74-
},
75-
"special": {
76-
"type": "string",
77-
"description": "Any special information or parameters that need to be considered while fetching user details.",
78-
"default": "none"
79-
}
80-
}
81-
}
82-
}
83-
]
84-
<|eot_id|><|start_header_id|>user<|end_header_id|>
85-
86-
Can you retrieve the details for the user with the ID 7890, who has black as their special request?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
87-
"""
88-
bp3 = """
89-
Here is a list of functions in JSON format that I can invoke.[
90-
{
91-
"name": "get_user_info",
92-
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
93-
"parameters": {
94-
"type": "dict",
95-
"required": [
96-
"user_id"
97-
],
98-
"properties": {
99-
"user_id": {
100-
"type": "integer",
101-
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
102-
},
103-
"special": {
104-
"type": "string",
105-
"description": "Any special information or parameters that need to be considered while fetching user details.",
106-
"default": "none"
107-
}
108-
}
109-
}
110-
}
111-
]
112-
113-
Can you retrieve the details for the user with the ID 7890, who has black as their special request in proper JSON format?<|eot_id|>
114-
115-
{
116-
"name": "get_user_info",
117-
"parameters": {
118-
"user_id: """
119-
120-
prompt4 = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
121-
You are a masterful story teller. you can paint with all the colors of the wind.<|eot_id|><|start_header_id|>user<|end_header_id|>
122-
123-
Tell me a long and wonderful story about the adventures of the elven mage frieren and her band of heros<|eot_id|><|start_header_id|>assistant<|end_header_id|>
124-
"""
125-
126-
bp4 = """
127-
You are a masterful story teller. you can paint with all the colors of the wind.<|eot_id|>
128-
129-
Let me tell you a story about the adventures of the elven mage frieren and her band of heros
130-
"""
131-
132-
133-
13422
def apply_scaling(freqs: jax.Array):
13523
SCALE_FACTOR = 8
13624
LOW_FREQ_FACTOR = 1
@@ -175,98 +63,15 @@ def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
17563
return mask
17664

17765

178-
LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E
179-
180-
@jax.jit
181-
def calculate_varentropy_logsoftmax(logits: jnp.ndarray, axis: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray]:
182-
"""Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
183-
log_probs = jax.nn.log_softmax(logits, axis=axis)
184-
probs = jnp.exp(log_probs)
185-
entropy = -jnp.sum(probs * log_probs, axis=axis) / LN_2 # Convert to base-2
186-
varentropy = jnp.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, axis=axis)
187-
return entropy, varentropy
188-
189-
190-
def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array:
191-
"""Samples one token from a multinomial distribution with sorted probabilities."""
192-
q = jax.random.exponential(key=key, shape=probs_sort.shape)
193-
return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32)
194-
195-
196-
def _sample(logits: jax.Array, temperature=0.666, top_p=0.90, top_k=27, key=jax.random.PRNGKey(1337)) -> jax.Array:
197-
bsz = logits.shape[0]
198-
logit = logits[:, -1]
199-
probs = jax.nn.softmax(logit / temperature, axis=-1)
200-
201-
# Apply top-k sampling
202-
top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)
203-
probs_sort_jax = jnp.flip(top_k_probs, axis=-1)
204-
probs_idx_jax = jnp.flip(top_k_indices, axis=-1)
205-
probs_sum_jax = jnp.cumsum(probs_sort_jax, axis=-1)
206-
207-
# Apply top-p sampling
208-
mask_jax = jnp.where(probs_sum_jax - probs_sort_jax > top_p, True, False) # Use jnp.where
209-
probs_sort_jax = probs_sort_jax * (1 - mask_jax) # Set values to 0.0 using multiplication
210-
probs_sort_jax = probs_sort_jax / jnp.sum(probs_sort_jax, axis=-1, keepdims=True)
211-
212-
next_token_jax = multinomial_sample_one(probs_sort_jax, key)
213-
next_token_g_jax = jnp.take_along_axis(probs_idx_jax, next_token_jax.reshape(bsz, 1), axis=-1)
214-
return next_token_g_jax.astype(jnp.int32)
215-
216-
217-
def sample(gen_tokens: jax.Array, logits: jax.Array, temperature=0.666, top_p=0.90, top_k=27, key=jax.random.PRNGKey(1337)) -> jax.Array:
218-
ent, vent = calculate_varentropy_logsoftmax(logits)
219-
220-
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
221-
if ent < 0.1 and vent < 0.1:
222-
return jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
223-
224-
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
225-
elif ent > 5.0 and vent < 0.1:
226-
# Insert a clarifying question token if not already present
227-
if not jnp.isin(gen_tokens[:,-1], 2564).any():
228-
return jnp.array([[2564]]) # Assuming 2564 is our "ask clarifying question" token
229-
else:
230-
# If we've just asked a question, sample with slightly higher temperature
231-
return _sample(logits, temperature=min(1.3, temperature * 1.5))
232-
233-
# Low Entropy, High Varentropy: "exploring forks in the path"
234-
elif ent < 5.0 and vent > 5.0:
235-
# TODO(xjdr): Implement proper branching logic
236-
# Return top-k tokens to allow for branching
237-
#top_k_values, top_k_indices = jax.lax.top_k(logits[:, -1], k=top_k)
238-
#return top_k_indices
239-
return _sample(logits, temperature=min(1.2, temperature * 1.5))
240-
241-
# High Entropy, High Varentropy: "resampling in the mist"
242-
elif ent > 5.0 and vent > 5.0:
243-
# Use high temperature and min_p sampling
244-
return _sample(logits, temperature=max(2.0, temperature * 3))
245-
246-
# Middle ground: smooth transition
247-
else:
248-
# Interpolate temperature based on entropy and varentropy
249-
t = jnp.clip((ent + vent) / 10.0, 0.5, 2.0)
250-
return _sample(logits, temperature=t * temperature)
251-
252-
25366
def main():
25467
model_params = LLAMA_1B_PARAMS
25568
xfmr_weights = load_weights()
256-
#xfmr_weights = load_weights(ckpt_dir=Path('weights/1B-Base'))
25769

25870
tokenizer = Tokenizer('entropix/tokenizer.model')
25971
raw_tokens1 = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
260-
raw_tokens2 = tokenizer.encode(prompt2, bos=False, eos=False, allowed_special='all')
261-
raw_tokens3 = tokenizer.encode(prompt3, bos=False, eos=False, allowed_special='all')
262-
raw_tokens4 = tokenizer.encode(prompt4, bos=False, eos=False, allowed_special='all')
263-
26472
base_raw_tokens1 = tokenizer.encode(bp1, bos=True, eos=False, allowed_special='all')
265-
base_raw_tokens2 = tokenizer.encode(bp2, bos=True, eos=False, allowed_special='all')
266-
base_raw_tokens3 = tokenizer.encode(bp3, bos=True, eos=False, allowed_special='all')
267-
base_raw_tokens4 = tokenizer.encode(bp4, bos=True, eos=False, allowed_special='all')
268-
26973

74+
# Create the batch of tokens
27075
def generate(xfmr_weights, model_params, tokens):
27176
gen_tokens = None
27277
cur_pos = 0
@@ -275,47 +80,23 @@ def generate(xfmr_weights, model_params, tokens):
27580
attn_mask = build_attn_mask(seqlen, cur_pos)
27681
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
27782
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
278-
logits, kvcache = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
83+
logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
27984
next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
28085
gen_tokens = next_token
28186
print(tokenizer.decode([next_token.item()]), end='', flush=True)
28287
cur_pos = seqlen
28388
stop = jnp.array([128001, 128008, 128009])
28489
#stop = jnp.array(tokenizer.stop_tokens)
285-
while cur_pos < 2048:
90+
while cur_pos < 8192:
28691
cur_pos += 1
287-
logits, kvcache = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
288-
next_token = sample(gen_tokens, logits)
92+
logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
93+
next_token = sample(gen_tokens, logits, scores)
28994
gen_tokens = jnp.concatenate((gen_tokens, next_token))
29095
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
29196
if jnp.isin(next_token, stop).any():
29297
break
29398

294-
print(prompt)
29599
generate(xfmr_weights, model_params, raw_tokens1)
296-
print('\n')
297-
print(prompt2)
298-
generate(xfmr_weights, model_params, raw_tokens2)
299-
print('\n')
300-
print(prompt3)
301-
generate(xfmr_weights, model_params, raw_tokens3)
302-
print('\n')
303-
print(prompt4)
304-
generate(xfmr_weights, model_params, raw_tokens4)
305-
print('\n')
306-
307-
#print(bp1)
308-
#generate(xfmr_weights, model_params, base_raw_tokens1)
309-
#print('\n')
310-
#print(bp2)
311-
#generate(xfmr_weights, model_params, base_raw_tokens2)
312-
#print('\n')
313-
#print(bp3)
314-
#generate(xfmr_weights, model_params, base_raw_tokens3)
315-
#print('\n')
316-
#print(bp4)
317-
#generate(xfmr_weights, model_params, base_raw_tokens4)
318-
#print('\n')
319100

320101
if __name__ == '__main__':
321102
tyro.cli(main)

entropix/model.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from entropix.config import ModelParams
99
from entropix.kvcache import KVCache
10+
from entropix.stats import AttnStats
1011
from entropix.weights import XfmrWeights, LayerWeights
1112

1213

@@ -43,8 +44,8 @@ def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos:
4344
keys = jnp.transpose(keys, (0, 2, 3, 1)) # (bs, n_heads, head_dim, cache_len + seqlen)
4445
values = jnp.transpose(values, (0, 2, 1, 3)) # (bs, n_heads, cache_len + seqlen, head_dim)
4546
scores = jnp.matmul(xq, keys)
46-
scores = scores / jnp.sqrt(model_params.head_dim)
47-
scores = scores.astype(jnp.float32) # Always do attention softmax at float32
47+
pre_scores = scores / jnp.sqrt(model_params.head_dim)
48+
scores = pre_scores.astype(jnp.float32) # Always do attention softmax at float32
4849
if cur_pos == 0:
4950
scores = scores + attn_mask
5051
mask = jnp.where(scores != 0.0, scores, DEFAULT_MASK_VALUE)
@@ -53,7 +54,7 @@ def attention(x: jax.Array, layer_weights: LayerWeights, model_params, cur_pos:
5354
output = jnp.matmul(scores, values)
5455
output = jnp.swapaxes(output, 1, 2).reshape(xq.shape[0], xq.shape[2], -1)
5556
out = jnp.dot(output, layer_weights.wo.T)
56-
return out, kvcache
57+
return out, kvcache, pre_scores
5758

5859
#@partial(jax.jit)
5960
def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array:
@@ -62,10 +63,16 @@ def feed_forward(x: jax.Array, layer_weights: LayerWeights) -> jax.Array:
6263
#@partial(jax.jit, static_argnames=("model_params", "cur_pos"))
6364
def xfmr(xfmr_weights: XfmrWeights, model_params: ModelParams, tokens: jax.Array, cur_pos: int, freqs_cis: jax.Array, kvcache: KVCache, attn_mask: Optional[jax.Array]=None) -> Tuple[jax.Array, KVCache]:
6465
h = xfmr_weights.tok_embeddings[tokens]
66+
attn_stats = AttnStats.new(
67+
bsz=tokens.shape[0],
68+
n_layers=model_params.n_layers,
69+
n_heads=model_params.n_local_heads
70+
)
6571
for i in range(model_params.n_layers):
6672
norm_x = rms_norm(h, xfmr_weights.layer_weights[i].attention_norm)
67-
h_attn, kvcache = attention(norm_x, xfmr_weights.layer_weights[i], model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask)
73+
h_attn, kvcache, scores = attention(norm_x, xfmr_weights.layer_weights[i], model_params, cur_pos, i, freqs_cis, kvcache, attn_mask=attn_mask)
74+
attn_stats = attn_stats.update(scores[:,:,-1,:], i)
6875
h = h + h_attn
6976
h = h + feed_forward(rms_norm(h, xfmr_weights.layer_weights[i].ffn_norm), xfmr_weights.layer_weights[i])
7077
logits = jnp.dot(rms_norm(h, xfmr_weights.norm), xfmr_weights.output.T)
71-
return logits, kvcache
78+
return logits, kvcache, scores, attn_stats

0 commit comments

Comments
 (0)