Skip to content

Commit

Permalink
Merge branch 'mainline' into mmap-mlmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz authored Feb 8, 2024
2 parents 8f397c0 + 1a18d2c commit a0ce0e9
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
if [[ "${{ github.event_name }}" == "push" ]] && \
[[ "${{ github.event.ref }}" =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo ::set-output name=match::true
echo "match=true" >> $GITHUB_OUTPUT
fi
- name: Upload to PyPI
Expand Down
35 changes: 31 additions & 4 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,23 +2309,46 @@ def link_calibrator_methods(self):
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
c_uint32,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
[
c_uint64,
POINTER(c_float),
POINTER(c_float),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
c_uint32,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
[
c_uint64,
POINTER(c_double),
POINTER(c_double),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)

def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
def fit_platt_transform(
self,
logits,
targets,
max_iter=100,
eps=1e-5,
clip_tgt_prob=True,
):
"""Python to C/C++ interface for platt transfrom fit.
Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf
Args:
logits (ndarray): 1-d array of logit with length N.
targets (ndarray): 1-d array of target probability scores within [0, 1] with length N.
clip_tgt_prob (bool): whether to clip the target probability to
max_iter (int, optional): max number of iterations to train. Default 100
eps (float, optional): epsilon. Defaults to 1e-5
clip_tgt_prob (bool, optional): whether to clip the target probability to
[1/(prior0 + 2), 1 - 1/(prior1 + 2)]
where prior1 = sum(targets), prior0 = N - prior1
Returns:
Expand Down Expand Up @@ -2356,13 +2379,17 @@ def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
elif tgt_prob.dtype == np.float64:
return_code = clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")
Expand Down
6 changes: 4 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,11 @@ extern "C" {
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
double* AB, \
size_t max_iter, \
double eps \
) { \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1], max_iter, eps); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
Expand Down
34 changes: 17 additions & 17 deletions pecos/core/third_party/nlohmann_json/json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14698,7 +14698,7 @@ The invariants are checked by member function assert_invariant().
@endinternal

@see [RFC 7159: The JavaScript Object Notation (JSON) Data Interchange
Format](http://rfc7159.net/rfc7159)
Format](https://datatracker.ietf.org/doc/html/rfc7159)

@since version 1.0.0

Expand Down Expand Up @@ -14939,7 +14939,7 @@ class basic_json
/*!
@brief a type for an object

[RFC 7159](http://rfc7159.net/rfc7159) describes JSON objects as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes JSON objects as follows:
> An object is an unordered collection of zero or more name/value pairs,
> where a name is a string and a value is a string, number, boolean, null,
> object, or array.
Expand Down Expand Up @@ -14993,7 +14993,7 @@ class basic_json

#### Limits

[RFC 7159](http://rfc7159.net/rfc7159) specifies:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) specifies:
> An implementation may set limits on the maximum depth of nesting.

In this class, the object's limit of nesting is not explicitly constrained.
Expand All @@ -15016,7 +15016,7 @@ class basic_json
name/value pairs in a different order than they were originally stored. In
fact, keys will be traversed in alphabetical order as `std::map` with
`std::less` is used by default. Please note this behavior conforms to [RFC
7159](http://rfc7159.net/rfc7159), because any order implements the
7159](https://datatracker.ietf.org/doc/html/rfc7159), because any order implements the
specified "unordered" nature of JSON objects.
*/
using object_t = ObjectType<StringType,
Expand All @@ -15028,7 +15028,7 @@ class basic_json
/*!
@brief a type for an array

[RFC 7159](http://rfc7159.net/rfc7159) describes JSON arrays as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes JSON arrays as follows:
> An array is an ordered sequence of zero or more values.

To store objects in C++, a type is defined by the template parameters
Expand All @@ -15052,7 +15052,7 @@ class basic_json

#### Limits

[RFC 7159](http://rfc7159.net/rfc7159) specifies:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) specifies:
> An implementation may set limits on the maximum depth of nesting.

In this class, the array's limit of nesting is not explicitly constrained.
Expand All @@ -15074,7 +15074,7 @@ class basic_json
/*!
@brief a type for a string

[RFC 7159](http://rfc7159.net/rfc7159) describes JSON strings as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes JSON strings as follows:
> A string is a sequence of zero or more Unicode characters.

To store objects in C++, a type is defined by the template parameter
Expand All @@ -15101,7 +15101,7 @@ class basic_json

#### String comparison

[RFC 7159](http://rfc7159.net/rfc7159) states:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) states:
> Software implementations are typically required to test names of object
> members for equality. Implementations that transform the textual
> representation into sequences of Unicode code units and then perform the
Expand All @@ -15127,7 +15127,7 @@ class basic_json
/*!
@brief a type for a boolean

[RFC 7159](http://rfc7159.net/rfc7159) implicitly describes a boolean as a
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) implicitly describes a boolean as a
type which differentiates the two literals `true` and `false`.

To store objects in C++, a type is defined by the template parameter @a
Expand All @@ -15153,7 +15153,7 @@ class basic_json
/*!
@brief a type for a number (integer)

[RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes numbers as follows:
> The representation of numbers is similar to that used in most
> programming languages. A number is represented in base 10 using decimal
> digits. It contains an integer component that may be prefixed with an
Expand Down Expand Up @@ -15191,7 +15191,7 @@ class basic_json

#### Limits

[RFC 7159](http://rfc7159.net/rfc7159) specifies:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) specifies:
> An implementation may set limits on the range and precision of numbers.

When the default type is used, the maximal integer number that can be
Expand All @@ -15202,7 +15202,7 @@ class basic_json
will be automatically be stored as @ref number_unsigned_t or @ref
number_float_t.

[RFC 7159](http://rfc7159.net/rfc7159) further states:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) further states:
> Note that when such software is used, numbers that are integers and are
> in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense
> that implementations will agree exactly on their numeric values.
Expand All @@ -15225,7 +15225,7 @@ class basic_json
/*!
@brief a type for a number (unsigned)

[RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes numbers as follows:
> The representation of numbers is similar to that used in most
> programming languages. A number is represented in base 10 using decimal
> digits. It contains an integer component that may be prefixed with an
Expand Down Expand Up @@ -15263,7 +15263,7 @@ class basic_json

#### Limits

[RFC 7159](http://rfc7159.net/rfc7159) specifies:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) specifies:
> An implementation may set limits on the range and precision of numbers.

When the default type is used, the maximal integer number that can be
Expand All @@ -15273,7 +15273,7 @@ class basic_json
deserialization, too large or small integer numbers will be automatically
be stored as @ref number_integer_t or @ref number_float_t.

[RFC 7159](http://rfc7159.net/rfc7159) further states:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) further states:
> Note that when such software is used, numbers that are integers and are
> in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense
> that implementations will agree exactly on their numeric values.
Expand All @@ -15296,7 +15296,7 @@ class basic_json
/*!
@brief a type for a number (floating-point)

[RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) describes numbers as follows:
> The representation of numbers is similar to that used in most
> programming languages. A number is represented in base 10 using decimal
> digits. It contains an integer component that may be prefixed with an
Expand Down Expand Up @@ -15334,7 +15334,7 @@ class basic_json

#### Limits

[RFC 7159](http://rfc7159.net/rfc7159) states:
[RFC 7159](https://datatracker.ietf.org/doc/html/rfc7159) states:
> This specification allows implementations to set limits on the range and
> precision of numbers accepted. Since software that implements IEEE
> 754-2008 binary64 (double precision) numbers is generally available and
Expand Down
22 changes: 12 additions & 10 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ namespace pecos {
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B, size_t max_iter, double eps) {
// define the return code
enum {
SUCCESS=0,
Expand All @@ -288,10 +288,8 @@ namespace pecos {
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-5;

// calculate prior of B
double prior1 = 0;
Expand All @@ -300,7 +298,6 @@ namespace pecos {
}
double prior0 = double(num_samples) - prior1;


// Initial Point and Initial Fun Value
A = 0.0; B = log((prior0 + 1.0) / (prior1 + 1.0));
double fval = 0.0;
Expand All @@ -313,7 +310,7 @@ namespace pecos {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
int iter;
size_t iter = 0;
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
Expand Down Expand Up @@ -342,16 +339,22 @@ namespace pecos {
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
double det = h11 * h22 - h21 * h21;
double dA = -(h22 * g1 - h21 * g2) / det;
double dB = -(-h21 * g1 + h11 * g2) / det;
double gd = g1 * dA + g2 * dB;

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps) {
break;
}
// additional stop criteria to handle the case when det is large
if (fabs(dA) < eps && fabs(dB) < eps) {
break;
}

// Line Search
double stepsize = 1.0;

Expand All @@ -370,8 +373,7 @@ namespace pecos {
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
if (newf < fval + 0.0001 * stepsize * gd) {
A = newA;
B = newB;
fval = newf;
Expand Down

0 comments on commit a0ce0e9

Please sign in to comment.