Skip to content

Commit

Permalink
language_model_selector: Refresh the models when the providers change (
Browse files Browse the repository at this point in the history
…#22624)

This PR fixes an issue introduced in #21939 where the list of models in
the language model selector could be outdated.

Since we're no longer recreating the picker each render, we now need to
make sure we are updating the list of models accordingly when there are
changes to the language model providers.

I noticed it specifically in Assistant1.

Release Notes:

- Fixed a staleness issue with the language model selector.
  • Loading branch information
maxdeviant authored Jan 3, 2025
1 parent e4eef72 commit 04518b1
Showing 1 changed file with 56 additions and 21 deletions.
77 changes: 56 additions & 21 deletions crates/language_model_selector/src/language_model_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::sync::Arc;

use feature_flags::ZedPro;
use gpui::{
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
View, WeakView,
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
Subscription, Task, View, WeakView,
};
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use picker::{Picker, PickerDelegate};
Expand All @@ -17,6 +17,10 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>

pub struct LanguageModelSelector {
picker: View<Picker<LanguageModelPickerDelegate>>,
/// The task used to update the picker's matches when there is a change to
/// the language model registry.
update_matches_task: Option<Task<()>>,
_subscriptions: Vec<Subscription>,
}

impl LanguageModelSelector {
Expand All @@ -26,7 +30,51 @@ impl LanguageModelSelector {
) -> Self {
let on_model_changed = Arc::new(on_model_changed);

let all_models = LanguageModelRegistry::global(cx)
let all_models = Self::all_models(cx);
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.view().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: all_models.clone(),
filtered_models: all_models,
selected_index: 0,
};

let picker =
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));

LanguageModelSelector {
picker,
update_matches_task: None,
_subscriptions: vec![cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_language_model_registry_event,
)],
}
}

fn handle_language_model_registry_event(
&mut self,
_registry: Model<LanguageModelRegistry>,
event: &language_model::Event,
cx: &mut ViewContext<Self>,
) {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let task = self.picker.update(cx, |this, cx| {
let query = this.query(cx);
this.delegate.all_models = Self::all_models(cx);
this.delegate.update_matches(query, cx)
});
self.update_matches_task = Some(task);
}
_ => {}
}
}

fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
Expand All @@ -44,20 +92,7 @@ impl LanguageModelSelector {
}
})
})
.collect::<Vec<_>>();

let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.view().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: all_models.clone(),
filtered_models: all_models,
selected_index: 0,
};

let picker =
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));

LanguageModelSelector { picker }
.collect::<Vec<_>>()
}
}

Expand Down Expand Up @@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate {

let llm_registry = LanguageModelRegistry::global(cx);

let configured_models: Vec<_> = llm_registry
let configured_providers = llm_registry
.read(cx)
.providers()
.iter()
.filter(|provider| provider.is_authenticated(cx))
.map(|provider| provider.id())
.collect();
.collect::<Vec<_>>();

cx.spawn(|this, mut cx| async move {
let filtered_models = cx
.background_executor()
.spawn(async move {
let displayed_models = if configured_models.is_empty() {
let displayed_models = if configured_providers.is_empty() {
all_models
} else {
all_models
.into_iter()
.filter(|model_info| {
configured_models.contains(&model_info.model.provider_id())
configured_providers.contains(&model_info.model.provider_id())
})
.collect::<Vec<_>>()
};
Expand Down

0 comments on commit 04518b1

Please sign in to comment.