axdriver/
virtio.rs

1use core::marker::PhantomData;
2use core::ptr::NonNull;
3
4use axalloc::global_allocator;
5use axdriver_base::{BaseDriverOps, DevResult, DeviceType};
6use axdriver_virtio::{BufferDirection, PhysAddr, VirtIoHal};
7use axhal::mem::{phys_to_virt, virt_to_phys};
8use cfg_if::cfg_if;
9
10use crate::{AxDeviceEnum, drivers::DriverProbe};
11
12cfg_if! {
13    if #[cfg(bus = "pci")] {
14        use axdriver_pci::{PciRoot, DeviceFunction, DeviceFunctionInfo};
15        type VirtIoTransport = axdriver_virtio::PciTransport;
16    } else if #[cfg(bus =  "mmio")] {
17        type VirtIoTransport = axdriver_virtio::MmioTransport;
18    }
19}
20
21/// A trait for VirtIO device meta information.
22pub trait VirtIoDevMeta {
23    const DEVICE_TYPE: DeviceType;
24
25    type Device: BaseDriverOps;
26    type Driver = VirtIoDriver<Self>;
27
28    fn try_new(transport: VirtIoTransport) -> DevResult<AxDeviceEnum>;
29}
30
31cfg_if! {
32    if #[cfg(net_dev = "virtio-net")] {
33        pub struct VirtIoNet;
34
35        impl VirtIoDevMeta for VirtIoNet {
36            const DEVICE_TYPE: DeviceType = DeviceType::Net;
37            type Device = axdriver_virtio::VirtIoNetDev<VirtIoHalImpl, VirtIoTransport, 64>;
38
39            fn try_new(transport: VirtIoTransport) -> DevResult<AxDeviceEnum> {
40                Ok(AxDeviceEnum::from_net(Self::Device::try_new(transport)?))
41            }
42        }
43    }
44}
45
46cfg_if! {
47    if #[cfg(block_dev = "virtio-blk")] {
48        pub struct VirtIoBlk;
49
50        impl VirtIoDevMeta for VirtIoBlk {
51            const DEVICE_TYPE: DeviceType = DeviceType::Block;
52            type Device = axdriver_virtio::VirtIoBlkDev<VirtIoHalImpl, VirtIoTransport>;
53
54            fn try_new(transport: VirtIoTransport) -> DevResult<AxDeviceEnum> {
55                Ok(AxDeviceEnum::from_block(Self::Device::try_new(transport)?))
56            }
57        }
58    }
59}
60
61cfg_if! {
62    if #[cfg(display_dev = "virtio-gpu")] {
63        pub struct VirtIoGpu;
64
65        impl VirtIoDevMeta for VirtIoGpu {
66            const DEVICE_TYPE: DeviceType = DeviceType::Display;
67            type Device = axdriver_virtio::VirtIoGpuDev<VirtIoHalImpl, VirtIoTransport>;
68
69            fn try_new(transport: VirtIoTransport) -> DevResult<AxDeviceEnum> {
70                Ok(AxDeviceEnum::from_display(Self::Device::try_new(transport)?))
71            }
72        }
73    }
74}
75
76/// A common driver for all VirtIO devices that implements [`DriverProbe`].
77pub struct VirtIoDriver<D: VirtIoDevMeta + ?Sized>(PhantomData<D>);
78
79impl<D: VirtIoDevMeta> DriverProbe for VirtIoDriver<D> {
80    #[cfg(bus = "mmio")]
81    fn probe_mmio(mmio_base: usize, mmio_size: usize) -> Option<AxDeviceEnum> {
82        let base_vaddr = phys_to_virt(mmio_base.into());
83        if let Some((ty, transport)) =
84            axdriver_virtio::probe_mmio_device(base_vaddr.as_mut_ptr(), mmio_size)
85            && ty == D::DEVICE_TYPE
86        {
87            match D::try_new(transport) {
88                Ok(dev) => return Some(dev),
89                Err(e) => {
90                    warn!(
91                        "failed to initialize MMIO device at [PA:{:#x}, PA:{:#x}): {:?}",
92                        mmio_base,
93                        mmio_base + mmio_size,
94                        e
95                    );
96                    return None;
97                }
98            }
99        }
100        None
101    }
102
103    #[cfg(bus = "pci")]
104    fn probe_pci(
105        root: &mut PciRoot,
106        bdf: DeviceFunction,
107        dev_info: &DeviceFunctionInfo,
108    ) -> Option<AxDeviceEnum> {
109        if dev_info.vendor_id != 0x1af4 {
110            return None;
111        }
112        match (D::DEVICE_TYPE, dev_info.device_id) {
113            (DeviceType::Net, 0x1000) | (DeviceType::Net, 0x1041) => {}
114            (DeviceType::Block, 0x1001) | (DeviceType::Block, 0x1042) => {}
115            (DeviceType::Display, 0x1050) => {}
116            _ => return None,
117        }
118
119        if let Some((ty, transport)) =
120            axdriver_virtio::probe_pci_device::<VirtIoHalImpl>(root, bdf, dev_info)
121        {
122            if ty == D::DEVICE_TYPE {
123                match D::try_new(transport) {
124                    Ok(dev) => return Some(dev),
125                    Err(e) => {
126                        warn!(
127                            "failed to initialize PCI device at {}({}): {:?}",
128                            bdf, dev_info, e
129                        );
130                        return None;
131                    }
132                }
133            }
134        }
135        None
136    }
137}
138
139pub struct VirtIoHalImpl;
140
141unsafe impl VirtIoHal for VirtIoHalImpl {
142    fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
143        let vaddr = if let Ok(vaddr) = global_allocator().alloc_pages(pages, 0x1000) {
144            vaddr
145        } else {
146            return (0, NonNull::dangling());
147        };
148        let paddr = virt_to_phys(vaddr.into());
149        let ptr = NonNull::new(vaddr as _).unwrap();
150        (paddr.as_usize(), ptr)
151    }
152
153    unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
154        global_allocator().dealloc_pages(vaddr.as_ptr() as usize, pages);
155        0
156    }
157
158    #[inline]
159    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
160        NonNull::new(phys_to_virt(paddr.into()).as_mut_ptr()).unwrap()
161    }
162
163    #[inline]
164    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
165        let vaddr = buffer.as_ptr() as *mut u8 as usize;
166        virt_to_phys(vaddr.into()).into()
167    }
168
169    #[inline]
170    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {}
171}