Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 42 additions & 51 deletions controllers/nvidiadriver_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ func (r *NVIDIADriverReconciler) updateCrStatus(
return nil
}

// enqueueAllNVIDIADrivers lists all NVIDIADriver instances in the cluster and enqueues a reconcile
// request for each instance. This is used to trigger reconciliation for all NVIDIADriver instances
// when a relevant event occurs (e.g. ClusterPolicy/NVIDIADriver update, node label change, etc).
func (r *NVIDIADriverReconciler) enqueueAllNVIDIADrivers(ctx context.Context) []reconcile.Request {
logger := log.FromContext(ctx)
list := &nvidiav1alpha1.NVIDIADriverList{}

err := r.List(ctx, list)
if err != nil {
logger.Error(err, "Unable to list NVIDIADriver resources")
return []reconcile.Request{}
}

reconcileRequests := make([]reconcile.Request, 0, len(list.Items))
for _, nvidiaDriver := range list.Items {
reconcileRequests = append(reconcileRequests,
reconcile.Request{
NamespacedName: types.NamespacedName{
Name: nvidiaDriver.GetName(),
Namespace: nvidiaDriver.GetNamespace(),
},
})
}

return reconcileRequests
}

// SetupWithManager sets up the controller with the Manager.
func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error {
// Create state manager
Expand Down Expand Up @@ -277,11 +304,17 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
return err
}

// Watch for changes to the primary resource NVIDIaDriver
// Watch for changes to NVIDIADriver CRs. Whenever an event is generated for a NVIDIADriver CR,
// enqueue a reconcile request for all NVIDIADriver instances.
nvidiaDriverMapFn := func(ctx context.Context, _ *nvidiav1alpha1.NVIDIADriver) []reconcile.Request {
return r.enqueueAllNVIDIADrivers(ctx)
}

// Watch for changes to the primary resource NVIDIADriver
err = c.Watch(source.Kind(
mgr.GetCache(),
&nvidiav1alpha1.NVIDIADriver{},
&handler.TypedEnqueueRequestForObject[*nvidiav1alpha1.NVIDIADriver]{},
handler.TypedEnqueueRequestsFromMapFunc(nvidiaDriverMapFn),
predicate.TypedGenerationChangedPredicate[*nvidiav1alpha1.NVIDIADriver]{},
),
)
Expand All @@ -291,63 +324,21 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.

// Watch for changes to ClusterPolicy. Whenever an event is generated for ClusterPolicy, enqueue
// a reconcile request for all NVIDIADriver instances.
mapFn := func(ctx context.Context, cp *gpuv1.ClusterPolicy) []reconcile.Request {
logger := log.FromContext(ctx)
opts := []client.ListOption{}
list := &nvidiav1alpha1.NVIDIADriverList{}

err := mgr.GetClient().List(ctx, list, opts...)
if err != nil {
logger.Error(err, "Unable to list NVIDIADriver resources")
return []reconcile.Request{}
}

reconcileRequests := []reconcile.Request{}
for _, nvidiaDriver := range list.Items {
reconcileRequests = append(reconcileRequests,
reconcile.Request{
NamespacedName: types.NamespacedName{
Name: nvidiaDriver.GetName(),
Namespace: nvidiaDriver.GetNamespace(),
},
})
}

return reconcileRequests
mapFn := func(ctx context.Context, _ *gpuv1.ClusterPolicy) []reconcile.Request {
return r.enqueueAllNVIDIADrivers(ctx)
}

// Watch for changes to the Nodes. Whenever an event is generated for ClusterPolicy, enqueue
// Watch for changes to the Nodes. Whenever an event is generated for a Node, enqueue
// a reconcile request for all NVIDIADriver instances.
nodeMapFn := func(ctx context.Context, cp *corev1.Node) []reconcile.Request {
logger := log.FromContext(ctx)
opts := []client.ListOption{}
list := &nvidiav1alpha1.NVIDIADriverList{}

err := mgr.GetClient().List(ctx, list, opts...)
if err != nil {
logger.Error(err, "Unable to list NVIDIADriver resources")
return []reconcile.Request{}
}

reconcileRequests := []reconcile.Request{}
for _, nvidiaDriver := range list.Items {
reconcileRequests = append(reconcileRequests,
reconcile.Request{
NamespacedName: types.NamespacedName{
Name: nvidiaDriver.GetName(),
Namespace: nvidiaDriver.GetNamespace(),
},
})
}

return reconcileRequests
nodeMapFn := func(ctx context.Context, _ *corev1.Node) []reconcile.Request {
return r.enqueueAllNVIDIADrivers(ctx)
}

err = c.Watch(
source.Kind(
mgr.GetCache(),
&gpuv1.ClusterPolicy{},
handler.TypedEnqueueRequestsFromMapFunc[*gpuv1.ClusterPolicy](mapFn),
handler.TypedEnqueueRequestsFromMapFunc(mapFn),
predicate.TypedGenerationChangedPredicate[*gpuv1.ClusterPolicy]{},
),
)
Expand Down Expand Up @@ -385,7 +376,7 @@ func (r *NVIDIADriverReconciler) SetupWithManager(ctx context.Context, mgr ctrl.
err = c.Watch(
source.Kind(mgr.GetCache(),
&corev1.Node{},
handler.TypedEnqueueRequestsFromMapFunc[*corev1.Node](nodeMapFn),
handler.TypedEnqueueRequestsFromMapFunc(nodeMapFn),
nodePredicate,
),
)
Expand Down
22 changes: 22 additions & 0 deletions controllers/nvidiadriver_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"errors"
"sort"
"testing"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -193,3 +194,24 @@ func TestReconcile(t *testing.T) {
})
}
}

func TestEnqueueAllNVIDIADrivers(t *testing.T) {
scheme := runtime.NewScheme()
require.NoError(t, nvidiav1alpha1.AddToScheme(scheme))

client := fake.NewClientBuilder().WithScheme(scheme).WithObjects(
&nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-a", Namespace: "default"}},
&nvidiav1alpha1.NVIDIADriver{ObjectMeta: metav1.ObjectMeta{Name: "driver-b", Namespace: "default"}},
).Build()

reconciler := &NVIDIADriverReconciler{Client: client}
requests := reconciler.enqueueAllNVIDIADrivers(context.Background())

require.Len(t, requests, 2)
got := []string{
requests[0].String(),
requests[1].String(),
}
sort.Strings(got)
require.Equal(t, []string{"default/driver-a", "default/driver-b"}, got)
}
Loading