Commit 318c3cc8 authored by Miguel Ojeda's avatar Miguel Ojeda
Browse files

rust: alloc: vec: Add some try_* methods we need



Add some missing fallible methods that we need.

They are all marked as:

    #[stable(feature = "kernel", since = "1.0.0")]

for easy identification.

Lina: Extracted from commit 487d7578bd03 ("rust: alloc: add some `try_*`
methods we need") in rust-for-linux/rust.

Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
Signed-off-by: default avatarAsahi Lina <lina@asahilina.net>
Link: https://github.com/Rust-for-Linux/linux/commit/487d7578bd03
Link: https://lore.kernel.org/r/20230224-rust-vec-v1-4-733b5b5a57c5@asahilina.net


[ Match the non-fallible methods from version 1.62.0, since those
  in commit 487d7578bd03 were written for 1.54.0-beta.1. ]
Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent 3dcb652a
Loading
Loading
Loading
Loading
+134 −3
Original line number Diff line number Diff line
@@ -122,10 +122,8 @@
#[cfg(not(no_global_oom_handling))]
mod spec_from_elem;

#[cfg(not(no_global_oom_handling))]
use self::set_len_on_drop::SetLenOnDrop;

#[cfg(not(no_global_oom_handling))]
mod set_len_on_drop;

#[cfg(not(no_global_oom_handling))]
@@ -149,7 +147,8 @@
#[cfg(not(no_global_oom_handling))]
use self::spec_extend::SpecExtend;

#[cfg(not(no_global_oom_handling))]
use self::spec_extend::TrySpecExtend;

mod spec_extend;

/// A contiguous growable array type, written as `Vec<T>`, short for 'vector'.
@@ -1919,6 +1918,17 @@ unsafe fn append_elements(&mut self, other: *const [T]) {
        self.len += count;
    }

    /// Tries to append elements to `self` from other buffer.
    #[inline]
    unsafe fn try_append_elements(&mut self, other: *const [T]) -> Result<(), TryReserveError> {
        let count = unsafe { (*other).len() };
        self.try_reserve(count)?;
        let len = self.len();
        unsafe { ptr::copy_nonoverlapping(other as *const T, self.as_mut_ptr().add(len), count) };
        self.len += count;
        Ok(())
    }

    /// Removes the specified range from the vector in bulk, returning all
    /// removed elements as an iterator. If the iterator is dropped before
    /// being fully consumed, it drops the remaining removed elements.
@@ -2340,6 +2350,45 @@ pub fn resize(&mut self, new_len: usize, value: T) {
        }
    }

    /// Tries to resize the `Vec` in-place so that `len` is equal to `new_len`.
    ///
    /// If `new_len` is greater than `len`, the `Vec` is extended by the
    /// difference, with each additional slot filled with `value`.
    /// If `new_len` is less than `len`, the `Vec` is simply truncated.
    ///
    /// This method requires `T` to implement [`Clone`],
    /// in order to be able to clone the passed value.
    /// If you need more flexibility (or want to rely on [`Default`] instead of
    /// [`Clone`]), use [`Vec::resize_with`].
    /// If you only need to resize to a smaller size, use [`Vec::truncate`].
    ///
    /// # Examples
    ///
    /// ```
    /// let mut vec = vec!["hello"];
    /// vec.try_resize(3, "world").unwrap();
    /// assert_eq!(vec, ["hello", "world", "world"]);
    ///
    /// let mut vec = vec![1, 2, 3, 4];
    /// vec.try_resize(2, 0).unwrap();
    /// assert_eq!(vec, [1, 2]);
    ///
    /// let mut vec = vec![42];
    /// let result = vec.try_resize(usize::MAX, 0);
    /// assert!(result.is_err());
    /// ```
    #[stable(feature = "kernel", since = "1.0.0")]
    pub fn try_resize(&mut self, new_len: usize, value: T) -> Result<(), TryReserveError> {
        let len = self.len();

        if new_len > len {
            self.try_extend_with(new_len - len, ExtendElement(value))
        } else {
            self.truncate(new_len);
            Ok(())
        }
    }

    /// Clones and appends all elements in a slice to the `Vec`.
    ///
    /// Iterates over the slice `other`, clones each element, and then appends
