import { NDArray, array, asarray, empty, ndindex, shallow_array_equal, slice, tester } from './core.mjs';
/**
*
* @param {NDArray} a
* @param {any[]} indices
* @param {null|number} [axis]
* @param {NDArray} [out]
* @param {string} [mode]
* @returns {NDArray}
*/
export function take(a, indices, axis = null, out = null, mode = 'raise') {
a = asarray(a);
indices = array(indices);
if (axis == null) {
indices.data = _indices(indices.data, mode, a.size);
let newshape = indices.shape;
if (out == null) out = empty(newshape);
else if (!shallow_array_equal(out.shape, newshape))
throw 'output array does not match result of ndarray.take';
for (let i = 0; i < indices.size; i++) {
out.data[i] = a.item(indices.data[i]);
}
return out;
} else {
if (axis < 0) axis += a.ndim;
indices.data = _indices(indices.data, mode, a.shape[axis]);
let newshape = a.shape.slice();
newshape.splice(axis, 1, ...indices.shape);
if (out == null) out = empty(newshape);
else if (!shallow_array_equal(out.shape, newshape))
throw 'output array does not match result of ndarray.take';
let rest = Array(axis).fill(slice());
for (let index of ndindex(indices.shape)) {
out.get(rest.concat(index)).set(a.get([...rest, indices.item(index)]));
}
return out;
}
}
function clip(n, min, max) {
return n < min ? min : n > max ? max : n;
}
function _indices(indices, mode, size) {
let newindices = [];
if (mode == 'wrap') {
for (let index of indices) {
index = index % size;
index = index < 0 ? index + size : index;
newindices.push(index);
}
} else if (mode == 'clip') {
for (let index of indices) {
index = clip(index, 0, size - 1);
newindices.push(index);
}
} else if (mode == 'raise') {
for (let index of indices) {
if (index < -size || index >= size)
throw `index ${index} is out of bounds for axis 0 with size ${size}`;
index = index < 0 ? index + a.size : index;
newindices.push(index);
}
} else throw `unexpected mode = ${mode}`;
return newindices;
}
process.env.PRODUCTION ||
tester
.add(
'take',
() => {
let a = [4, 3, 5, 7, 6, 8];
let indices = [0, 1, 4];
return take(a, indices);
},
() => array([4, 3, 6])
)
.add(
'take',
() => {
let a = [4, 3, 5, 7, 6, 8];
return take(a, [
[0, 1],
[2, 3],
]);
},
() =>
array([
[4, 3],
[5, 7],
])
)
.add(
'take',
() => {
let a = array([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
return take(
a,
[
[0, 1],
[2, 3],
],
0
);
},
() =>
array([
[
[1, 2],
[3, 4],
],
[
[5, 6],
[7, 8],
],
])
)
.add(
'take',
() => {
let a = array([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
return take(
a,
array([
[
[
[0, 1],
[0, 1],
],
],
]),
1
);
},
() =>
array([
[
[
[
[1, 2],
[1, 2],
],
],
],
[
[
[
[3, 4],
[3, 4],
],
],
],
[
[
[
[5, 6],
[5, 6],
],
],
],
[
[
[
[7, 8],
[7, 8],
],
],
],
])
)
.add(
'take',
() => {
let a = [
[5, 6, 2, 7, 1],
[4, 9, 2, 9, 3],
];
let indices = [0, 4];
return take(a, indices, 1);
},
() => [
[5, 1],
[4, 3],
]
);