819 Stimmen

Wie erhalte ich Indizes von N maximalen Werten in einem NumPy-Array?

NumPy bietet eine Möglichkeit, den Index des maximalen Wertes eines Arrays über np.argmax .

Ich würde gerne etwas Ähnliches machen, aber die Indizes der N Höchstwerte.

Zum Beispiel, wenn ich ein Array habe, [1, 3, 2, 4, 5] entonces nargmax(array, n=3) würde die Indizes zurückgeben [4, 3, 1] die den folgenden Elementen entsprechen [5, 4, 3] .

1001voto

Fred Foo Punkte 341230

Neuere NumPy-Versionen (1.8 und höher) haben eine Funktion namens argpartition für diese. Um die Indizes der vier größten Elemente zu erhalten, gehen Sie wie folgt vor

>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])

>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])

>>> top4 = a[ind]
>>> top4
array([4, 9, 6, 9])

Anders als argsort läuft diese Funktion im ungünstigsten Fall in linearer Zeit, aber die zurückgegebenen Indizes sind nicht sortiert, wie man am Ergebnis der Auswertung von a[ind] . Wenn Sie das auch brauchen, sortieren Sie sie anschließend:

>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])

Um die Top- k Elemente in sortierter Reihenfolge auf diese Weise dauert O( n + k Protokoll k ) Zeit.

515voto

NPE Punkte 462670

Die einfachste Lösung, die mir eingefallen ist, lautet:

>>> import numpy as np
>>> arr = np.array([1, 3, 2, 4, 5])
>>> arr.argsort()[-3:][::-1]
array([4, 3, 1])

Dazu ist eine vollständige Sortierung des Arrays erforderlich. Ich frage mich, ob numpy bietet eine eingebaute Möglichkeit, eine partielle Sortierung durchzuführen; bisher konnte ich keine finden.

Wenn sich diese Lösung als zu langsam erweist (insbesondere bei kleinen n ), könnte es sich lohnen, etwas in Cython .

82voto

Ketan Punkte 1297

Das ist noch einfacher:

idx = (-arr).argsort()[:n]

donde n ist die Anzahl der Maximalwerte.

52voto

anishpatel Punkte 1372

Verwendung:

>>> import heapq
>>> import numpy
>>> a = numpy.array([1, 3, 2, 4, 5])
>>> heapq.nlargest(3, range(len(a)), a.take)
[4, 3, 1]

Für reguläre Python-Listen:

>>> a = [1, 3, 2, 4, 5]
>>> heapq.nlargest(3, range(len(a)), a.__getitem__)
[4, 3, 1]

Wenn Sie Python 2 verwenden, benutzen Sie xrange anstelle von range .

Quelle: heapq - Heap-Warteschlangen-Algorithmus

44voto

danvk Punkte 14538

Wenn Sie mit einem mehrdimensionalen Array arbeiten, müssen Sie die Indizes glätten und entflechten:

def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

Zum Beispiel:

>>> xs = np.sin(np.arange(9)).reshape((3, 3))
>>> xs
array([[ 0.        ,  0.84147098,  0.90929743],
       [ 0.14112001, -0.7568025 , -0.95892427],
       [-0.2794155 ,  0.6569866 ,  0.98935825]])
>>> largest_indices(xs, 3)
(array([2, 0, 0]), array([2, 2, 1]))
>>> xs[largest_indices(xs, 3)]
array([ 0.98935825,  0.90929743,  0.84147098])

CodeJaeger.com

CodeJaeger ist eine Gemeinschaft für Programmierer, die täglich Hilfe erhalten..
Wir haben viele Inhalte, und Sie können auch Ihre eigenen Fragen stellen oder die Fragen anderer Leute lösen.

Powered by:

X