Skip to content

Commit

Permalink
Fix more select issues
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed May 21, 2023
1 parent eb350be commit 0428ca3
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions guidance/library/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ async def recursive_select(current_prefix, allow_token_extension=True):
match_index = 0
for i in range(len(current_prefix), min([len(o[0]) for o in extension_options])):
if len(set([o[0][i] for o in extension_options])) > 1:
match_index = i
break
match_index += 1
if match_index > len(current_prefix):
current_prefix += extension_options[0][0][len(current_prefix):match_index]
# extension_options = [(option[i:], index) for option,index in extension_options]
Expand All @@ -81,14 +81,13 @@ async def recursive_select(current_prefix, allow_token_extension=True):
option_tokens = parser.program.llm.encode(parser_prefix + option)

# if we extended the last token to a longer one
if len(tmp_prefix_tokens) > 0 and option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]:
if len(tmp_prefix_tokens) > 0 and option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1] or len(option_tokens) == len(tmp_prefix_tokens):
if allow_token_extension: # this is a valid extension only if we are not allowed to extend the token
logit_bias1[option_tokens[len(tmp_prefix_tokens)-1]] = 100

# if we did not extend the last token to a longer one we can bias the next token
else:
if len(option_tokens) > len(tmp_prefix_tokens):
logit_bias2[option_tokens[len(tmp_prefix_tokens)]] = 100
logit_bias2[option_tokens[len(tmp_prefix_tokens)]] = 100

# logit_bias[option_tokens[len(tmp_prefix_tokens)-1]] = tmp_prefix_tokens[-1]
# if len(option_tokens) > len(tmp_prefix_tokens) and :
Expand All @@ -106,12 +105,17 @@ async def recursive_select(current_prefix, allow_token_extension=True):
logit_bias = logit_bias2
last_token_str = ""

# check for where we are at the end of the prefix
if len(logit_bias) == 0 and current_prefix in [o[0] for o in extension_options]:
logprobs_out[current_prefix] = 0
return logprobs_out

# generate the token logprobs
gen_obj = await parser.llm_session(
call_prefix, # TODO: perhaps we should allow passing of token ids directly? (this could allow us to avoid retokenizing the whole prefix many times)
max_tokens=1,
logit_bias=logit_bias,
logprobs=10,
logprobs=len(logit_bias),
cache_seed=0,
token_healing=False # we manage token boundary healing ourselves for this function
)
Expand All @@ -127,7 +131,9 @@ async def recursive_select(current_prefix, allow_token_extension=True):
for token_str,logprob in top_logprobs.items():

# build our recursive call prefix
rec_prefix = token_str[len(last_token_str):]
if not token_str.startswith(last_token_str):
continue
rec_prefix = current_prefix + token_str[len(last_token_str):]

# if we did not extend the last token then we recurse while ignoring the possibility of extending the last token
if token_str == last_token_str:
Expand All @@ -144,22 +150,6 @@ async def recursive_select(current_prefix, allow_token_extension=True):
or_prob = p1 + p2 - p1*p2
logprobs_out[k] = np.log(or_prob)

# if we did token healing and did not extend past our prefix we need to consider the next token
# TODO: when returning all logprobs we need to consider all the options, which means we should
# force the model to not token heal and see what would have happened then on the next token...
# first_token_str = max(top_logprobs, key=top_logprobs.get)
# if len(logprobs_result["top_logprobs"]) > 1 and len(first_token_str) == remove_prefix:
# top_logprobs = logprobs_result["top_logprobs"][1]
# for token_str,logprob in top_logprobs.items():
# sub_logprobs = await recursive_select(current_prefix + token_str)
# for k in sub_logprobs:

# # compute the probability of a logical OR between the new extension and the previous possible ones
# p1 = np.exp(logprobs_out[k])
# p2 = np.exp(sub_logprobs[k] + logprob)
# or_prob = p1 + p2 - p1*p2
# logprobs_out[k] = np.log(or_prob)

return logprobs_out

# recursively compute the logprobs for each option
Expand Down

0 comments on commit 0428ca3

Please sign in to comment.