Skip to content

Commit 0fcb92f

Browse files
committed
add java
1 parent 2bc22a2 commit 0fcb92f

File tree

9 files changed

+283
-7
lines changed

9 files changed

+283
-7
lines changed

BUILD.bazel

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,20 @@ load(":version.bzl", "VERSION", "write_version")
33

44
# gazelle:prefix github.com/ajwerner/tdigest
55
gazelle(name = "gazelle")
6+
7+
cc_library(
8+
name = "jni_headers",
9+
srcs = [
10+
"@local_jdk//:jni_header",
11+
"@local_jdk//:jni_md_header-linux",
12+
],
13+
includes = [
14+
"external/local_jdk/include",
15+
"external/local_jdk/include/linux",
16+
],
17+
linkstatic = 1,
18+
visibility = [
19+
"//visibility:public",
20+
],
21+
)
22+

c/include/tdigest.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ td_histogram_t *td_new(double compression);
4040
// td_free may only be called if the histogram was created with td_new.
4141
void td_free(td_histogram_t *h);
4242

43-
void td_clean(td_histogram_t *h);
44-
4543
void td_add(td_histogram_t *h, double mean, double count);
4644

4745
void td_merge(td_histogram_t *into, td_histogram_t *from);

c/src/tdigest.c

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <stdlib.h>
22
#include <stdbool.h>
3+
#include <stdio.h>
34
#include <string.h>
45
#include <math.h>
56

@@ -99,14 +100,59 @@ double td_total_count(td_histogram_t *h) {
99100
return h->merged_count + h->unmerged_count;
100101
}
101102

