Module: Mesa
Branch: main
Commit: 9e84e9e44b111a6afe8a346fb0bb74f9c597af61
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=9e84e9e44b111a6afe8a346fb0bb74f9c597af61

Author: Faith Ekstrand <[email protected]>
Date:   Wed Nov 22 14:32:21 2023 -0600

nak: Add base support for 8 and 16-bit types

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26348>

---

 src/nouveau/compiler/nak_builder.rs  |  38 +++++
 src/nouveau/compiler/nak_from_nir.rs | 259 +++++++++++++++++++++++++++--------
 2 files changed, 237 insertions(+), 60 deletions(-)

diff --git a/src/nouveau/compiler/nak_builder.rs 
b/src/nouveau/compiler/nak_builder.rs
index 585bde11d60..27211f72f90 100644
--- a/src/nouveau/compiler/nak_builder.rs
+++ b/src/nouveau/compiler/nak_builder.rs
@@ -233,6 +233,44 @@ pub trait SSABuilder: Builder {
         dst
     }
 
+    fn prmt4(&mut self, src: [Src;4], sel: [u8;4]) -> SSARef {
+        let max_sel = *sel.iter().max().unwrap();
+        if max_sel < 8 {
+            self.prmt(src[0], src[1], sel)
+        } else if max_sel < 12 {
+            let mut sel_a = [0_u8; 4];
+            let mut sel_b = [0_u8; 4];
+            for i in 0..4_u8 {
+                if sel[usize::from(i)] < 8 {
+                    sel_a[usize::from(i)] = sel[usize::from(i)];
+                    sel_b[usize::from(i)] = i;
+                } else {
+                    sel_b[usize::from(i)] = (sel[usize::from(i)] - 8) + 4;
+                }
+            }
+            let a = self.prmt(src[0], src[1], sel_a);
+            self.prmt(a.into(), src[2], sel_b)
+        } else if max_sel < 16 {
+            let mut sel_a = [0_u8; 4];
+            let mut sel_b = [0_u8; 4];
+            let mut sel_c = [0_u8; 4];
+            for i in 0..4_u8 {
+                if sel[usize::from(i)] < 8 {
+                    sel_a[usize::from(i)] = sel[usize::from(i)];
+                    sel_c[usize::from(i)] = i;
+                } else {
+                    sel_b[usize::from(i)] = sel[usize::from(i)] - 8;
+                    sel_c[usize::from(i)] = 4 + i;
+                }
+            }
+            let a = self.prmt(src[0], src[1], sel_a);
+            let b = self.prmt(src[2], src[3], sel_b);
+            self.prmt(a.into(), b.into(), sel_c)
+        } else {
+            panic!("Invalid permute value: {max_sel}");
+        }
+    }
+
     fn sel(&mut self, cond: Src, x: Src, y: Src) -> SSARef {
         assert!(cond.src_ref.is_predicate());
         assert!(x.is_predicate() == y.is_predicate());
diff --git a/src/nouveau/compiler/nak_from_nir.rs 
b/src/nouveau/compiler/nak_from_nir.rs
index 69d82252615..26ca4a7e2bf 100644
--- a/src/nouveau/compiler/nak_from_nir.rs
+++ b/src/nouveau/compiler/nak_from_nir.rs
@@ -243,12 +243,21 @@ impl<'a> ShaderFromNir<'a> {
             .or_insert(vec);
     }
 
-    fn get_ssa_comp(&mut self, def: &nir_def, c: u8) -> SSARef {
+    fn get_ssa_comp(&mut self, def: &nir_def, c: u8) -> (SSARef, u8) {
         let vec = self.get_ssa(def);
         match def.bit_size {
-            1 | 32 => vec[usize::from(c)].into(),
-            64 => [vec[usize::from(c) * 2], vec[usize::from(c) * 2 + 
1]].into(),
-            _ => panic!("Unsupported bit size"),
+            1 => (vec[usize::from(c)].into(), 0),
+            8 => (vec[usize::from(c / 4)].into(), c % 4),
+            16 => (vec[usize::from(c / 2)].into(), (c * 2) % 4),
+            32 => (vec[usize::from(c)].into(), 0),
+            64 => {
+                let comps = [
+                    vec[usize::from(c) * 2 + 0],
+                    vec[usize::from(c) * 2 + 1],
+                ];
+                (comps.into(), 0)
+            }
+            _ => panic!("Unsupported bit size: {}", def.bit_size),
         }
     }
 
@@ -271,7 +280,7 @@ impl<'a> ShaderFromNir<'a> {
         if let Some(base_def) = std::ptr::NonNull::new(addr_offset.base.def) {
             let base_def = unsafe { base_def.as_ref() };
             let base_comp = u8::try_from(addr_offset.base.comp).unwrap();
-            let base = self.get_ssa_comp(base_def, base_comp);
+            let (base, _) = self.get_ssa_comp(base_def, base_comp);
             (base.into(), addr_offset.offset)
         } else {
             (SrcRef::Zero.into(), addr_offset.offset)
@@ -296,53 +305,156 @@ impl<'a> ShaderFromNir<'a> {
     }
 
     fn parse_alu(&mut self, b: &mut impl SSABuilder, alu: &nir_alu_instr) {
-        let mut srcs = Vec::new();
-        for (i, alu_src) in alu.srcs_as_slice().iter().enumerate() {
-            let bit_size = alu_src.src.bit_size();
-            let comps = alu.src_components(i.try_into().unwrap());
-
-            let alu_src_ssa = self.get_ssa(&alu_src.src.as_def());
-            let mut src_comps = Vec::new();
-            for c in 0..comps {
-                let s = usize::from(alu_src.swizzle[usize::from(c)]);
-                if bit_size == 1 || bit_size == 32 {
-                    src_comps.push(alu_src_ssa[s]);
-                } else if bit_size == 64 {
-                    src_comps.push(alu_src_ssa[s * 2]);
-                    src_comps.push(alu_src_ssa[s * 2 + 1]);
-                } else {
-                    panic!("Unhandled bit size");
-                }
-            }
-            srcs.push(Src::from(SSARef::try_from(src_comps).unwrap()));
-        }
-
-        /* Handle vectors as a special case since they're the only ALU ops that
-         * can produce more than a 16B of data.
-         */
+        // Handle vectors and pack ops as a special case since they're the only
+        // ALU ops that can produce more than 16B. They are also the only ALU
+        // ops which we allow to consume small (8 and 16-bit) vector data
+        // scattered across multiple dwords
         match alu.op {
-            nir_op_mov | nir_op_vec2 | nir_op_vec3 | nir_op_vec4
+            nir_op_mov
+            | nir_op_pack_32_4x8_split
+            | nir_op_pack_32_2x16_split
+            | nir_op_pack_64_2x32_split
+            | nir_op_vec2 | nir_op_vec3 | nir_op_vec4
             | nir_op_vec5 | nir_op_vec8 | nir_op_vec16 => {
-                let file = if alu.def.bit_size == 1 {
-                    RegFile::Pred
+                let src_bit_size = alu.get_src(0).src.bit_size();
+                let bits = alu.def.num_components * alu.def.bit_size;
+
+                // Collect the sources into a vec with src_bit_size per SSA
+                // value in the vec.  This implicitly makes 64-bit sources look
+                // like two 32-bit values
+                let mut srcs = Vec::new();
+                if alu.op == nir_op_mov {
+                    let src = alu.get_src(0);
+                    for c in 0..alu.def.num_components {
+                        let s = src.swizzle[usize::from(c)];
+                        let (src, byte) =
+                            self.get_ssa_comp(src.src.as_def(), s);
+                        for ssa in src.iter() {
+                            srcs.push((*ssa, byte));
+                        }
+                    }
                 } else {
-                    RegFile::GPR
-                };
+                    for src in alu.srcs_as_slice().iter() {
+                        let s = src.swizzle[0];
+                        let (src, byte) =
+                            self.get_ssa_comp(src.src.as_def(), s);
+                        for ssa in src.iter() {
+                            srcs.push((*ssa, byte));
+                        }
+                    }
+                }
 
-                let mut dst_vec = Vec::new();
-                for src in srcs {
-                    for v in src.as_ssa().unwrap().iter() {
-                        let dst = b.alloc_ssa(file, 1)[0];
-                        b.copy_to(dst.into(), (*v).into());
-                        dst_vec.push(dst);
+                let mut comps = Vec::new();
+                match src_bit_size {
+                    1 | 32 | 64 => {
+                        for (ssa, _) in srcs {
+                            comps.push(ssa);
+                        }
+                    }
+                    8 => {
+                        for dc in 0..bits.div_ceil(32) {
+                            let mut psrc = [Src::new_zero(); 4];
+                            let mut psel = [0_u8; 4];
+
+                            for b in 0..4 {
+                                let sc = usize::from(dc * 4 + b);
+                                if sc < srcs.len() {
+                                    let (ssa, byte) = srcs[sc];
+                                    for i in 0..4_u8 {
+                                        let psrc_i = &mut psrc[usize::from(i)];
+                                        if *psrc_i == Src::new_zero() {
+                                            *psrc_i = ssa.into();
+                                        } else if *psrc_i != Src::from(ssa) {
+                                            continue;
+                                        }
+                                        psel[usize::from(b)] = i * 4 + byte;
+                                    }
+                                }
+                            }
+                            comps.push(b.prmt4(psrc, psel)[0]);
+                        }
+                    }
+                    16 => {
+                        for dc in 0..bits.div_ceil(32) {
+                            let mut psrc = [Src::new_zero(); 2];
+                            let mut psel = [0_u8; 4];
+
+                            for w in 0..2 {
+                                let sc = usize::from(dc * 2 + w);
+                                if sc < srcs.len() {
+                                    let (ssa, byte) = srcs[sc];
+                                    let w_usize = usize::from(w);
+                                    psrc[w_usize] = ssa.into();
+                                    psel[w_usize * 2 + 0] = (w * 4) + byte;
+                                    psel[w_usize * 2 + 1] = (w * 4) + byte + 1;
+                                }
+                            }
+                            comps.push(b.prmt(psrc[0], psrc[1], psel)[0]);
+                        }
                     }
+                    _ => panic!("Unknown bit size: {src_bit_size}"),
                 }
-                self.set_ssa(&alu.def, dst_vec);
+
+                self.set_ssa(&alu.def, comps);
                 return;
             }
             _ => (),
         }
 
+        let mut srcs: Vec<Src> = Vec::new();
+        for (i, alu_src) in alu.srcs_as_slice().iter().enumerate() {
+            let bit_size = alu_src.src.bit_size();
+            let comps = alu.src_components(i.try_into().unwrap());
+            let ssa = self.get_ssa(&alu_src.src.as_def());
+
+            match bit_size {
+                1 => {
+                    assert!(comps == 1);
+                    let s = usize::from(alu_src.swizzle[0]);
+                    srcs.push(ssa[s].into());
+                }
+                8 => {
+                    assert!(comps <= 4);
+                    let s = alu_src.swizzle[0];
+                    let dw = ssa[usize::from(s / 4)];
+
+                    let mut prmt = [4_u8; 4];
+                    for c in 0..comps {
+                        let cs = alu_src.swizzle[usize::from(c)];
+                        assert!(s / 4 == cs / 4);
+                        prmt[usize::from(c)] = cs;
+                    }
+                    srcs.push(b.prmt(dw.into(), 0.into(), prmt).into());
+                }
+                16 => {
+                    assert!(comps <= 2);
+                    let s = alu_src.swizzle[0];
+                    let dw = ssa[usize::from(s / 2)];
+
+                    let mut prmt = [0_u8; 4];
+                    for c in 0..comps {
+                        let cs = alu_src.swizzle[usize::from(c)];
+                        assert!(s / 2 == cs / 2);
+                        prmt[usize::from(c) * 2 + 0] = cs * 2 + 0;
+                        prmt[usize::from(c) * 2 + 1] = cs * 2 + 1;
+                    }
+                    // TODO: Some ops can handle swizzles
+                    srcs.push(b.prmt(dw.into(), 0.into(), prmt).into());
+                }
+                32 => {
+                    assert!(comps == 1);
+                    let s = usize::from(alu_src.swizzle[0]);
+                    srcs.push(ssa[s].into());
+                }
+                64 => {
+                    assert!(comps == 1);
+                    let s = usize::from(alu_src.swizzle[0]);
+                    srcs.push([ssa[s * 2], ssa[s * 2 + 1]].into());
+                }
+                _ => panic!("Invalid bit size: {bit_size}"),
+            }
+        }
+
         let dst: SSARef = match alu.op {
             nir_op_b2b1 => {
                 assert!(alu.get_src(0).bit_size() == 32);
@@ -813,12 +925,6 @@ impl<'a> ShaderFromNir<'a> {
             nir_op_ixor => {
                 b.lop2(LogicOp::new_lut(&|x, y, _| x ^ y), srcs[0], srcs[1])
             }
-            nir_op_pack_64_2x32_split => {
-                let dst = b.alloc_ssa(RegFile::GPR, 2);
-                b.copy_to(dst[0].into(), srcs[0]);
-                b.copy_to(dst[1].into(), srcs[1]);
-                dst
-            }
             nir_op_pack_half_2x16_split => {
                 assert!(alu.get_src(0).bit_size() == 32);
                 let low = b.alloc_ssa(RegFile::GPR, 1);
@@ -867,6 +973,12 @@ impl<'a> ShaderFromNir<'a> {
             nir_op_ult => {
                 b.isetp(IntCmpType::U32, IntCmpOp::Lt, srcs[0], srcs[1])
             }
+            nir_op_unpack_32_2x16_split_x => {
+                b.prmt(srcs[0], 0.into(), [0, 1, 4, 4])
+            }
+            nir_op_unpack_32_2x16_split_y => {
+                b.prmt(srcs[0], 0.into(), [2, 3, 4, 4])
+            }
             nir_op_unpack_64_2x32_split_x => {
                 let src0_x = srcs[0].as_ssa().unwrap()[0];
                 b.copy(src0_x.into())
@@ -2136,22 +2248,49 @@ impl<'a> ShaderFromNir<'a> {
         b: &mut impl SSABuilder,
         load_const: &nir_load_const_instr,
     ) {
-        let mut dst_vec = Vec::new();
-        for c in 0..load_const.def.num_components {
-            if load_const.def.bit_size == 1 {
-                let imm_b1 = unsafe { load_const.values()[c as usize].b };
-                dst_vec.push(b.copy(imm_b1.into())[0]);
-            } else if load_const.def.bit_size == 32 {
-                let imm_u32 = unsafe { load_const.values()[c as usize].u32_ };
-                dst_vec.push(b.copy(imm_u32.into())[0]);
-            } else if load_const.def.bit_size == 64 {
-                let imm_u64 = unsafe { load_const.values()[c as usize].u64_ };
-                dst_vec.push(b.copy((imm_u64 as u32).into())[0]);
-                dst_vec.push(b.copy(((imm_u64 >> 32) as u32).into())[0]);
+        let values = &load_const.values();
+
+        let mut dst = Vec::new();
+        match load_const.def.bit_size {
+            1 => for c in 0..load_const.def.num_components {
+                let imm_b1 = unsafe { values[usize::from(c)].b };
+                dst.push(b.copy(imm_b1.into())[0]);
+            }
+            8 => for dw in 0..load_const.def.num_components.div_ceil(4) {
+                let mut imm_u32 = 0;
+                for b in 0..4 {
+                    let c = dw * 4 + b;
+                    if c < load_const.def.num_components {
+                        let imm_u8 = unsafe { values[usize::from(c)].u8_ };
+                        imm_u32 |= u32::from(imm_u8) << b * 8;
+                    }
+                }
+                dst.push(b.copy(imm_u32.into())[0]);
+            }
+            16 => for dw in 0..load_const.def.num_components.div_ceil(2) {
+                let mut imm_u32 = 0;
+                for w in 0..2 {
+                    let c = dw * 2 + w;
+                    if c < load_const.def.num_components {
+                        let imm_u16 = unsafe { values[usize::from(c)].u16_ };
+                        imm_u32 |= u32::from(imm_u16) << w * 16;
+                    }
+                }
+                dst.push(b.copy(imm_u32.into())[0]);
+            }
+            32 => for c in 0..load_const.def.num_components {
+                let imm_u32 = unsafe { values[usize::from(c)].u32_ };
+                dst.push(b.copy(imm_u32.into())[0]);
+            }
+            64 => for c in 0..load_const.def.num_components {
+                let imm_u64 = unsafe { values[c as usize].u64_ };
+                dst.push(b.copy((imm_u64 as u32).into())[0]);
+                dst.push(b.copy(((imm_u64 >> 32) as u32).into())[0]);
             }
+            _ => panic!("Unknown bit size: {}", load_const.def.bit_size),
         }
 
-        self.set_ssa(&load_const.def, dst_vec);
+        self.set_ssa(&load_const.def, dst);
     }
 
     fn parse_undef(

Reply via email to