Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/225#discussion_r162369645
--- Diff: src/ports/postgres/modules/knn/knn.py_in ---
@@ -167,22 +169,31 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
knn_neighbors = ""
label_out = ""
cast_to_int = ""
+ k_neighbours = ""
+ k_neighbours_unnest = ""
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
"knn_temp.dist ASC) AS k_nearest_neighbours ")
+ k_neighbours = ", array_agg(distinct k_neighbours) AS k_nearest_neighbours"
+ k_neighbours_unnest = ", unnest(k_nearest_neighbours) as k_neighbours"
if label_column_name:
is_classification = False
label_column_type = get_expr_type(
label_column_name, point_source).lower()
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
-
- pred_out = ", avg({label_col_temp})".format(**locals())
+ if weighted_avg:
+ pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(**locals())
--- End diff --
We should avoid `**locals()` when the format list is so short.
---
|