@@ -2365,6 +2414,30 @@ pub fn extend_from_slice(&mut self, other: &[T]) {
        self.spec_extend(other.iter())
    }

    /// Tries to clone and append all elements in a slice to the `Vec`.
    ///
    /// Iterates over the slice `other`, clones each element, and then appends
    /// it to this `Vec`. The `other` slice is traversed in-order.
    ///
    /// Note that this function is same as [`extend`] except that it is
    /// specialized to work with slices instead. If and when Rust gets
    /// specialization this function will likely be deprecated (but still
    /// available).
    ///
    /// # Examples
    ///
    /// ```
    /// let mut vec = vec![1];
    /// vec.try_extend_from_slice(&[2, 3, 4]).unwrap();
    /// assert_eq!(vec, [1, 2, 3, 4]);
    /// ```
    ///
    /// [`extend`]: Vec::extend
    #[stable(feature = "kernel", since = "1.0.0")]
    pub fn try_extend_from_slice(&mut self, other: &[T]) -> Result<(), TryReserveError> {
        self.try_spec_extend(other.iter())
    }

    /// Copies elements from `src` range to the end of the vector.
    ///
    /// # Panics
@@ -2504,6 +2577,36 @@ fn extend_with<E: ExtendWith<T>>(&mut self, n: usize, mut value: E) {
            // len set by scope guard
        }
    }

    /// Try to extend the vector by `n` values, using the given generator.
    fn try_extend_with<E: ExtendWith<T>>(&mut self, n: usize, mut value: E) -> Result<(), TryReserveError> {
        self.try_reserve(n)?;

        unsafe {
            let mut ptr = self.as_mut_ptr().add(self.len());
            // Use SetLenOnDrop to work around bug where compiler
            // might not realize the store through `ptr` through self.set_len()
            // don't alias.
            let mut local_len = SetLenOnDrop::new(&mut self.len);

            // Write all elements except the last one
            for _ in 1..n {
                ptr::write(ptr, value.next());
                ptr = ptr.offset(1);
                // Increment the length in every step in case next() panics
                local_len.increment_len(1);
            }

            if n > 0 {
                // We can write the last element directly without cloning needlessly
                ptr::write(ptr, value.last());
                local_len.increment_len(1);
            }

            // len set by scope guard
            Ok(())
        }
    }
}

impl<T: PartialEq, A: Allocator> Vec<T, A> {
@@ -2838,6 +2941,34 @@ fn extend_desugared<I: Iterator<Item = T>>(&mut self, mut iterator: I) {
        }
    }

    // leaf method to which various SpecFrom/SpecExtend implementations delegate when
    // they have no further optimizations to apply
    fn try_extend_desugared<I: Iterator<Item = T>>(&mut self, mut iterator: I) -> Result<(), TryReserveError> {
        // This is the case for a general iterator.
        //
        // This function should be the moral equivalent of:
        //
        //      for item in iterator {
        //          self.push(item);
        //      }
        while let Some(element) = iterator.next() {
            let len = self.len();
            if len == self.capacity() {
                let (lower, _) = iterator.size_hint();
                self.try_reserve(lower.saturating_add(1))?;
            }
            unsafe {
                ptr::write(self.as_mut_ptr().add(len), element);
                // Since next() executes user code which can panic we have to bump the length
                // after each step.
                // NB can't overflow since we would have had to alloc the address space
                self.set_len(len + 1);
            }
        }

        Ok(())
    }

    /// Creates a splicing iterator that replaces the specified range in the vector
    /// with the given `replace_with` iterator and yields the removed items.
    /// `replace_with` does not need to be the same length as `range`.
+85 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

use crate::alloc::Allocator;
use crate::collections::{TryReserveError, TryReserveErrorKind};
use core::iter::TrustedLen;
use core::ptr::{self};
use core::slice::{self};
@@ -8,10 +9,17 @@
use super::{IntoIter, SetLenOnDrop, Vec};

// Specialization trait used for Vec::extend
#[cfg(not(no_global_oom_handling))]
pub(super) trait SpecExtend<T, I> {
    fn spec_extend(&mut self, iter: I);
}

// Specialization trait used for Vec::try_extend
pub(super) trait TrySpecExtend<T, I> {
    fn try_spec_extend(&mut self, iter: I) -> Result<(), TryReserveError>;
}

#[cfg(not(no_global_oom_handling))]
impl<T, I, A: Allocator> SpecExtend<T, I> for Vec<T, A>
where
    I: Iterator<Item = T>,
