20 void Compute(tensorflow::OpKernelContext* context)
override {
21 using namespace tensorflow;
23 const Tensor& inp_tensor = context->input(0);
26 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
27 errors::InvalidArgument(
"ThreeNN expects "
28 "(batch_size,num_points,3) inp shape"));
29 int batch_size = inp_tensor.shape().dim_size(0);
30 int pts_num_out = inp_tensor.shape().dim_size(1);
31 auto inp_flat = inp_tensor.flat<
float>();
32 const float* inp = &(inp_flat(0));
34 const Tensor& data_tensor = context->input(1);
37 data_tensor.dims() == 3 && data_tensor.shape().dim_size(2) == 3,
38 errors::InvalidArgument(
40 "(batch_size,num_points,3) data shape"));
41 int pts_num_in = data_tensor.shape().dim_size(1);
42 auto data_flat = data_tensor.flat<
float>();
43 const float* data = &(data_flat(0));
48 context->allocate_output(
49 0, TensorShape{batch_size, pts_num_out, 3}, &out_dist));
50 auto out_flat0 = out_dist->flat<
float>();
51 float* out0 = &(out_flat0(0));
56 context->allocate_output(
57 1, TensorShape{batch_size, pts_num_out, 3}, &out_idx));
58 auto out_flat1 = out_idx->flat<
int>();
59 int* out1 = &(out_flat1(0));
61 Kernel(context, batch_size, pts_num_out, pts_num_in, inp, data, out0,