-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.ml
More file actions
77 lines (72 loc) * 2.92 KB
/
test.ml
File metadata and controls
77 lines (72 loc) * 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
open Printf
module CLI = Minicli.CLI
module L = BatList
module Log = Dolog.Log
module Score_label = struct
type t = bool * float (* (label, pred_score) *)
let get_label (l, _) = l
let get_score (_, s) = s
end
module ROC = Cpm.MakeROC.Make(Score_label)
module Svm = Orsvm_e1071.Svm
module Svmpath = Orsvm_e1071.Svmpath
module Utls = Orsvm_e1071.Utls
let main () =
Log.set_log_level Log.DEBUG;
Log.color_on ();
let _argc, args = CLI.init () in
let ncores = CLI.get_int_def ["-np"] args 1 in
let data_fn = "data/train_data.txt" in
let sparse_data_fn = "data/train_data.csr" in
let labels_fn = "data/train_labels.txt" in
let cost = 1.0 in
let rbf_preds =
let rbf =
let gamma = 1.0 /. 1831.0 in
Svm.RBF gamma in
let rbf_model = Svm.train ~debug:true Dense ~cost rbf data_fn labels_fn in
let rbf_preds_fn = Svm.predict ~debug:true Dense rbf_model data_fn in
Svm.read_predictions rbf_preds_fn in
let lin_preds =
let lin_model = Svm.train ~debug:true Dense ~cost Svm.Linear data_fn labels_fn in
let lin_preds_fn = Svm.predict ~debug:true Dense lin_model data_fn in
Svm.read_predictions lin_preds_fn in
let sparse_lin_preds =
let sparse_lin_model = Svm.train ~debug:true (Sparse 1831) ~cost Svm.Linear sparse_data_fn labels_fn in
let sparse_lin_preds_fn = Svm.predict ~debug:true (Sparse 1831) sparse_lin_model sparse_data_fn in
Svm.read_predictions sparse_lin_preds_fn in
assert(List.length rbf_preds = 88);
assert(List.length lin_preds = 88);
(* List.iter (printf "%f\n") predictions *)
let labels =
let labels_line = Utls.with_in_file labels_fn input_line in
let label_strings = BatString.split_on_char '\t' labels_line in
L.map (function
| "1" -> true
| "-1" -> false
| other -> failwith other
) label_strings in
let rbf_auc = ROC.auc (List.combine labels rbf_preds) in
printf "RBF AUC: %.3f\n" rbf_auc;
let lin_auc = ROC.auc (List.combine labels lin_preds) in
printf "Lin AUC: %.3f\n" lin_auc;
let sparse_lin_auc = ROC.auc (List.combine labels sparse_lin_preds) in
printf "sparse Lin AUC: %.3f\n" sparse_lin_auc;
let maybe_model = Svmpath.train ~debug:true data_fn labels_fn in
let lambdas = Svmpath.read_lambdas ~debug:true maybe_model in
let lambda_aucs =
Parmap.parmap ~ncores ~chunksize:1 (fun lambda ->
let svmpath_preds_fn =
Svmpath.predict ~debug:false ~lambda:lambda maybe_model data_fn in
let svmpath_preds = Svmpath.read_predictions svmpath_preds_fn in
let auc = ROC.auc (List.combine labels svmpath_preds) in
(lambda, auc)
) (Parmap.L lambdas) in
let best_lambda, best_auc =
L.fold_left (fun (best_lambda, best_auc) (lambda, auc) ->
if auc > best_auc then (lambda, auc)
else (best_lambda, best_auc)
) (0.0, 0.0) lambda_aucs in
printf "svmpath best_lambda: %f best_AUC: %.3f\n"
best_lambda best_auc
let () = main ()