@@ -21,6 +29,16 @@ impl<T, I, A: Allocator> SpecExtend<T, I> for Vec<T, A>
    }
}

impl<T, I, A: Allocator> TrySpecExtend<T, I> for Vec<T, A>
where
    I: Iterator<Item = T>,
{
    default fn try_spec_extend(&mut self, iter: I) -> Result<(), TryReserveError> {
        self.try_extend_desugared(iter)
    }
}

#[cfg(not(no_global_oom_handling))]
impl<T, I, A: Allocator> SpecExtend<T, I> for Vec<T, A>
where
    I: TrustedLen<Item = T>,
@@ -59,6 +77,41 @@ impl<T, I, A: Allocator> SpecExtend<T, I> for Vec<T, A>
    }
}

impl<T, I, A: Allocator> TrySpecExtend<T, I> for Vec<T, A>
where
    I: TrustedLen<Item = T>,
{
    default fn try_spec_extend(&mut self, iterator: I) -> Result<(), TryReserveError> {
        // This is the case for a TrustedLen iterator.
        let (low, high) = iterator.size_hint();
        if let Some(additional) = high {
            debug_assert_eq!(
                low,
                additional,
                "TrustedLen iterator's size hint is not exact: {:?}",
                (low, high)
            );
            self.try_reserve(additional)?;
            unsafe {
                let mut ptr = self.as_mut_ptr().add(self.len());
                let mut local_len = SetLenOnDrop::new(&mut self.len);
                iterator.for_each(move |element| {
                    ptr::write(ptr, element);
                    ptr = ptr.offset(1);
                    // Since the loop executes user code which can panic we have to bump the pointer
                    // after each step.
                    // NB can't overflow since we would have had to alloc the address space
                    local_len.increment_len(1);
                });
            }
            Ok(())
        } else {
            Err(TryReserveErrorKind::CapacityOverflow.into())
        }
    }
}

#[cfg(not(no_global_oom_handling))]
impl<T, A: Allocator> SpecExtend<T, IntoIter<T>> for Vec<T, A> {
    fn spec_extend(&mut self, mut iterator: IntoIter<T>) {
        unsafe {
@@ -68,6 +121,17 @@ fn spec_extend(&mut self, mut iterator: IntoIter<T>) {
    }
}

impl<T, A: Allocator> TrySpecExtend<T, IntoIter<T>> for Vec<T, A> {
    fn try_spec_extend(&mut self, mut iterator: IntoIter<T>) -> Result<(), TryReserveError> {
        unsafe {
            self.try_append_elements(iterator.as_slice() as _)?;
        }
        iterator.forget_remaining_elements();
        Ok(())
    }
}

#[cfg(not(no_global_oom_handling))]
impl<'a, T: 'a, I, A: Allocator + 'a> SpecExtend<&'a T, I> for Vec<T, A>
where
    I: Iterator<Item = &'a T>,
@@ -78,6 +142,17 @@ impl<'a, T: 'a, I, A: Allocator + 'a> SpecExtend<&'a T, I> for Vec<T, A>
    }
}

impl<'a, T: 'a, I, A: Allocator + 'a> TrySpecExtend<&'a T, I> for Vec<T, A>
where
    I: Iterator<Item = &'a T>,
    T: Clone,
{
    default fn try_spec_extend(&mut self, iterator: I) -> Result<(), TryReserveError> {
        self.try_spec_extend(iterator.cloned())
    }
}

#[cfg(not(no_global_oom_handling))]
impl<'a, T: 'a, A: Allocator + 'a> SpecExtend<&'a T, slice::Iter<'a, T>> for Vec<T, A>
where
    T: Copy,
@@ -87,3 +162,13 @@ fn spec_extend(&mut self, iterator: slice::Iter<'a, T>) {
        unsafe { self.append_elements(slice) };
    }
}

impl<'a, T: 'a, A: Allocator + 'a> TrySpecExtend<&'a T, slice::Iter<'a, T>> for Vec<T, A>
where
    T: Copy,
{
    fn try_spec_extend(&mut self, iterator: slice::Iter<'a, T>) -> Result<(), TryReserveError> {
        let slice = iterator.as_slice();
        unsafe { self.try_append_elements(slice) }
    }
}