Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Fix logits_to_logprobs for 2-D and 3-D logits #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 16, 2023
Merged

Fix logits_to_logprobs for 2-D and 3-D logits #1002

merged 3 commits into from
Dec 16, 2023

Conversation

kddubey
Copy link
Contributor

@kddubey kddubey commented Dec 12, 2023

The implementation in main (from this PR) only works for 1-D logits. It silently fails for 2-D or 3-D logits. The implementation in this PR works out-of-the-box for 1-D, 2-D, and 3-D logits. (3-D is possible in the future w/ batch inference and logits_all=True.) This feature might be useful b/c there are some places in the code where we can save time by vectorizing / not converting data to lists. I'll do that in a future PR.

The minimal and sufficient fix is to set axis=-1 in the np.max call, and set keepdims=True in the np.sum call. I decided to instead go with a more robust implementation. It's almost copy-pasted from scipy.special.log_softmax. I decided against adding scipy as a required dependency b/c it's not lightweight—the latest version is ~37 MB.

How has this been tested?

Script

  1. Install the new test dependency, scipy, which contains a correct implementation

    python -m pip install scipy
  2. Checkout main

    git checkout main
  3. Run this script in main to verify that the current implementation is silently wrong for 2-D logits

    from __future__ import annotations
    
    import numpy as np
    from scipy.special import log_softmax
    
    from llama_cpp import Llama
    
    atol = 1e-3  # intentionally set to be loose when testing the impl in main
    size = (2, 3)
    logits: list = (
        (-np.random.uniform(low=0, high=60, size=size)).astype(np.single).tolist()
    )
    
    logprobs = Llama.logits_to_logprobs(logits)
    logprobs_correct = log_softmax(logits, axis=-1)
    assert np.allclose(logprobs, logprobs_correct, atol=atol)
  4. Checkout this branch

    git checkout kddubey/fix-logits-to-logprobs
  5. Run the same script with atol=1e-6. No error should be raised.

New unit tests

pytest tests/test_llama.py -k test_logits_to_logprobs

]
test = [
"pytest>=7.4.0",
"httpx>=0.24.1",
"scipy>=1.10",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the oldest version compatible with numpy>=1.20.0

source: https://docs.scipy.org/doc/scipy/dev/toolchain.html#numpy

@kddubey kddubey changed the title Fix logits_to_logprobs Fix logits_to_logprobs for 2-D and 3-D logits Dec 12, 2023
@abetlen
Copy link
Owner

abetlen commented Dec 16, 2023

@kddubey thank you, yes that's a good idea wrt vectorizing the logits -> logprobs calculation

@abetlen abetlen merged commit 5a89446 into abetlen:main Dec 16, 2023
@kddubey kddubey deleted the kddubey/fix-logits-to-logprobs branch December 17, 2023 00:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.