+static inline int
+minstrel_get_duration(int index)
+{
+ const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
+ unsigned int duration = group->duration[index % MCS_GROUP_RATES];
+ return duration << group->shift;
+}
+
+static bool
+minstrel_ht_probe_group(struct minstrel_ht_sta *mi, const struct mcs_group *tp_group,
+ int tp_idx, const struct mcs_group *group)
+{
+ if (group->bw < tp_group->bw)
+ return false;
+
+ if (group->streams == tp_group->streams)
+ return true;
+
+ if (tp_idx < 4 && group->streams == tp_group->streams - 1)
+ return true;
+
+ return group->streams == tp_group->streams + 1;
+}
+
+static void
+minstrel_ht_find_probe_rates(struct minstrel_ht_sta *mi, u16 *rates, int *n_rates,
+ bool faster_rate)
+{
+ const struct mcs_group *group, *tp_group;
+ int i, g, max_dur;
+ int tp_idx;
+
+ tp_group = &minstrel_mcs_groups[mi->max_tp_rate[0] / MCS_GROUP_RATES];
+ tp_idx = mi->max_tp_rate[0] % MCS_GROUP_RATES;
+
+ max_dur = minstrel_get_duration(mi->max_tp_rate[0]);
+ if (faster_rate)
+ max_dur -= max_dur / 16;
+
+ for (g = 0; g < MINSTREL_GROUPS_NB; g++) {
+ u16 supported = mi->supported[g];
+
+ if (!supported)
+ continue;
+
+ group = &minstrel_mcs_groups[g];
+ if (!minstrel_ht_probe_group(mi, tp_group, tp_idx, group))
+ continue;
+
+ for (i = 0; supported; supported >>= 1, i++) {
+ int idx;
+
+ if (!(supported & 1))
+ continue;
+
+ if ((group->duration[i] << group->shift) > max_dur)
+ continue;
+
+ idx = g * MCS_GROUP_RATES + i;
+ if (idx == mi->max_tp_rate[0])
+ continue;
+
+ rates[(*n_rates)++] = idx;
+ break;
+ }
+ }
+}
+
+static void
+minstrel_ht_rate_sample_switch(struct minstrel_priv *mp,
+ struct minstrel_ht_sta *mi)
+{
+ struct minstrel_rate_stats *mrs;
+ u16 rates[MINSTREL_GROUPS_NB];
+ int n_rates = 0;
+ int probe_rate = 0;
+ bool faster_rate;
+ int i;
+ u8 random;
+
+ /*
+ * Use rate switching instead of probing packets for devices with
+ * little control over retry fallback behavior
+ */
+ if (mp->hw->max_rates > 1)
+ return;
+
+ /*
+ * If the current EWMA prob is >75%, look for a rate that's 6.25%
+ * faster than the max tp rate.
+ * If that fails, look again for a rate that is at least as fast
+ */
+ mrs = minstrel_get_ratestats(mi, mi->max_tp_rate[0]);
+ faster_rate = mrs->prob_ewma > MINSTREL_FRAC(75, 100);
+ minstrel_ht_find_probe_rates(mi, rates, &n_rates, faster_rate);
+ if (!n_rates && faster_rate)
+ minstrel_ht_find_probe_rates(mi, rates, &n_rates, false);
+
+ /* If no suitable rate was found, try to pick the next one in the group */
+ if (!n_rates) {
+ int g_idx = mi->max_tp_rate[0] / MCS_GROUP_RATES;
+ u16 supported = mi->supported[g_idx];
+
+ supported >>= mi->max_tp_rate[0] % MCS_GROUP_RATES;
+ for (i = 0; supported; i++) {
+ if (!(supported & 1))
+ continue;
+
+ probe_rate = mi->max_tp_rate[0] + i;
+ goto out;
+ }
+
+ return;
+ }
+
+ i = 0;
+ if (n_rates > 1) {
+ random = prandom_u32();
+ i = random % n_rates;
+ }
+ probe_rate = rates[i];
+
+out:
+ mi->sample_rate = probe_rate;
+ mi->sample_mode = MINSTREL_SAMPLE_ACTIVE;
+}
+