103+
double td_quantile_of(td_histogram_t *h, double val) {
104+
merge(h);
105+
if (h->merged_nodes == 0) {
106+
return NAN;
107+
}
108+
/* if (h->merged_nodes == 1) { */
109+
/* if (h->nodes[0].mean > val) { */
110+
/* return 1; */
111+
/* } else if (h->nodes[0].mean < val) { */
112+
/* return 0; */
113+
/* } */
114+
/* return 0.5; */
115+
/* } */
116+
double k = 0;
117+
int i = 0;
118+
node_t *n = NULL;
119+
for (i = 0; i < h->merged_nodes; i++) {
120+
n = &h->nodes[i];
121+
if (n->mean >= val) {
122+
break;
123+
}
124+
k += n->count;
125+
}
126+
if (val == n->mean) {
127+
// technically this needs to find all of the nodes which contain this value and sum their weight
128+
double count_at_value = n->count;
129+
for (i += 1; i < h->merged_nodes && h->nodes[i].mean == n->mean; i++) {
130+
count_at_value += h->nodes[i].count;
131+
}
132+
return (k + (count_at_value/2)) / h->merged_count;
133+
} else if (val > n->mean) { // past the largest
134+
return 1;
135+
} else if (i == 0) {
136+
return 0;
137+
}
138+
// we want to figure out where along the line from the prev node to this node, the value falls
139+
node_t *nr = n;
140+
node_t *nl = n-1;
141+
k -= (nl->count/2);
142+
// we say that at zero we're at nl->mean
143+
// and at (nl->count/2 + nr->count/2) we're at nr
144+
double m = (nr->mean - nl->mean) / (nl->count/2 + nr->count/2);
145+
double x = (val - nl->mean) / m;
146+
printf("hi %f %f %f %f\n", m, x, k, h->merged_count);
147+
return (k + x) / h->merged_count;
148+
}
149+
150+
102151
double td_value_at(td_histogram_t *h, double q) {
103152
merge(h);
104153
if (q < 0 || q > 1 || h->merged_nodes == 0) {
105154
return NAN;
106155
}
107-
if (h->merged_nodes == 1) {
108-
return h->nodes[0].mean;
109-
}
110156
// if left of the first node, use the first node
111157
// if right of the last node, use the last node, use it
112158
double goal = q * h->merged_count;
@@ -118,7 +164,7 @@ double td_value_at(td_histogram_t *h, double q) {
118164
if (k + n->count > goal) {
119165
break;
120166
}
121-
k += h->nodes[i].count;
167+
k += n->count;
122168
}
123169
double delta_k = goal - k - (n->count/2);
124170
if (is_very_small(delta_k)) {
@@ -143,7 +189,7 @@ double td_value_at(td_histogram_t *h, double q) {
143189
double x = goal - k;
144190
// we have two points (0, nl->mean), (nr->count, nr->mean)
145191
// and we want x
146-
double m = (nr->mean - nl->mean) / (nr->count);
192+
double m = (nr->mean - nl->mean) / (nl->count/2 + nr->count/2);
147193
return m * x + nl->mean;
148194
}
149195

c/test/tdigest_test.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,24 @@ static char *test_two_interp() {
7272
return NULL;
7373
}
7474

75+
static char *test_quantile_of() {
76+
td_histogram_t *t = td_new(1000);
77+
td_add(t, 1, 1);
78+
td_add(t, 10, 1);
79+
mu_assert("test_quantile_of: .99", td_quantile_of(t, .99) == 0);
80+
mu_assert("test_quantile_of: 1", td_quantile_of(t, 1) == .25);
81+
mu_assert("test_quantile_of: 5.5", td_quantile_of(t, 5.5) == .5);
82+
83+
td_free(t);
84+
return NULL;
85+
}
86+
7587
static char *all_tests() {
7688
mu_run_test(test_basic);
7789
mu_run_test(test_uniform_rand);
7890
mu_run_test(test_nans);
7991
mu_run_test(test_two_interp);
92+
mu_run_test(test_quantile_of);
8093
return NULL;
8194
}
8295

java/BUILD.bazel

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
load(":extract_header.bzl", "extract_native_header_jar")
2+
3+
java_library(
4+
name = "TDigest-java",
5+
srcs = [
6+
"src/main/java/com/ajwerner/tdigestc/TDigest.java",
7+
],
8+
)
9+
10+
extract_native_header_jar(
11+
name = "TDigest-header",
12+
outs = ["com_ajwerner_tdigestc_TDigest.h"],
13+
lib = ":TDigest-java",
14+
)
15+
16+
cc_binary(
17+
name = "TDigest.so",
18+
srcs = [
19+
"src/main/c/TDigest.c",
20+
"com_ajwerner_tdigestc_TDigest.h",
21+
],
22+
deps = [
23+
"//c:tdigest",
24+
"//:jni_headers",
25+
],
26+
includes = [ "." ],
27+
linkshared = 1,
28+
)
29+
30+
java_library(
31+
name = "TDigest",
32+
data = [ ":TDigest.so" ],
33+
srcs = [
34+
"src/main/java/com/ajwerner/tdigestc/TDigest.java",
35+
],
36+
resources = [
37+
":TDigest.so",
38+
],
39+
exports = [
40+
":TDigest.so",
41+
],
42+
)
43+
44+
java_test(
45+
name = "TDigestTest",
46+
srcs = glob(["src/test/java/**/*.java"]),
47+
deps = [
48+
":TDigest",
49+
],
50+
)

java/extract_header.bzl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
def _impl(ctx):
3+
# https://github.com/bazelbuild/rules_scala/pull/286/files
4+
print(ctx.outputs.outs[0].dirname)
5+
ctx.actions.run(
6+
inputs = [ctx.attr.lib.java.outputs.native_headers],
7+
tools = [ctx.executable._zipper],
8+
outputs = ctx.outputs.outs,
9+
executable = ctx.executable._zipper.path,
10+
arguments = ["vxf", ctx.attr.lib.java.outputs.native_headers.path, "-d", ctx.outputs.outs[0].dirname],
11+
)
12+
13+
extract_native_header_jar = rule(
14+
implementation=_impl,
15+
attrs={
16+
"lib": attr.label(mandatory=True, single_file=True),
17+
"outs": attr.output_list(),
18+
# https://github.com/bazelbuild/bazel/issues/2414
19+
"_zipper": attr.label(executable=True, cfg="host", default=Label("@bazel_tools//tools/zip:zipper"), allow_files=True)
20+
},
21+
output_to_genfiles = True,
22+
)

java/src/main/c/TDigest.c

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
#include <jni.h>
3+
#include "tdigest.h"
4+
#include "com_ajwerner_tdigestc_TDigest.h"
5+
6+
/*
7+
* Class: com_ajwerner_tdigestc_TDigest
8+
* Method: td_new
9+
* Signature: (I)J
10+
*/
11+
JNIEXPORT jlong JNICALL Java_com_ajwerner_tdigestc_TDigest_td_1new
12+
(JNIEnv *env, jobject this, jint size) {
13+
return (jlong)(td_new((double)(size)));
14+
}
15+
16+
/*
17+
* Class: com_ajwerner_tdigestc_TDigest
18+
* Method: td_add
19+
* Signature: (JDD)V
20+
*/
21+
JNIEXPORT void JNICALL Java_com_ajwerner_tdigestc_TDigest_td_1add
22+
(JNIEnv *env, jobject this, jlong ptr, jdouble val, jdouble count) {
23+
td_add((td_histogram_t *)(ptr), (double)(val), (double)(count));
24+
}
25+
26+
/*
27+
* Class: com_ajwerner_tdigestc_TDigest
28+
* Method: td_value_at
29+
* Signature: (JD)D
30+
*/
31+
JNIEXPORT jdouble JNICALL Java_com_ajwerner_tdigestc_TDigest_td_1value_1at
32+
(JNIEnv *env, jobject this, jlong ptr, jdouble q) {
33+
34+
return (jdouble)(td_value_at((td_histogram_t *)(ptr), (double)(q)));
35+
}
36+
37+
/*
38+
* Class: com_ajwerner_tdigestc_TDigest
39+
* Method: td_free
40+
* Signature: (J)V
41+
*/
42+
JNIEXPORT void JNICALL Java_com_ajwerner_tdigestc_TDigest_td_1free
43+
(JNIEnv *env, jobject this, jlong ptr) {
44+
td_free((td_histogram_t *)(ptr));
45+
}
46+
47+
JNIEXPORT jdouble JNICALL Java_com_ajwerner_tdigestc_TDigest_td_1quantile_1of
48+
(JNIEnv *env, jobject this, jlong ptr, jdouble val) {
49+
return (jdouble)(td_quantile_of((td_histogram_t *)(ptr), (double)(val)));
50+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.ajwerner.tdigestc;
2+
3+
import java.io.InputStream;
4+
import java.io.OutputStream;
5+
import java.io.FileOutputStream;
6+
import java.io.File;
7+
import java.io.IOException;
8+
import java.nio.file.Files;
9+
10+
public class TDigest {
11+
private long ptr;
12+
13+
public TDigest(int size) {
14+
this.ptr = td_new(size);
15+
}
16+
17+
public void add(double val) {
18+
this.add(val, 1);
19+
}
20+
21+
public void add(double val, double count) {
22+
td_add(this.ptr, val, count);
23+
}
24+
25+
public double valueAt(double q) {
26+
return td_value_at(this.ptr, q);
27+
}
28+
29+
public double quantileOf(double val) {
30+
return td_quantile_of(this.ptr, val);
31+
}
32+
33+
@Override
34+
public void finalize() {
35+
td_free(this.ptr);
36+
}
37+
38+
private native long td_new(int size);
39+
private native void td_add(long ptr, double val, double count);
40+
private native double td_value_at(long ptr, double q);
41+
private native void td_free(long ptr);
42+
private native double td_quantile_of(long ptr, double val);
43+
44+
static {
45+
InputStream is = TDigest.class.getResourceAsStream("/TDigest.so");
46+
try {
47+
File file = File.createTempFile("lib", ".so");
48+
OutputStream os = new FileOutputStream(file);
49+
byte[] buffer = new byte[1024];
50+
int length;
51+
while ((length = is.read(buffer)) != -1) {
52+
os.write(buffer, 0, length);
53+
}
54+
is.close();
55+
os.close();
56+
System.load(file.getAbsolutePath());
57+
file.deleteOnExit();
58+
} catch (IOException e) {
59+
System.out.println(e);
60+
}
61+
62+
}
63+
64+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.ajwerner.tdigestc;
2+
import static org.junit.Assert.*;
3+
import org.junit.Test;
4+
import com.ajwerner.tdigestc.TDigest;
5+
6+
public class TDigestTest {
7+
@Test
8+
public void basicallyWorks() {
9+
TDigest td = new TDigest(100);
10+
td.add(1);
11+
td.add(2);
12+
assertEquals(td.valueAt(.5), 1.5, 0);
13+
assertEquals(td.quantileOf(1.5), .5, 0);
14+
}
15+
}
16+

0 commit comments

Comments
 (0)